From: Brendan Hansen Date: Thu, 15 Apr 2021 18:34:07 +0000 (-0500) Subject: changed how #operator and #add_overload work X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=d370d96bb04dec7ba693685cc21783c570027458;p=onyx.git changed how #operator and #add_overload work --- diff --git a/bin/onyx b/bin/onyx index dc651f1d..a5b6e88e 100755 Binary files a/bin/onyx and b/bin/onyx differ diff --git a/core/string.onyx b/core/string.onyx index 76d86963..e9ad0fc1 100644 --- a/core/string.onyx +++ b/core/string.onyx @@ -117,7 +117,8 @@ compare :: (str1: str, str2: str) -> i32 { return ~~(str1[i] - str2[i]); } -equal :: (str1: str, str2: str) -> bool #operator== { +#operator == equal +equal :: (str1: str, str2: str) -> bool { if str1.count != str2.count do return false; while i := 0; i < str1.count { if str1[i] != str2[i] do return false; diff --git a/include/onyxastnodes.h b/include/onyxastnodes.h index 075cd8da..c44353f6 100644 --- a/include/onyxastnodes.h +++ b/include/onyxastnodes.h @@ -4,79 +4,81 @@ #include "onyxlex.h" #include "onyxtypes.h" -#define AST_NODES \ - NODE(Node) \ - NODE(Typed) \ - \ - NODE(NamedValue) \ - NODE(BinaryOp) \ - NODE(UnaryOp) \ - NODE(NumLit) \ - NODE(StrLit) \ - NODE(Local) \ - NODE(Call) \ - NODE(Argument) \ - NODE(AddressOf) \ - NODE(Dereference) \ - NODE(ArrayAccess) \ - NODE(FieldAccess) \ - NODE(SizeOf) \ - NODE(AlignOf) \ - NODE(FileContents) \ - NODE(StructLiteral) \ - NODE(ArrayLiteral) \ - NODE(RangeLiteral) \ - NODE(Compound) \ - \ - NODE(DirectiveSolidify) \ - NODE(StaticIf) \ - NODE(DirectiveError) \ - \ - NODE(Return) \ - NODE(Jump) \ - NODE(Use) \ - \ - NODE(Block) \ - NODE(IfWhile) \ - NODE(For) \ - NODE(Defer) \ - NODE(SwitchCase) \ - NODE(Switch) \ - \ - NODE(Type) \ - NODE(BasicType) \ - NODE(PointerType) \ - NODE(FunctionType) \ - NODE(ArrayType) \ - NODE(SliceType) \ - NODE(DynArrType) \ - NODE(VarArgType) \ - NODE(StructType) \ - NODE(StructMember) \ - NODE(PolyStructType) \ - NODE(PolyStructParam) \ - NODE(PolyCallType) \ - NODE(EnumType) \ - NODE(EnumValue) \ - NODE(TypeAlias) \ - NODE(TypeRawAlias) \ - NODE(CompoundType) \ - \ - NODE(Binding) \ - NODE(MemRes) \ - NODE(Include) \ - NODE(UsePackage) \ - NODE(Alias) \ - NODE(Global) \ - NODE(Param) \ - NODE(Function) \ - NODE(OverloadedFunction) \ - \ - NODE(PolyParam) \ - NODE(PolySolution) \ - NODE(SolidifiedFunction) \ - NODE(PolyProc) \ - \ +#define AST_NODES \ + NODE(Node) \ + NODE(Typed) \ + \ + NODE(NamedValue) \ + NODE(BinaryOp) \ + NODE(UnaryOp) \ + NODE(NumLit) \ + NODE(StrLit) \ + NODE(Local) \ + NODE(Call) \ + NODE(Argument) \ + NODE(AddressOf) \ + NODE(Dereference) \ + NODE(ArrayAccess) \ + NODE(FieldAccess) \ + NODE(SizeOf) \ + NODE(AlignOf) \ + NODE(FileContents) \ + NODE(StructLiteral) \ + NODE(ArrayLiteral) \ + NODE(RangeLiteral) \ + NODE(Compound) \ + \ + NODE(DirectiveSolidify) \ + NODE(StaticIf) \ + NODE(DirectiveError) \ + NODE(DirectiveAddOverload) \ + NODE(DirectiveOperator) \ + \ + NODE(Return) \ + NODE(Jump) \ + NODE(Use) \ + \ + NODE(Block) \ + NODE(IfWhile) \ + NODE(For) \ + NODE(Defer) \ + NODE(SwitchCase) \ + NODE(Switch) \ + \ + NODE(Type) \ + NODE(BasicType) \ + NODE(PointerType) \ + NODE(FunctionType) \ + NODE(ArrayType) \ + NODE(SliceType) \ + NODE(DynArrType) \ + NODE(VarArgType) \ + NODE(StructType) \ + NODE(StructMember) \ + NODE(PolyStructType) \ + NODE(PolyStructParam) \ + NODE(PolyCallType) \ + NODE(EnumType) \ + NODE(EnumValue) \ + NODE(TypeAlias) \ + NODE(TypeRawAlias) \ + NODE(CompoundType) \ + \ + NODE(Binding) \ + NODE(MemRes) \ + NODE(Include) \ + NODE(UsePackage) \ + NODE(Alias) \ + NODE(Global) \ + NODE(Param) \ + NODE(Function) \ + NODE(OverloadedFunction) \ + \ + NODE(PolyParam) \ + NODE(PolySolution) \ + NODE(SolidifiedFunction) \ + NODE(PolyProc) \ + \ NODE(Package) #define NODE(name) typedef struct Ast ## name Ast ## name; @@ -171,6 +173,8 @@ typedef enum AstKind { Ast_Kind_Directive_Solidify, Ast_Kind_Static_If, Ast_Kind_Directive_Error, + Ast_Kind_Directive_Add_Overload, + Ast_Kind_Directive_Operator, Ast_Kind_Count } AstKind; @@ -753,14 +757,6 @@ struct AstFunction { AstBlock *body; bh_arr(AstTyped *) allocate_exprs; - // NOTE: used by the #add_overload directive. Initially set to a symbol, - // then resolved to an overloaded function. - AstNode *overloaded_function; - - // NOTE: set to -1 if the function is not an operator overload; - // set to a BinaryOp value if it is. - BinaryOp operator_overload; - OnyxToken* name; @@ -888,6 +884,23 @@ struct AstDirectiveError { OnyxToken* error_msg; }; +struct AstDirectiveAddOverload { + AstNode_base; + + // NOTE: used by the #add_overload directive. Initially set to a symbol, + // then resolved to an overloaded function. + AstNode *overloaded_function; + + AstTyped *overload; +}; + +struct AstDirectiveOperator { + AstNode_base; + + BinaryOp operator; + AstTyped *overload; +}; + extern AstNode empty_node; @@ -931,6 +944,7 @@ typedef enum EntityType { Entity_Type_Foreign_Global_Header, Entity_Type_Function_Header, Entity_Type_Global_Header, + Entity_Type_Process_Directive, Entity_Type_Struct_Member_Default, Entity_Type_Memory_Reservation, Entity_Type_Expression, @@ -1233,4 +1247,10 @@ static inline ParamPassType type_get_param_pass(Type* type) { return Param_Pass_By_Value; } +static inline AstFunction* get_function_from_node(AstNode* node) { + if (node->kind == Ast_Kind_Function) return (AstFunction *) node; + if (node->kind == Ast_Kind_Polymorphic_Proc) return ((AstPolyProc *) node)->base_func; + return NULL; +} + #endif // #ifndef ONYXASTNODES_H diff --git a/src/onyx.c b/src/onyx.c index 68d89d31..beb0a498 100644 --- a/src/onyx.c +++ b/src/onyx.c @@ -394,7 +394,7 @@ static void output_dummy_progress_bar() { for (i32 j = 0; j < Entity_State_Count; j++) { if (eh->all_count[j][i] == 0) continue; - printf(state_colors[j]); + printf("%s", state_colors[j]); i32 count = (eh->all_count[j][i] >> 5) + 1; for (i32 c = 0; c < count * 2; c++) printf("\xe2\x96\x88"); diff --git a/src/onyxastnodes.c b/src/onyxastnodes.c index c7a0354a..0f4c83e5 100644 --- a/src/onyxastnodes.c +++ b/src/onyxastnodes.c @@ -82,6 +82,8 @@ static const char* ast_node_names[] = { "SOLIDIFY", "STATIC IF", "STATIC ERROR", + "ADD OVERLOAD", + "OPERATOR OVERLOAD", "AST_NODE_KIND_COUNT", }; @@ -136,6 +138,7 @@ const char* entity_type_strings[Entity_Type_Count] = { "Foreign_Global Header", "Function Header", "Global Header", + "Process Directive", "Struct Member Default", "Memory Reservation", "Expression", diff --git a/src/onyxchecker.c b/src/onyxchecker.c index 10e9607c..32e91e7d 100644 --- a/src/onyxchecker.c +++ b/src/onyxchecker.c @@ -939,7 +939,7 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) { if (!type_is_structlike_strict(sl->type)) { onyx_report_error(sl->token->pos, - "'%s' is not a constructable using a struct literal.", + "'%s' is not constructable using a struct literal.", type_get_name(sl->type)); return Check_Error; } diff --git a/src/onyxentities.c b/src/onyxentities.c index 955e039b..5a11f4d7 100644 --- a/src/onyxentities.c +++ b/src/onyxentities.c @@ -302,6 +302,14 @@ void add_entities_for_node(bh_arr(Entity *) *target_arr, AstNode* node, Scope* s ENTITY_INSERT(ent); break; } + + case Ast_Kind_Directive_Add_Overload: + case Ast_Kind_Directive_Operator: { + ent.type = Entity_Type_Process_Directive; + ent.expr = (AstTyped *) node; + ENTITY_INSERT(ent); + break; + } default: { ent.type = Entity_Type_Expression; diff --git a/src/onyxparser.c b/src/onyxparser.c index 34e6d782..e4bd7450 100644 --- a/src/onyxparser.c +++ b/src/onyxparser.c @@ -1869,7 +1869,6 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok AstFunction* func_def = make_node(AstFunction, Ast_Kind_Function); func_def->token = token; - func_def->operator_overload = -1; bh_arr_new(global_heap_allocator, func_def->allocate_exprs, 4); bh_arr_new(global_heap_allocator, func_def->params, 4); @@ -1886,28 +1885,7 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok func_def->return_type = parse_type(parser); while (parser->curr->type == '#') { - if (parse_possible_directive(parser, "add_overload")) { - if (func_def->overloaded_function != NULL) { - onyx_report_error(parser->curr->pos, "cannot have multiple #add_overload directives on a single procedure."); - expect_token(parser, Token_Type_Symbol); - - } else { - func_def->overloaded_function = (AstNode *) parse_expression(parser, 0); - } - } - - else if (parse_possible_directive(parser, "operator")) { - BinaryOp op = binary_op_from_token_type(parser->curr->type); - consume_token(parser); - - if (op == Binary_Op_Count) { - onyx_report_error(parser->curr->pos, "Invalid binary operator."); - } else { - func_def->operator_overload = op; - } - } - - else if (parse_possible_directive(parser, "intrinsic")) { + if (parse_possible_directive(parser, "intrinsic")) { func_def->flags |= Ast_Flag_Intrinsic; if (parser->curr->type == Token_Type_Literal_String) { @@ -2293,6 +2271,35 @@ static void parse_top_level_statement(OnyxParser* parser) { ENTITY_SUBMIT(error); return; } + else if (parse_possible_directive(parser, "operator")) { + AstDirectiveOperator *operator = make_node(AstDirectiveOperator, Ast_Kind_Directive_Operator); + operator->token = dir_token; + + BinaryOp op = binary_op_from_token_type(parser->curr->type); + consume_token(parser); + + if (op == Binary_Op_Count) { + onyx_report_error(parser->curr->pos, "Invalid binary operator."); + } else { + operator->operator = op; + } + + operator->overload = parse_expression(parser, 0); + + ENTITY_SUBMIT(operator); + return; + } + else if (parse_possible_directive(parser, "add_overload")) { + AstDirectiveAddOverload *add_overload = make_node(AstDirectiveAddOverload, Ast_Kind_Directive_Add_Overload); + add_overload->token = dir_token; + add_overload->overloaded_function = (AstNode *) parse_expression(parser, 0); + + expect_token(parser, ','); + add_overload->overload = parse_expression(parser, 0); + + ENTITY_SUBMIT(add_overload); + return; + } else { OnyxToken* directive_token = expect_token(parser, '#'); OnyxToken* symbol_token = expect_token(parser, Token_Type_Symbol); diff --git a/src/onyxsymres.c b/src/onyxsymres.c index 4a0282a8..99c356a5 100644 --- a/src/onyxsymres.c +++ b/src/onyxsymres.c @@ -810,33 +810,6 @@ SymresStatus symres_function_header(AstFunction* func) { if (onyx_has_errors()) return Symres_Error; } } - - if ((func->flags & Ast_Flag_From_Polymorphism) == 0) { - if (func->overloaded_function != NULL) { - SYMRES(expression, (AstTyped **) &func->overloaded_function); - if (func->overloaded_function == NULL) return Symres_Error; // NOTE: Error message will already be generated - - if (func->overloaded_function->kind != Ast_Kind_Overloaded_Function) { - onyx_report_error(func->token->pos, "#add_overload directive did not resolve to an overloaded function."); - - } else { - AstOverloadedFunction* ofunc = (AstOverloadedFunction *) func->overloaded_function; - bh_arr_push(ofunc->overloads, (AstTyped *) func); - } - } - - if (func->operator_overload != (BinaryOp) -1) { - if (bh_arr_length(func->params) != 2) { - onyx_report_error(func->token->pos, "Expected 2 exactly arguments for binary operator overload."); - } - - if (binop_is_assignment(func->operator_overload)) { - onyx_report_error(func->token->pos, "'%s' is not currently overloadable.", binaryop_string[func->operator_overload]); - } - - bh_arr_push(operator_overloads[func->operator_overload], (AstTyped *) func); - } - } SYMRES(type, &func->return_type); if (!node_is_type((AstNode *) func->return_type)) { @@ -1035,18 +1008,6 @@ static SymresStatus symres_polyproc(AstPolyProc* pp) { SYMRES(type, ¶m->type_expr); } - // CLEANUP: This was copied from symres_function_header. - if (pp->base_func->operator_overload != (BinaryOp) -1) { - if (bh_arr_length(pp->base_func->params) != 2) { - onyx_report_error(pp->base_func->token->pos, "Expected 2 exactly arguments for binary operator overload."); - } - - if (binop_is_assignment(pp->base_func->operator_overload)) { - onyx_report_error(pp->base_func->token->pos, "'%s' is not currently overloadable.", binaryop_string[pp->base_func->operator_overload]); - } - - bh_arr_push(operator_overloads[pp->base_func->operator_overload], (AstTyped *) pp); - } return Symres_Success; } @@ -1055,6 +1016,54 @@ static SymresStatus symres_static_if(AstStaticIf* static_if) { return Symres_Success; } +static SymresStatus symres_process_directive(AstNode* directive) { + switch (directive->kind) { + case Ast_Kind_Directive_Add_Overload: { + AstDirectiveAddOverload *add_overload = (AstDirectiveAddOverload *) directive; + + SYMRES(expression, (AstTyped **) &add_overload->overloaded_function); + if (add_overload->overloaded_function == NULL) return Symres_Error; // NOTE: Error message will already be generated + + if (add_overload->overloaded_function->kind != Ast_Kind_Overloaded_Function) { + onyx_report_error(add_overload->token->pos, "#add_overload directive did not resolve to an overloaded function."); + + } else { + AstOverloadedFunction* ofunc = (AstOverloadedFunction *) add_overload->overloaded_function; + bh_arr_push(ofunc->overloads, (AstTyped *) add_overload->overload); + } + + break; + } + + case Ast_Kind_Directive_Operator: { + AstDirectiveOperator *operator = (AstDirectiveOperator *) directive; + SYMRES(expression, &operator->overload); + if (!operator->overload) return Symres_Error; + + AstFunction* overload = get_function_from_node((AstNode *) operator->overload); + if (overload == NULL) { + onyx_report_error(operator->token->pos, "This cannot be used as an operator overload."); + return Symres_Error; + } + + if (bh_arr_length(overload->params) != 2) { + onyx_report_error(operator->token->pos, "Expected 2 exactly arguments for binary operator overload."); + return Symres_Error; + } + + if (binop_is_assignment(operator->operator)) { + onyx_report_error(overload->token->pos, "'%s' is not currently overloadable.", binaryop_string[operator->operator]); + return Symres_Error; + } + + bh_arr_push(operator_overloads[operator->operator], operator->overload); + break; + } + } + + return Symres_Success; +} + void symres_entity(Entity* ent) { if (block_stack == NULL) bh_arr_new(global_heap_allocator, block_stack, 16); @@ -1113,6 +1122,9 @@ void symres_entity(Entity* ent) { case Entity_Type_Polymorphic_Proc: ss = symres_polyproc(ent->poly_proc); break; case Entity_Type_String_Literal: ss = symres_expression(&ent->expr); break; case Entity_Type_Struct_Member_Default: ss = symres_struct_defaults((AstType *) ent->type_alias); break; + case Entity_Type_Process_Directive: ss = symres_process_directive((AstNode *) ent->expr); + next_state = Entity_State_Finalized; + break; default: break; } diff --git a/tests/aoc-2020/day17.onyx b/tests/aoc-2020/day17.onyx index 8aff1ac7..e4f51bf0 100644 --- a/tests/aoc-2020/day17.onyx +++ b/tests/aoc-2020/day17.onyx @@ -12,11 +12,11 @@ CubeState :: struct { next := false; } -proc (c: CubePos) -> u32 #add_overload map.hash_function { +#add_overload map.hash_function, (c: CubePos) -> u32 { return 17 * c.x + 13 * c.y + 11 * c.z + 19 * c.w; } -proc (a: CubePos, b: CubePos) -> bool #add_overload map.cmp_function { +#add_overload map.cmp_function, (a: CubePos, b: CubePos) -> bool { return (a.x == b.x) && (a.y == b.y) && (a.z == b.z) diff --git a/tests/aoc-2020/day24.onyx b/tests/aoc-2020/day24.onyx index 099d3414..3180fbf9 100644 --- a/tests/aoc-2020/day24.onyx +++ b/tests/aoc-2020/day24.onyx @@ -7,13 +7,12 @@ Vec2 :: struct { y: i32 = 0; } -proc (v: Vec2) -> u32 #add_overload map.hash_function { +#add_overload map.hash_function, (v: Vec2) -> u32 { return v.x * 11 + v.y * 17; } -proc (v1: Vec2, v2: Vec2) -> bool #add_overload map.cmp_function { - return v1.x == v2.x - && v1.y == v2.y; +#add_overload map.cmp_function, (v1: Vec2, v2: Vec2) -> bool { + return v1.x == v2.x && v1.y == v2.y; } Hex_Directions := Vec2.[ diff --git a/tests/array_struct_robustness.onyx b/tests/array_struct_robustness.onyx index dc169bc8..e201ae36 100644 --- a/tests/array_struct_robustness.onyx +++ b/tests/array_struct_robustness.onyx @@ -5,7 +5,7 @@ use package core Vec2 :: struct { x: i32; y: i32; } // Overload print() to print Vec2's. -proc (use writer: ^io.Writer, use v: Vec2) #add_overload io.write { +#add_overload io.write, (use writer: ^io.Writer, use v: Vec2) { io.write_format(writer, "Vec2(%i, %i)", x, y); } diff --git a/tests/operator_overload.onyx b/tests/operator_overload.onyx index e5b58f55..4041ffa8 100644 --- a/tests/operator_overload.onyx +++ b/tests/operator_overload.onyx @@ -7,15 +7,15 @@ Complex :: struct { im : f32 = 0; } -proc (a: Complex, b: Complex) -> Complex #operator+ { +#operator+ (a: Complex, b: Complex) -> Complex { return Complex.{ a.re + b.re, a.im + b.im }; } -proc (a: Complex, b: Complex) -> Complex #operator- { +#operator- (a: Complex, b: Complex) -> Complex { return Complex.{ a.re - b.re, a.im - b.im }; } -proc (a: Complex, b: Complex) -> Complex #operator* { +#operator* (a: Complex, b: Complex) -> Complex { return Complex.{ a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re }; } @@ -23,30 +23,29 @@ C :: (re: f32, im: f32) -> Complex do return Complex.{ re, im }; - Vec :: struct (T: type_expr, N: i32) { data: [N] T; } -proc (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) #operator+ { +#operator+ (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) { out : Vec(T, N); for i: 0 .. N do out.data[i] = a.data[i] + b.data[i]; return out; } -proc (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) #operator- { +#operator- (a: Vec($T, $N), b: Vec(T, N)) -> Vec(T, N) { out : Vec(T, N); for i: 0 .. N do out.data[i] = a.data[i] - b.data[i]; return out; } -proc (a: Vec($T, $N), s: T) -> Vec(T, N) #operator* { +#operator* (a: Vec($T, $N), s: T) -> Vec(T, N) { out : Vec(T, N); for i: 0 .. N do out.data[i] = a.data[i] * s; return out; } -proc (a: Vec($T, $N), b: Vec(T, N)) -> T #operator* { +#operator* (a: Vec($T, $N), b: Vec(T, N)) -> T { res := T.{}; for i: 0 .. N do res += a.data[i] * b.data[i]; return res;