changed how #operator and #add_overload work
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Thu, 15 Apr 2021 18:34:07 +0000 (13:34 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Thu, 15 Apr 2021 18:34:07 +0000 (13:34 -0500)
13 files changed:
bin/onyx
core/string.onyx
include/onyxastnodes.h
src/onyx.c
src/onyxastnodes.c
src/onyxchecker.c
src/onyxentities.c
src/onyxparser.c
src/onyxsymres.c
tests/aoc-2020/day17.onyx
tests/aoc-2020/day24.onyx
tests/array_struct_robustness.onyx
tests/operator_overload.onyx

index dc651f1db62f18f611f470d0f1b9a85cac5d9277..a5b6e88e669ab7d3e62c67f9b0d0ffa3db41d620 100755 (executable)
Binary files a/bin/onyx and b/bin/onyx differ
index 76d8696307d40a5636351cdf8168a7b75cd5da73..e9ad0fc1024514c6756b23e735badf1a090a4581 100644 (file)
@@ -117,7 +117,8 @@ compare :: (str1: str, str2: str) -> i32 {
     return ~~(str1[i] - str2[i]);
 }
 
-equal :: (str1: str, str2: str) -> bool #operator== {
+#operator == equal
+equal :: (str1: str, str2: str) -> bool {
     if str1.count != str2.count do return false;
     while i := 0; i < str1.count {
         if str1[i] != str2[i] do return false;
index 075cd8daf7e2144d7094ab1cdb360b78205f1481..c44353f6832009c9d835dea9e0b8e2968419a8b7 100644 (file)
@@ -4,79 +4,81 @@
 #include "onyxlex.h"
 #include "onyxtypes.h"
 
-#define AST_NODES            \
-    NODE(Node)               \
-    NODE(Typed)              \
-                             \
-    NODE(NamedValue)         \
-    NODE(BinaryOp)           \
-    NODE(UnaryOp)            \
-    NODE(NumLit)             \
-    NODE(StrLit)             \
-    NODE(Local)              \
-    NODE(Call)               \
-    NODE(Argument)           \
-    NODE(AddressOf)          \
-    NODE(Dereference)        \
-    NODE(ArrayAccess)        \
-    NODE(FieldAccess)        \
-    NODE(SizeOf)             \
-    NODE(AlignOf)            \
-    NODE(FileContents)       \
-    NODE(StructLiteral)      \
-    NODE(ArrayLiteral)       \
-    NODE(RangeLiteral)       \
-    NODE(Compound)           \
-                             \
-    NODE(DirectiveSolidify)  \
-    NODE(StaticIf)           \
-    NODE(DirectiveError)     \
-                             \
-    NODE(Return)             \
-    NODE(Jump)               \
-    NODE(Use)                \
-                             \
-    NODE(Block)              \
-    NODE(IfWhile)            \
-    NODE(For)                \
-    NODE(Defer)              \
-    NODE(SwitchCase)         \
-    NODE(Switch)             \
-                             \
-    NODE(Type)               \
-    NODE(BasicType)          \
-    NODE(PointerType)        \
-    NODE(FunctionType)       \
-    NODE(ArrayType)          \
-    NODE(SliceType)          \
-    NODE(DynArrType)         \
-    NODE(VarArgType)         \
-    NODE(StructType)         \
-    NODE(StructMember)       \
-    NODE(PolyStructType)     \
-    NODE(PolyStructParam)    \
-    NODE(PolyCallType)       \
-    NODE(EnumType)           \
-    NODE(EnumValue)          \
-    NODE(TypeAlias)          \
-    NODE(TypeRawAlias)       \
-    NODE(CompoundType)       \
-                             \
-    NODE(Binding)            \
-    NODE(MemRes)             \
-    NODE(Include)            \
-    NODE(UsePackage)         \
-    NODE(Alias)              \
-    NODE(Global)             \
-    NODE(Param)              \
-    NODE(Function)           \
-    NODE(OverloadedFunction) \
-                             \
-    NODE(PolyParam)          \
-    NODE(PolySolution)       \
-    NODE(SolidifiedFunction) \
-    NODE(PolyProc)           \
-                             \
+#define AST_NODES              \
+    NODE(Node)                 \
+    NODE(Typed)                \
+                               \
+    NODE(NamedValue)           \
+    NODE(BinaryOp)             \
+    NODE(UnaryOp)              \
+    NODE(NumLit)               \
+    NODE(StrLit)               \
+    NODE(Local)                \
+    NODE(Call)                 \
+    NODE(Argument)             \
+    NODE(AddressOf)            \
+    NODE(Dereference)          \
+    NODE(ArrayAccess)          \
+    NODE(FieldAccess)          \
+    NODE(SizeOf)               \
+    NODE(AlignOf)              \
+    NODE(FileContents)         \
+    NODE(StructLiteral)        \
+    NODE(ArrayLiteral)         \
+    NODE(RangeLiteral)         \
+    NODE(Compound)             \
+                               \
+    NODE(DirectiveSolidify)    \
+    NODE(StaticIf)             \
+    NODE(DirectiveError)       \
+    NODE(DirectiveAddOverload) \
+    NODE(DirectiveOperator)    \
+                               \
+    NODE(Return)               \
+    NODE(Jump)                 \
+    NODE(Use)                  \
+                               \
+    NODE(Block)                \
+    NODE(IfWhile)              \
+    NODE(For)                  \
+    NODE(Defer)                \
+    NODE(SwitchCase)           \
+    NODE(Switch)               \
+                               \
+    NODE(Type)                 \
+    NODE(BasicType)            \
+    NODE(PointerType)          \
+    NODE(FunctionType)         \
+    NODE(ArrayType)            \
+    NODE(SliceType)            \
+    NODE(DynArrType)           \
+    NODE(VarArgType)           \
+    NODE(StructType)           \
+    NODE(StructMember)         \
+    NODE(PolyStructType)       \
+    NODE(PolyStructParam)      \
+    NODE(PolyCallType)         \
+    NODE(EnumType)             \
+    NODE(EnumValue)            \
+    NODE(TypeAlias)            \
+    NODE(TypeRawAlias)         \
+    NODE(CompoundType)         \
+                               \
+    NODE(Binding)              \
+    NODE(MemRes)               \
+    NODE(Include)              \
+    NODE(UsePackage)           \
+    NODE(Alias)                \
+    NODE(Global)               \
+    NODE(Param)                \
+    NODE(Function)             \
+    NODE(OverloadedFunction)   \
+                               \
+    NODE(PolyParam)            \
+    NODE(PolySolution)         \
+    NODE(SolidifiedFunction)   \
+    NODE(PolyProc)             \
+                               \
     NODE(Package)          
 
 #define NODE(name) typedef struct Ast ## name Ast ## name;
@@ -171,6 +173,8 @@ typedef enum AstKind {
     Ast_Kind_Directive_Solidify,
     Ast_Kind_Static_If,
     Ast_Kind_Directive_Error,
+    Ast_Kind_Directive_Add_Overload,
+    Ast_Kind_Directive_Operator,
 
     Ast_Kind_Count
 } AstKind;
@@ -753,14 +757,6 @@ struct AstFunction {
     AstBlock *body;
     bh_arr(AstTyped *) allocate_exprs;
 
-    // NOTE: used by the #add_overload directive. Initially set to a symbol,
-    // then resolved to an overloaded function.
-    AstNode *overloaded_function;
-
-    // NOTE: set to -1 if the function is not an operator overload;
-    // set to a BinaryOp value if it is.
-    BinaryOp operator_overload;
-
     OnyxToken* name;
 
 
@@ -888,6 +884,23 @@ struct AstDirectiveError {
     OnyxToken* error_msg;
 };
 
+struct AstDirectiveAddOverload {
+    AstNode_base;
+
+    // NOTE: used by the #add_overload directive. Initially set to a symbol,
+    // then resolved to an overloaded function.
+    AstNode *overloaded_function;
+
+    AstTyped *overload;
+};
+
+struct AstDirectiveOperator {
+    AstNode_base;
+
+    BinaryOp operator;
+    AstTyped *overload;
+};
+
 
 extern AstNode empty_node;
 
@@ -931,6 +944,7 @@ typedef enum EntityType {
     Entity_Type_Foreign_Global_Header,
     Entity_Type_Function_Header,
     Entity_Type_Global_Header,
+    Entity_Type_Process_Directive,
     Entity_Type_Struct_Member_Default,
     Entity_Type_Memory_Reservation,
     Entity_Type_Expression,
@@ -1233,4 +1247,10 @@ static inline ParamPassType type_get_param_pass(Type* type) {
     return Param_Pass_By_Value;
 }
 
+static inline AstFunction* get_function_from_node(AstNode* node) {
+    if (node->kind == Ast_Kind_Function) return (AstFunction *) node;
+    if (node->kind == Ast_Kind_Polymorphic_Proc) return ((AstPolyProc *) node)->base_func;
+    return NULL;
+}
+
 #endif // #ifndef ONYXASTNODES_H
index 68d89d318cf05984db3209db68bc6e06da712f1d..beb0a4982b671177201db43ed6cb36133c8095f2 100644 (file)
@@ -394,7 +394,7 @@ static void output_dummy_progress_bar() {
         for (i32 j = 0; j < Entity_State_Count; j++) {
             if (eh->all_count[j][i] == 0) continue;
 
-            printf(state_colors[j]);
+            printf("%s", state_colors[j]);
 
             i32 count = (eh->all_count[j][i] >> 5) + 1;
             for (i32 c = 0; c < count * 2; c++) printf("\xe2\x96\x88");
index c7a0354ad9fa68d5f366ee3112dd197200095e26..0f4c83e50dc65ef67c3e18c1d106b6db4e411614 100644 (file)
@@ -82,6 +82,8 @@ static const char* ast_node_names[] = {
     "SOLIDIFY",
     "STATIC IF",
     "STATIC ERROR",
+    "ADD OVERLOAD",
+    "OPERATOR OVERLOAD",
 
     "AST_NODE_KIND_COUNT",
 };
@@ -136,6 +138,7 @@ const char* entity_type_strings[Entity_Type_Count] = {
     "Foreign_Global Header",
     "Function Header",
     "Global Header",
+    "Process Directive",
     "Struct Member Default",
     "Memory Reservation",
     "Expression",
index 10e9607c8d3f3009501abc782811b3d5b4ef8575..32e91e7daf86a528351300f723c586fe85d9dc2b 100644 (file)
@@ -939,7 +939,7 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) {
 
     if (!type_is_structlike_strict(sl->type)) {
         onyx_report_error(sl->token->pos,
-                "'%s' is not constructable using a struct literal.",
+                "'%s' is not constructable using a struct literal.",
                 type_get_name(sl->type));
         return Check_Error;
     }
index 955e039b6fb6ba69b841d69aaa3a956179983595..5a11f4d708e638e32e3bac6e1e1eb8d575a86857 100644 (file)
@@ -302,6 +302,14 @@ void add_entities_for_node(bh_arr(Entity *) *target_arr, AstNode* node, Scope* s
             ENTITY_INSERT(ent);
             break;   
         }
+
+        case Ast_Kind_Directive_Add_Overload:
+        case Ast_Kind_Directive_Operator: {
+            ent.type = Entity_Type_Process_Directive;
+            ent.expr = (AstTyped *) node;
+            ENTITY_INSERT(ent);
+            break;
+        }
         
         default: {
             ent.type = Entity_Type_Expression;
index 34e6d7820ea945f945798a8d3a5c9f71d13519d6..e4bd7450913d11a3381984a24ebc29e68e689897 100644 (file)
@@ -1869,7 +1869,6 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok
 
     AstFunction* func_def = make_node(AstFunction, Ast_Kind_Function);
     func_def->token = token;
-    func_def->operator_overload = -1;
 
     bh_arr_new(global_heap_allocator, func_def->allocate_exprs, 4);
     bh_arr_new(global_heap_allocator, func_def->params, 4);
@@ -1886,28 +1885,7 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok
         func_def->return_type = parse_type(parser);
 
     while (parser->curr->type == '#') {
-        if (parse_possible_directive(parser, "add_overload")) {
-            if (func_def->overloaded_function != NULL) {
-                onyx_report_error(parser->curr->pos, "cannot have multiple #add_overload directives on a single procedure.");
-                expect_token(parser, Token_Type_Symbol);
-
-            } else {
-                func_def->overloaded_function = (AstNode *) parse_expression(parser, 0);
-            }
-        }
-
-        else if (parse_possible_directive(parser, "operator")) {
-            BinaryOp op = binary_op_from_token_type(parser->curr->type);
-            consume_token(parser);
-            
-            if (op == Binary_Op_Count) {
-                onyx_report_error(parser->curr->pos, "Invalid binary operator.");
-            } else {
-                func_def->operator_overload = op;
-            }
-        }
-
-        else if (parse_possible_directive(parser, "intrinsic")) {
+        if (parse_possible_directive(parser, "intrinsic")) {
             func_def->flags |= Ast_Flag_Intrinsic;
 
             if (parser->curr->type == Token_Type_Literal_String) {
@@ -2293,6 +2271,35 @@ static void parse_top_level_statement(OnyxParser* parser) {
                 ENTITY_SUBMIT(error);
                 return;
             }
+            else if (parse_possible_directive(parser, "operator")) {
+                AstDirectiveOperator *operator = make_node(AstDirectiveOperator, Ast_Kind_Directive_Operator);
+                operator->token = dir_token;
+
+                BinaryOp op = binary_op_from_token_type(parser->curr->type);
+                consume_token(parser);
+                
+                if (op == Binary_Op_Count) {
+                    onyx_report_error(parser->curr->pos, "Invalid binary operator.");
+                } else {
+                    operator->operator = op;
+                }
+
+                operator->overload = parse_expression(parser, 0);
+
+                ENTITY_SUBMIT(operator);
+                return;
+            }
+            else if (parse_possible_directive(parser, "add_overload")) {
+                AstDirectiveAddOverload *add_overload = make_node(AstDirectiveAddOverload, Ast_Kind_Directive_Add_Overload);
+                add_overload->token = dir_token;
+                add_overload->overloaded_function = (AstNode *) parse_expression(parser, 0);
+
+                expect_token(parser, ',');
+                add_overload->overload = parse_expression(parser, 0);
+
+                ENTITY_SUBMIT(add_overload);
+                return;
+            }
             else {
                 OnyxToken* directive_token = expect_token(parser, '#');
                 OnyxToken* symbol_token = expect_token(parser, Token_Type_Symbol);
index 4a0282a8e01cc595de9d56613f8a270f3f4f1abb..99c356a551acca13938f202746dce39cb7467b6a 100644 (file)
@@ -810,33 +810,6 @@ SymresStatus symres_function_header(AstFunction* func) {
             if (onyx_has_errors()) return Symres_Error;
         }
     }
-    
-    if ((func->flags & Ast_Flag_From_Polymorphism) == 0) {
-        if (func->overloaded_function != NULL) {
-            SYMRES(expression, (AstTyped **) &func->overloaded_function);
-            if (func->overloaded_function == NULL) return Symres_Error; // NOTE: Error message will already be generated
-
-            if (func->overloaded_function->kind != Ast_Kind_Overloaded_Function) {
-                onyx_report_error(func->token->pos, "#add_overload directive did not resolve to an overloaded function.");
-
-            } else {
-                AstOverloadedFunction* ofunc = (AstOverloadedFunction *) func->overloaded_function;
-                bh_arr_push(ofunc->overloads, (AstTyped *) func);
-            }
-        }
-
-        if (func->operator_overload != (BinaryOp) -1) {
-            if (bh_arr_length(func->params) != 2) {
-                onyx_report_error(func->token->pos, "Expected 2 exactly arguments for binary operator overload.");
-            }
-
-            if (binop_is_assignment(func->operator_overload)) {
-                onyx_report_error(func->token->pos, "'%s' is not currently overloadable.", binaryop_string[func->operator_overload]);
-            }
-
-            bh_arr_push(operator_overloads[func->operator_overload], (AstTyped *) func);
-        }
-    }
 
     SYMRES(type, &func->return_type);
     if (!node_is_type((AstNode *) func->return_type)) {
@@ -1035,18 +1008,6 @@ static SymresStatus symres_polyproc(AstPolyProc* pp) {
         SYMRES(type, &param->type_expr);
     }
 
-    // CLEANUP: This was copied from symres_function_header.
-    if (pp->base_func->operator_overload != (BinaryOp) -1) {
-        if (bh_arr_length(pp->base_func->params) != 2) {
-            onyx_report_error(pp->base_func->token->pos, "Expected 2 exactly arguments for binary operator overload.");
-        }
-
-        if (binop_is_assignment(pp->base_func->operator_overload)) {
-            onyx_report_error(pp->base_func->token->pos, "'%s' is not currently overloadable.", binaryop_string[pp->base_func->operator_overload]);
-        }
-
-        bh_arr_push(operator_overloads[pp->base_func->operator_overload], (AstTyped *) pp);
-    }
     return Symres_Success;
 }
 
@@ -1055,6 +1016,54 @@ static SymresStatus symres_static_if(AstStaticIf* static_if) {
     return Symres_Success;
 }
 
+static SymresStatus symres_process_directive(AstNode* directive) {
+    switch (directive->kind) {
+        case Ast_Kind_Directive_Add_Overload: {
+            AstDirectiveAddOverload *add_overload = (AstDirectiveAddOverload *) directive;
+
+            SYMRES(expression, (AstTyped **) &add_overload->overloaded_function);
+            if (add_overload->overloaded_function == NULL) return Symres_Error; // NOTE: Error message will already be generated
+
+            if (add_overload->overloaded_function->kind != Ast_Kind_Overloaded_Function) {
+                onyx_report_error(add_overload->token->pos, "#add_overload directive did not resolve to an overloaded function.");
+
+            } else {
+                AstOverloadedFunction* ofunc = (AstOverloadedFunction *) add_overload->overloaded_function;
+                bh_arr_push(ofunc->overloads, (AstTyped *) add_overload->overload);
+            }
+
+            break;
+        }
+
+        case Ast_Kind_Directive_Operator: {
+            AstDirectiveOperator *operator = (AstDirectiveOperator *) directive;
+            SYMRES(expression, &operator->overload);
+            if (!operator->overload) return Symres_Error;
+
+            AstFunction* overload = get_function_from_node((AstNode *) operator->overload);
+            if (overload == NULL) {
+                onyx_report_error(operator->token->pos, "This cannot be used as an operator overload.");
+                return Symres_Error;
+            }
+
+            if (bh_arr_length(overload->params) != 2) {
+                onyx_report_error(operator->token->pos, "Expected 2 exactly arguments for binary operator overload.");
+                return Symres_Error;
+            }
+
+            if (binop_is_assignment(operator->operator)) {
+                onyx_report_error(overload->token->pos, "'%s' is not currently overloadable.", binaryop_string[operator->operator]);
+                return Symres_Error;
+            }
+
+            bh_arr_push(operator_overloads[operator->operator], operator->overload);
+            break;
+        }
+    }
+
+    return Symres_Success;
+}
+
 void symres_entity(Entity* ent) {
     if (block_stack == NULL) bh_arr_new(global_heap_allocator, block_stack, 16);
 
@@ -1113,6 +1122,9 @@ void symres_entity(Entity* ent) {
         case Entity_Type_Polymorphic_Proc:        ss = symres_polyproc(ent->poly_proc); break;
         case Entity_Type_String_Literal:          ss = symres_expression(&ent->expr); break;
         case Entity_Type_Struct_Member_Default:   ss = symres_struct_defaults((AstType *) ent->type_alias); break;
+        case Entity_Type_Process_Directive:       ss = symres_process_directive((AstNode *) ent->expr);
+                                                  next_state = Entity_State_Finalized;
+                                                  break;
 
         default: break;
     }
index 8aff1ac742e4792d29fbd1cbad8b0e7917a86d39..e4f51bf026a5d99382143a436ed460a89108ca92 100644 (file)
@@ -12,11 +12,11 @@ CubeState :: struct {
     next  := false;
 }
 
-proc (c: CubePos) -> u32 #add_overload map.hash_function {
+#add_overload map.hash_function, (c: CubePos) -> u32 {
     return 17 * c.x + 13 * c.y + 11 * c.z + 19 * c.w;
 }
 
-proc (a: CubePos, b: CubePos) -> bool #add_overload map.cmp_function {
+#add_overload map.cmp_function, (a: CubePos, b: CubePos) -> bool {
     return (a.x == b.x)
         && (a.y == b.y)
         && (a.z == b.z)
index 099d34141149da0f69f52cbd1f4f76963e9bd0d9..3180fbf9bc50e3d405a119f518bb83acb64b0a94 100644 (file)
@@ -7,13 +7,12 @@ Vec2 :: struct {
        y: i32 = 0;
 }
 
-proc (v: Vec2) -> u32 #add_overload map.hash_function {
+#add_overload map.hash_function, (v: Vec2) -> u32 {
        return v.x * 11 + v.y * 17;
 }
 
-proc (v1: Vec2, v2: Vec2) -> bool #add_overload map.cmp_function {
-       return v1.x == v2.x
-               && v1.y == v2.y;
+#add_overload map.cmp_function, (v1: Vec2, v2: Vec2) -> bool {
+       return v1.x == v2.x && v1.y == v2.y;
 }
 
 Hex_Directions := Vec2.[
index dc169bc8397171167d690b581f4be5a193e09c20..e201ae363904af04852fd4bf51e13618ada5a7e7 100644 (file)
@@ -5,7 +5,7 @@ use package core
 Vec2 :: struct { x: i32; y: i32; }
 
 // Overload print() to print Vec2's.
-proc (use writer: ^io.Writer, use v: Vec2) #add_overload io.write {
+#add_overload io.write, (use writer: ^io.Writer, use v: Vec2) {
     io.write_format(writer, "Vec2(%i, %i)", x, y);
 }
 
index e5b58f55fd059c0f3c4b01f86f02ed32f65c7140..4041ffa8ef60a19c9049a2785f211955f76c59fb 100644 (file)
@@ -7,15 +7,15 @@ Complex :: struct {
     im : f32 = 0;
 }
 
-proc (a: Complex, b: Complex) -> Complex #operator+ {
+#operator+ (a: Complex, b: Complex) -> Complex {
     return Complex.{ a.re + b.re, a.im + b.im };
 }
 
-proc (a: Complex, b: Complex) -> Complex #operator- {
+#operator- (a: Complex, b: Complex) -> Complex {
     return Complex.{ a.re - b.re, a.im - b.im };
 }
 
-proc (a: Complex, b: Complex) -> Complex #operator* {
+#operator* (a: Complex, b: Complex) -> Complex {
     return Complex.{ a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re };
 }
 
@@ -23,30 +23,29 @@ C :: (re: f32, im: f32) -> Complex do return Complex.{ re, im };
 
 
 
-
 Vec :: struct (T: type_expr, N: i32) {
     data: [N] T;
 }
 
-proc (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) #operator+ {
+#operator+ (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) {
     out : Vec(T, N);
     for i: 0 .. N do out.data[i] = a.data[i] + b.data[i];
     return out;
 }
 
-proc (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) #operator- {
+#operator- (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) {
     out : Vec(T, N);
     for i: 0 .. N do out.data[i] = a.data[i] - b.data[i];
     return out;
 }
 
-proc (a: Vec($T, $N), s: T) -> Vec(T, N) #operator* {
+#operator* (a: Vec($T, $N), s: T) -> Vec(T, N) {
     out : Vec(T, N);
     for i: 0 .. N do out.data[i] = a.data[i] * s;
     return out;
 }
 
-proc (a: Vec($T, $N), b: Vec(T, N)) -> T #operator* {
+#operator* (a: Vec($T, $N), b: Vec(T, N)) -> T {
     res := T.{};
     for i: 0 .. N do res += a.data[i] * b.data[i];
     return res;