added basics of operator overloading
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Sun, 10 Jan 2021 01:06:22 +0000 (19:06 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Sun, 10 Jan 2021 01:06:22 +0000 (19:06 -0600)
bin/onyx
include/onyxastnodes.h
onyx.exe
src/onyxbuiltins.c
src/onyxchecker.c
src/onyxparser.c
src/onyxsymres.c

index c3d9cd57d081302a2c6a7df8cfbb3a2718231e75..ea4c0c52959aed9255588b349bc2f3cfae051834 100755 (executable)
Binary files a/bin/onyx and b/bin/onyx differ
index aadbb47a1cf5f0c9489197635b5fb6f37fd9ef91..f87ccdf7560c98c36671cc1e9f5b7b8d37984975 100644 (file)
@@ -722,6 +722,10 @@ struct AstFunction {
     // then resolved to an overloaded function.
     AstNode *overloaded_function;
 
+    // NOTE: set to -1 if the function is not an operator overload;
+    // set to a BinaryOp value if it is.
+    BinaryOp operator_overload;
+
     OnyxToken* name;
 
     union {
@@ -940,6 +944,8 @@ typedef struct IntrinsicMap {
 
 extern bh_table(OnyxIntrinsic) intrinsic_table;
 
+extern bh_arr(AstTyped *) operator_overloads[Binary_Op_Count];
+
 void initialize_builtins(bh_allocator a, ProgramInfo* prog);
 
 
index 249cd882ce72a784bd93ef7f5b08af307496b2c3..515eb88918fbc3c139c4adc6bb9945faad9c717d 100644 (file)
Binary files a/onyx.exe and b/onyx.exe differ
index e0960975f44ed788032763b5deac4a228901cf13..60bb9726610061f038e5d935ac9fc52a7c6a50d8 100644 (file)
@@ -322,6 +322,8 @@ static IntrinsicMap builtin_intrinsics[] = {
     { NULL, ONYX_INTRINSIC_UNDEFINED },
 };
 
+bh_arr(AstTyped *) operator_overloads[Binary_Op_Count] = { 0 };
+
 void initialize_builtins(bh_allocator a, ProgramInfo* prog) {
     // HACK
     builtin_package_token.text = bh_strdup(global_heap_allocator, builtin_package_token.text);
@@ -369,6 +371,9 @@ void initialize_builtins(bh_allocator a, ProgramInfo* prog) {
         return;
     }
 
+    fori (i, 0, Binary_Op_Count) {
+        bh_arr_new(global_heap_allocator, operator_overloads[i], 4); 
+    }
 
     bh_table_init(global_heap_allocator, intrinsic_table, 128);
     IntrinsicMap* intrinsic = &builtin_intrinsics[0];
index 77c8ae8f663e84ea0f002471bed6d251db5c4380..47ea5c7f017bbcdf456a6742b5087e555303eaaa 100644 (file)
@@ -293,25 +293,29 @@ CheckStatus check_switch(AstSwitch* switchnode) {
 
     return 0;
 }
-
-static AstTyped* match_overloaded_function(AstCall* call, AstOverloadedFunction* ofunc) {
-    bh_arr_each(AstTyped *, node, ofunc->overloads) {
+static AstTyped* match_overloaded_function(bh_arr(AstTyped *) arg_arr, bh_arr(AstTyped *) overloads) {
+    bh_arr_each(AstTyped *, node, overloads) {
         AstFunction* overload = (AstFunction *) *node;
 
         fill_in_type((AstTyped *) overload);
 
         TypeFunction* ol_type = &overload->type->Function;
 
-        if (call->arg_count < ol_type->needed_param_count) continue;
+        if (bh_arr_length(arg_arr) < (i32) ol_type->needed_param_count) continue;
 
         Type** param_type = ol_type->params;
-        bh_arr_each(AstArgument*, arg, call->arg_arr) {
-            fill_in_type((AstTyped *) *arg);
-            if ((*param_type)->kind == Type_Kind_VarArgs) {
-                if (!type_check_or_auto_cast(&(*arg)->value, (*param_type)->VarArgs.ptr_to_data->Pointer.elem))
-                    goto no_match;
-            }
-            else if (!type_check_or_auto_cast(&(*arg)->value, *param_type)) goto no_match;
+        bh_arr_each(AstTyped*, arg, arg_arr) {
+            fill_in_type(*arg);
+
+            Type* type_to_match = *param_type;
+            if ((*param_type)->kind == Type_Kind_VarArgs)
+                type_to_match = (*param_type)->VarArgs.ptr_to_data->Pointer.elem;
+
+            AstTyped** value = arg;
+            if ((*arg)->kind == Ast_Kind_Argument)
+                value = &((AstArgument *) *arg)->value;
+
+            if (!type_check_or_auto_cast(value, type_to_match)) goto no_match;
 
             param_type++;
         }
@@ -321,20 +325,6 @@ static AstTyped* match_overloaded_function(AstCall* call, AstOverloadedFunction*
 no_match:
         continue;
     }
-
-    char* arg_str = bh_alloc(global_scratch_allocator, 1024);
-    arg_str[0] = '\0';
-
-    bh_arr_each(AstArgument *, arg, call->arg_arr) {
-        strncat(arg_str, type_get_name((*arg)->value->type), 1023);
-
-        if (arg != &bh_arr_last(call->arg_arr))
-            strncat(arg_str, ", ", 1023);
-    }
-
-    onyx_report_error(call->token->pos, "unable to match overloaded function with provided argument types: (%s)", arg_str);
-
-    bh_free(global_scratch_allocator, arg_str);
     return NULL;
 }
 
@@ -376,10 +366,28 @@ CheckStatus check_call(AstCall* call) {
     }
 
     if (callee->kind == Ast_Kind_Overloaded_Function) {
-        call->callee = match_overloaded_function(call, (AstOverloadedFunction *) callee);
-        callee = (AstFunction *) call->callee;
+        call->callee = match_overloaded_function(
+            (bh_arr(AstTyped *)) call->arg_arr,
+            ((AstOverloadedFunction *) callee)->overloads);
+
+        if (call->callee == NULL) {
+            char* arg_str = bh_alloc(global_scratch_allocator, 1024);
+            arg_str[0] = '\0';
+
+            bh_arr_each(AstArgument *, arg, call->arg_arr) {
+                strncat(arg_str, type_get_name((*arg)->value->type), 1023);
+
+                if (arg != &bh_arr_last(call->arg_arr))
+                    strncat(arg_str, ", ", 1023);
+            }
+
+            onyx_report_error(call->token->pos, "unable to match overloaded function with provided argument types: (%s)", arg_str);
+
+            bh_free(global_scratch_allocator, arg_str);
+            return Check_Error;
+        }
 
-        if (callee == NULL) return Check_Error;
+        callee = (AstFunction *) call->callee;
     }
 
     if (callee->kind == Ast_Kind_Polymorphic_Proc) {
@@ -744,6 +752,45 @@ CheckStatus check_binaryop(AstBinaryOp** pbinop, b32 assignment_is_ok) {
         return Check_Error;
     }
 
+    if (binop->left->type->kind != Type_Kind_Basic || binop->right->type->kind != Type_Kind_Basic) {
+        bh_arr(AstTyped *) args = NULL;
+        bh_arr_new(global_heap_allocator, args, 2);
+        bh_arr_push(args, binop->left);
+        bh_arr_push(args, binop->right);
+
+        AstTyped* overload = match_overloaded_function(args, operator_overloads[binop->operation]);
+        if (overload == NULL) {
+            bh_arr_free(args);
+            goto not_overloaded;
+        }
+
+        AstCall* implicit_call = onyx_ast_node_new(semstate.node_allocator, sizeof(AstCall), Ast_Kind_Call);
+        implicit_call->token = binop->token;
+        implicit_call->arg_count = 2;
+        implicit_call->callee = overload;
+        implicit_call->va_kind = VA_Kind_Not_VA;
+
+        bh_arr_each(AstTyped *, arg, args) {
+            AstArgument* new_arg = onyx_ast_node_new(semstate.node_allocator, sizeof(AstArgument), Ast_Kind_Argument);
+            new_arg->token = (*arg)->token;
+            new_arg->type  = (*arg)->type;
+            new_arg->value = *arg;
+            new_arg->va_kind = VA_Kind_Not_VA;
+
+            *arg = (AstTyped *) new_arg;
+        }
+
+        implicit_call->arg_arr = (AstArgument **) args;
+
+        CHECK(call, implicit_call);
+
+        // NOTE: Not a binary op
+        *pbinop = (AstBinaryOp *) implicit_call;
+        return Check_Success;
+    }
+
+not_overloaded:
+
     if (!type_is_numeric(binop->left->type) && !type_is_pointer(binop->left->type)) {
         onyx_report_error(binop->token->pos,
                 "Expected numeric or pointer type for left side of binary operator, got '%s'.",
@@ -883,8 +930,9 @@ CheckStatus check_binaryop(AstBinaryOp** pbinop, b32 assignment_is_ok) {
     }
 
     if ((binop_allowed[binop->operation] & effective_flags) == 0) {
-        onyx_report_error(binop->token->pos, "Binary operator not allowed for arguments of type '%s'.",
-                type_get_name(binop->type));
+        onyx_report_error(binop->token->pos, "Binary operator not allowed for arguments of type '%s' and '%s'.",
+                type_get_name(binop->left->type),
+                type_get_name(binop->right->type));
         return Check_Error;
     }
 
index 9ac5151f6c83156aed9cab7006086a5cdfd991c4..1bfe4f01e1cc829adff360b7e21f89e487a59aea 100644 (file)
@@ -695,6 +695,50 @@ static inline i32 get_precedence(BinaryOp kind) {
     }
 }
 
+static BinaryOp binary_op_from_token_type(TokenType t) {
+    switch (t) {
+        case Token_Type_Equal_Equal:       return Binary_Op_Equal;
+        case Token_Type_Not_Equal:         return Binary_Op_Not_Equal;
+        case Token_Type_Less_Equal:        return Binary_Op_Less_Equal;
+        case Token_Type_Greater_Equal:     return Binary_Op_Greater_Equal;
+        case '<':                          return Binary_Op_Less;
+        case '>':                          return Binary_Op_Greater;
+
+        case '+':                          return Binary_Op_Add;
+        case '-':                          return Binary_Op_Minus;
+        case '*':                          return Binary_Op_Multiply;
+        case '/':                          return Binary_Op_Divide;
+        case '%':                          return Binary_Op_Modulus;
+
+        case '&':                          return Binary_Op_And;
+        case '|':                          return Binary_Op_Or;
+        case '^':                          return Binary_Op_Xor;
+        case Token_Type_Shift_Left:        return Binary_Op_Shl;
+        case Token_Type_Shift_Right:       return Binary_Op_Shr;
+        case Token_Type_Shift_Arith_Right: return Binary_Op_Sar;
+
+        case Token_Type_And_And:           return Binary_Op_Bool_And;
+        case Token_Type_Or_Or:             return Binary_Op_Bool_Or;
+
+        case '=':                          return Binary_Op_Assign;
+        case Token_Type_Plus_Equal:        return Binary_Op_Assign_Add;
+        case Token_Type_Minus_Equal:       return Binary_Op_Assign_Minus;
+        case Token_Type_Star_Equal:        return Binary_Op_Assign_Multiply;
+        case Token_Type_Fslash_Equal:      return Binary_Op_Assign_Divide;
+        case Token_Type_Percent_Equal:     return Binary_Op_Assign_Modulus;
+        case Token_Type_And_Equal:         return Binary_Op_Assign_And;
+        case Token_Type_Or_Equal:          return Binary_Op_Assign_Or;
+        case Token_Type_Xor_Equal:         return Binary_Op_Assign_Xor;
+        case Token_Type_Shl_Equal:         return Binary_Op_Assign_Shl;
+        case Token_Type_Shr_Equal:         return Binary_Op_Assign_Shr;
+        case Token_Type_Sar_Equal:         return Binary_Op_Assign_Sar;
+
+        case Token_Type_Pipe:              return Binary_Op_Pipe;
+        case Token_Type_Dot_Dot:           return Binary_Op_Range;
+        default: return Binary_Op_Count;
+    }
+}
+
 // <expr> +  <expr>
 // <expr> -  <expr>
 // <expr> *  <expr>
@@ -729,85 +773,43 @@ static AstTyped* parse_expression(OnyxParser* parser) {
     while (1) {
         if (parser->hit_unexpected_token) return root;
 
-        bin_op_kind = Binary_Op_Count;
-        switch ((u16) parser->curr->type) {
-            case Token_Type_Equal_Equal:       bin_op_kind = Binary_Op_Equal; break;
-            case Token_Type_Not_Equal:         bin_op_kind = Binary_Op_Not_Equal; break;
-            case Token_Type_Less_Equal:        bin_op_kind = Binary_Op_Less_Equal; break;
-            case Token_Type_Greater_Equal:     bin_op_kind = Binary_Op_Greater_Equal; break;
-            case '<':                          bin_op_kind = Binary_Op_Less; break;
-            case '>':                          bin_op_kind = Binary_Op_Greater; break;
-
-            case '+':                          bin_op_kind = Binary_Op_Add; break;
-            case '-':                          bin_op_kind = Binary_Op_Minus; break;
-            case '*':                          bin_op_kind = Binary_Op_Multiply; break;
-            case '/':                          bin_op_kind = Binary_Op_Divide; break;
-            case '%':                          bin_op_kind = Binary_Op_Modulus; break;
-
-            case '&':                          bin_op_kind = Binary_Op_And; break;
-            case '|':                          bin_op_kind = Binary_Op_Or; break;
-            case '^':                          bin_op_kind = Binary_Op_Xor; break;
-            case Token_Type_Shift_Left:        bin_op_kind = Binary_Op_Shl; break;
-            case Token_Type_Shift_Right:       bin_op_kind = Binary_Op_Shr; break;
-            case Token_Type_Shift_Arith_Right: bin_op_kind = Binary_Op_Sar; break;
-
-            case Token_Type_And_And:           bin_op_kind = Binary_Op_Bool_And; break;
-            case Token_Type_Or_Or:             bin_op_kind = Binary_Op_Bool_Or; break;
-
-            case '=':                          bin_op_kind = Binary_Op_Assign; break;
-            case Token_Type_Plus_Equal:        bin_op_kind = Binary_Op_Assign_Add; break;
-            case Token_Type_Minus_Equal:       bin_op_kind = Binary_Op_Assign_Minus; break;
-            case Token_Type_Star_Equal:        bin_op_kind = Binary_Op_Assign_Multiply; break;
-            case Token_Type_Fslash_Equal:      bin_op_kind = Binary_Op_Assign_Divide; break;
-            case Token_Type_Percent_Equal:     bin_op_kind = Binary_Op_Assign_Modulus; break;
-            case Token_Type_And_Equal:         bin_op_kind = Binary_Op_Assign_And; break;
-            case Token_Type_Or_Equal:          bin_op_kind = Binary_Op_Assign_Or; break;
-            case Token_Type_Xor_Equal:         bin_op_kind = Binary_Op_Assign_Xor; break;
-            case Token_Type_Shl_Equal:         bin_op_kind = Binary_Op_Assign_Shl; break;
-            case Token_Type_Shr_Equal:         bin_op_kind = Binary_Op_Assign_Shr; break;
-            case Token_Type_Sar_Equal:         bin_op_kind = Binary_Op_Assign_Sar; break;
-
-            case Token_Type_Pipe:              bin_op_kind = Binary_Op_Pipe; break;
-            case Token_Type_Dot_Dot:           bin_op_kind = Binary_Op_Range; break;
-            default: goto expression_done;
-        }
-
-        if (bin_op_kind != Binary_Op_Count) {
-            bin_op_tok = parser->curr;
-            consume_token(parser);
-
-            AstBinaryOp* bin_op;
-            if (bin_op_kind == Binary_Op_Pipe) {
-                bin_op = make_node(AstBinaryOp, Ast_Kind_Pipe);
+        bin_op_kind = binary_op_from_token_type(parser->curr->type);
+        if (bin_op_kind == Binary_Op_Count) goto expression_done;
 
-            } else if (bin_op_kind == Binary_Op_Range) {
-                bin_op = (AstBinaryOp *) make_node(AstRangeLiteral, Ast_Kind_Range_Literal);
+        bin_op_tok = parser->curr;
+        consume_token(parser);
 
-            } else {
-                bin_op = make_node(AstBinaryOp, Ast_Kind_Binary_Op);
-            }
+        AstBinaryOp* bin_op;
+        if (bin_op_kind == Binary_Op_Pipe) {
+            bin_op = make_node(AstBinaryOp, Ast_Kind_Pipe);
 
-            bin_op->token = bin_op_tok;
-            bin_op->operation = bin_op_kind;
+        } else if (bin_op_kind == Binary_Op_Range) {
+            bin_op = (AstBinaryOp *) make_node(AstRangeLiteral, Ast_Kind_Range_Literal);
 
-            while ( !bh_arr_is_empty(tree_stack) &&
-                    get_precedence(bh_arr_last(tree_stack)->operation) >= get_precedence(bin_op_kind))
-                bh_arr_pop(tree_stack);
+        } else {
+            bin_op = make_node(AstBinaryOp, Ast_Kind_Binary_Op);
+        }
 
-            if (bh_arr_is_empty(tree_stack)) {
-                // NOTE: new is now the root node
-                bin_op->left = root;
-                root = (AstTyped *) bin_op;
-            } else {
-                bin_op->left = bh_arr_last(tree_stack)->right;
-                bh_arr_last(tree_stack)->right = (AstTyped *) bin_op;
-            }
+        bin_op->token = bin_op_tok;
+        bin_op->operation = bin_op_kind;
 
-            bh_arr_push(tree_stack, bin_op);
+        while ( !bh_arr_is_empty(tree_stack) &&
+                get_precedence(bh_arr_last(tree_stack)->operation) >= get_precedence(bin_op_kind))
+            bh_arr_pop(tree_stack);
 
-            right = parse_factor(parser);
-            bin_op->right = right;
+        if (bh_arr_is_empty(tree_stack)) {
+            // NOTE: new is now the root node
+            bin_op->left = root;
+            root = (AstTyped *) bin_op;
+        } else {
+            bin_op->left = bh_arr_last(tree_stack)->right;
+            bh_arr_last(tree_stack)->right = (AstTyped *) bin_op;
         }
+
+        bh_arr_push(tree_stack, bin_op);
+
+        right = parse_factor(parser);
+        bin_op->right = right;
     }
 
     bh_arr_free(tree_stack);
@@ -1838,6 +1840,7 @@ static AstFunction* parse_function_definition(OnyxParser* parser) {
 
     AstFunction* func_def = make_node(AstFunction, Ast_Kind_Function);
     func_def->token = proc_token;
+    func_def->operator_overload = -1;
 
     bh_arr_new(global_heap_allocator, func_def->allocate_exprs, 4);
     bh_arr_new(global_heap_allocator, func_def->params, 4);
@@ -1864,10 +1867,25 @@ static AstFunction* parse_function_definition(OnyxParser* parser) {
                 expect_token(parser, Token_Type_Symbol);
 
             } else {
+                if (bh_arr_length(parser->block_stack) != 0) {
+                    onyx_report_error(parser->curr->pos, "#add_overload cannot be placed on procedures inside of other scopes");
+                }
+
                 func_def->overloaded_function = (AstNode *) parse_expression(parser);
             }
         }
 
+        else if (parse_possible_directive(parser, "operator")) {
+            BinaryOp op = binary_op_from_token_type(parser->curr->type);
+            consume_token(parser);
+            
+            if (op == Binary_Op_Count) {
+                onyx_report_error(parser->curr->pos, "Invalid binary operator.");
+            } else {
+                func_def->operator_overload = op;
+            }
+        }
+
         else if (parse_possible_directive(parser, "intrinsic")) {
             func_def->flags |= Ast_Flag_Intrinsic;
 
index 3f9feb93240e2397d8ef4190f9ca6f9acaaf7e65..10ab50b2bc38b0328a98615efa5ce951b7cdd987 100644 (file)
@@ -657,6 +657,14 @@ void symres_function_header(AstFunction* func) {
         }
     }
 
+    if (func->operator_overload != (BinaryOp) -1) {
+        if (bh_arr_length(func->params) != 2) {
+            onyx_report_error(func->token->pos, "Expected 2 exactly arguments for binary operator overload.");
+        }
+
+        bh_arr_push(operator_overloads[func->operator_overload], (AstTyped *) func);
+    }
+
     func->return_type = symres_type(func->return_type);
     if (!node_is_type((AstNode *) func->return_type)) {
         onyx_report_error(func->token->pos, "Return type is not a type.");