resolving polymorphic params in function parameters; code cleanup
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Tue, 1 Sep 2020 19:01:24 +0000 (14:01 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Tue, 1 Sep 2020 19:01:24 +0000 (14:01 -0500)
include/onyxparser.h
onyx
progs/poly_test.onyx
src/onyxparser.c
src/onyxtypes.c
src/onyxutils.c

index 0da8dc2a8916a10ba27f602386e2da7037312b9d..2d5766a0a220b4d1b6ca6adc99ac1d22f6163907 100644 (file)
@@ -23,6 +23,11 @@ typedef struct ParseResults {
     bh_arr(NodeToProcess) nodes_to_process;
 } ParseResults;
 
+typedef struct PolymorphicContext {
+    AstType* root_node;
+    bh_arr(AstPolyParam)* poly_params;
+} PolymorphicContext;
+
 typedef struct OnyxParser {
     bh_allocator allocator;
 
@@ -36,6 +41,8 @@ typedef struct OnyxParser {
 
     ParseResults results;
 
+    PolymorphicContext polymorph_context;
+
     b32 hit_unexpected_token : 1;
 } OnyxParser;
 
diff --git a/onyx b/onyx
index 708fe4561c3244026968b17a827fb697c16b5720..03ec709f4d09b4175d948a951e143a5b86d003e4 100755 (executable)
Binary files a/onyx and b/onyx differ
index 4d0dbafe9d802f29e9c32fc6375d6dfb40922e57..eca29822170c324581f1482eca6ec043ef0805ce 100644 (file)
@@ -37,15 +37,14 @@ get_count :: proc (x: $T) -> u32 do return x.count;
 // array/slice in most places.
 Dummy :: struct {
     count : u32 = 5;
-    data  : ^u32;
+    data  : [5] u32;
 }
 
 
-/* TODO: Make this work at some point
-compose :: proc (a: $A, f: proc (A) -> $B, g: proc (B) -> $C) -> C {
+/* This demos some of the power you have with the polymorphic types */
+compose :: proc (a: A, f: proc ($A) -> $B, g: proc (B) -> $C) -> C {
     return a |> f() |> g();
 }
-*/
 
 
 SOA :: struct {
@@ -64,6 +63,10 @@ soa_deinit :: proc (s: ^SOA) {
 }
 
 main :: proc (args: [] cstring) {
+    res := compose(5, proc (x: i32) -> i32 do return x * 3;,
+                      proc (x: i32) -> i32 do return x + 5;);
+    print(res);
+
     s : SOA;
     soa_init(^s);
     defer soa_deinit(^s);
@@ -109,7 +112,7 @@ main :: proc (args: [] cstring) {
 
     print("Deleteing ^a[20]\n");
     ptrmap_delete(^map, ^s.a[20]);
-    
+
     print("Has ^a[20]? ");
     print(ptrmap_has(^map, ^s.a[20]));
     print("\n");
@@ -191,7 +194,9 @@ main2 :: proc (args: [] cstring) {
     print_array(varr, "\n");
 
 
-    dummy := Dummy.{ data = cast(^u32) calloc(sizeof [5] i32) };
+    dummy := cast(^Dummy) calloc(sizeof Dummy);
+    defer cfree(dummy);
+    dummy.count = 5;
     for i: 0, dummy.count do dummy.data[i] = i * 5;
 
     print_array(dummy);
index 3d7262f43e26229310d4cb81f00b985e32cbc72c..1f12e8516dd8bb5e6e6d78d714079fc8e0ea1bea 100644 (file)
@@ -35,9 +35,9 @@ static b32            parse_possible_symbol_declaration(OnyxParser* parser, AstN
 static AstReturn*     parse_return_statement(OnyxParser* parser);
 static AstBlock*      parse_block(OnyxParser* parser);
 static AstNode*       parse_statement(OnyxParser* parser);
-static AstType*       parse_type(OnyxParser* parser, bh_arr(AstPolyParam)* polymorphic_vars);
+static AstType*       parse_type(OnyxParser* parser);
 static AstStructType* parse_struct(OnyxParser* parser);
-static void           parse_function_params(OnyxParser* parser, AstFunction* func, bh_arr(AstPolyParam)* poly_vars);
+static void           parse_function_params(OnyxParser* parser, AstFunction* func);
 static b32            parse_possible_directive(OnyxParser* parser, const char* dir);
 static AstFunction*   parse_function_definition(OnyxParser* parser);
 static AstTyped*      parse_global_declaration(OnyxParser* parser);
@@ -328,7 +328,7 @@ static AstTyped* parse_factor(OnyxParser* parser) {
             AstUnaryOp* cast_node = make_node(AstUnaryOp, Ast_Kind_Unary_Op);
             cast_node->token = expect_token(parser, Token_Type_Keyword_Cast);
             expect_token(parser, '(');
-            cast_node->type_node = parse_type(parser, NULL);
+            cast_node->type_node = parse_type(parser);
             expect_token(parser, ')');
             cast_node->operation = Unary_Op_Cast;
             cast_node->expr = parse_factor(parser);
@@ -340,7 +340,7 @@ static AstTyped* parse_factor(OnyxParser* parser) {
         case Token_Type_Keyword_Sizeof: {
             AstSizeOf* so_node = make_node(AstSizeOf, Ast_Kind_Size_Of);
             so_node->token = expect_token(parser, Token_Type_Keyword_Sizeof);
-            so_node->so_type = (AstType *) parse_type(parser, NULL);
+            so_node->so_type = (AstType *) parse_type(parser);
             so_node->type_node = (AstType *) &basic_type_i32;
 
             retval = (AstTyped *) so_node;
@@ -350,7 +350,7 @@ static AstTyped* parse_factor(OnyxParser* parser) {
         case Token_Type_Keyword_Alignof: {
             AstAlignOf* ao_node = make_node(AstAlignOf, Ast_Kind_Align_Of);
             ao_node->token = expect_token(parser, Token_Type_Keyword_Alignof);
-            ao_node->ao_type = (AstType *) parse_type(parser, NULL);
+            ao_node->ao_type = (AstType *) parse_type(parser);
             ao_node->type_node = (AstType *) &basic_type_i32;
 
             retval = (AstTyped *) ao_node;
@@ -961,7 +961,7 @@ static b32 parse_possible_symbol_declaration(OnyxParser* parser, AstNode** ret)
     // NOTE: var: type
     if (parser->curr->type != ':'
             && parser->curr->type != '=') {
-        type_node = parse_type(parser, NULL);
+        type_node = parse_type(parser);
     }
 
     AstLocal* local = make_node(AstLocal, Ast_Kind_Local);
@@ -1187,7 +1187,7 @@ static AstBlock* parse_block(OnyxParser* parser) {
 
 // <symbol>
 // '^' <type>
-static AstType* parse_type(OnyxParser* parser, bh_arr(AstPolyParam)* poly_vars) {
+static AstType* parse_type(OnyxParser* parser) {
     AstType* root = NULL;
     AstType** next_insertion = &root;
 
@@ -1238,7 +1238,7 @@ static AstType* parse_type(OnyxParser* parser, bh_arr(AstPolyParam)* poly_vars)
             while (parser->curr->type != ')') {
                 if (parser->hit_unexpected_token) return root;
 
-                AstType* param_type = parse_type(parser, poly_vars);
+                AstType* param_type = parse_type(parser);
                 bh_arr_push(params, param_type);
 
                 if (parser->curr->type != ')')
@@ -1249,7 +1249,7 @@ static AstType* parse_type(OnyxParser* parser, bh_arr(AstPolyParam)* poly_vars)
             AstType* return_type = (AstType *) &basic_type_void;
             if (parser->curr->type == Token_Type_Right_Arrow) {
                 consume_token(parser);
-                return_type = parse_type(parser, poly_vars);
+                return_type = parse_type(parser);
             }
 
             u64 param_count = bh_arr_length(params);
@@ -1270,10 +1270,10 @@ static AstType* parse_type(OnyxParser* parser, bh_arr(AstPolyParam)* poly_vars)
         else if (parser->curr->type == '$') {
             bh_arr(AstPolyParam) pv = NULL;
 
-            if (poly_vars == NULL)
+            if (parser->polymorph_context.poly_params == NULL)
                 onyx_report_error(parser->curr->pos, "polymorphic variable not valid here.");
             else
-                pv = *poly_vars;
+                pv = *parser->polymorph_context.poly_params;
 
             consume_token(parser);
 
@@ -1286,11 +1286,13 @@ static AstType* parse_type(OnyxParser* parser, bh_arr(AstPolyParam)* poly_vars)
             if (pv != NULL) {
                 bh_arr_push(pv, ((AstPolyParam) {
                     .poly_sym = symbol_node,
-                    .type_expr = root,
+
+                    // These will be filled out by function_params()
+                    .type_expr = NULL,
                     .idx = -1,
                 }));
 
-                *poly_vars = pv;
+                *parser->polymorph_context.poly_params = pv;
             }
         }
 
@@ -1373,7 +1375,7 @@ static AstStructType* parse_struct(OnyxParser* parser) {
 
         mem->token = expect_token(parser, Token_Type_Symbol);
         expect_token(parser, ':');
-        mem->type_node = parse_type(parser, NULL);
+        mem->type_node = parse_type(parser);
 
         if (parser->curr->type == '=') {
             consume_token(parser);
@@ -1392,7 +1394,7 @@ static AstStructType* parse_struct(OnyxParser* parser) {
 
 // e
 // '(' (<symbol>: <type>,?)* ')'
-static void parse_function_params(OnyxParser* parser, AstFunction* func, bh_arr(AstPolyParam)* poly_vars) {
+static void parse_function_params(OnyxParser* parser, AstFunction* func) {
     if (parser->curr->type != '(')
         return;
 
@@ -1406,7 +1408,7 @@ static void parse_function_params(OnyxParser* parser, AstFunction* func, bh_arr(
     AstParam curr_param = { 0 };
 
     u32 param_idx = 0;
-    bh_arr(AstPolyParam) pv = *poly_vars;
+    assert(parser->polymorph_context.poly_params != NULL);
 
     b32 param_use = 0;
     OnyxToken* symbol;
@@ -1431,11 +1433,15 @@ static void parse_function_params(OnyxParser* parser, AstFunction* func, bh_arr(
         }
 
         if (parser->curr->type != '=') {
-            u32 old_len = bh_arr_length(pv);
-            curr_param.local->type_node = parse_type(parser, &pv);
+            i32 old_len = bh_arr_length(*parser->polymorph_context.poly_params);
+            curr_param.local->type_node = parse_type(parser);
+
+            i32 new_len = bh_arr_length(*parser->polymorph_context.poly_params);
+            i32 new_poly_params = new_len - old_len;
 
-            if (old_len != bh_arr_length(pv)) {
-                bh_arr_last(pv).idx = param_idx;
+            fori (i, 0, new_poly_params) {
+                (*parser->polymorph_context.poly_params)[old_len + i].type_expr = curr_param.local->type_node;
+                (*parser->polymorph_context.poly_params)[old_len + i].idx = param_idx;
             }
         }
 
@@ -1455,8 +1461,6 @@ static void parse_function_params(OnyxParser* parser, AstFunction* func, bh_arr(
         param_idx++;
     }
 
-    *poly_vars = pv;
-
     consume_token(parser); // Skip the )
     return;
 }
@@ -1488,13 +1492,15 @@ static AstFunction* parse_function_definition(OnyxParser* parser) {
     bh_arr(AstPolyParam) polymorphic_vars = NULL;
     bh_arr_new(global_heap_allocator, polymorphic_vars, 4);
 
-    parse_function_params(parser, func_def, &polymorphic_vars);
+    parser->polymorph_context.poly_params = &polymorphic_vars;
+    parse_function_params(parser, func_def);
+    parser->polymorph_context.poly_params = NULL;
 
     AstType* return_type = (AstType *) &basic_type_void;
     if (parser->curr->type == Token_Type_Right_Arrow) {
         expect_token(parser, Token_Type_Right_Arrow);
 
-        return_type = parse_type(parser, NULL);
+        return_type = parse_type(parser);
     }
     func_def->return_type = return_type;
 
@@ -1619,7 +1625,7 @@ static AstTyped* parse_global_declaration(OnyxParser* parser) {
         }
     }
 
-    global_node->type_node = parse_type(parser, NULL);
+    global_node->type_node = parse_type(parser);
 
     add_node_to_process(parser, (AstNode *) global_node);
 
@@ -1700,7 +1706,7 @@ static AstTyped* parse_top_level_expression(OnyxParser* parser) {
     }
     else if (parse_possible_directive(parser, "type")) {
         AstTypeAlias* alias = make_node(AstTypeAlias, Ast_Kind_Type_Alias);
-        alias->to = parse_type(parser, NULL);
+        alias->to = parse_type(parser);
         return (AstTyped *) alias;
     }
     else if (parser->curr->type == Token_Type_Keyword_Enum) {
@@ -1828,7 +1834,7 @@ static AstNode* parse_top_level_statement(OnyxParser* parser) {
                     memres->initial_value = parse_expression(parser);
 
                 } else {
-                    memres->type_node = parse_type(parser, NULL);
+                    memres->type_node = parse_type(parser);
 
                     if (parser->curr->type == '=') {
                         consume_token(parser);
@@ -1913,6 +1919,11 @@ OnyxParser onyx_parser_create(bh_allocator alloc, OnyxTokenizer *tokenizer, Prog
 
     };
 
+    parser.polymorph_context = (PolymorphicContext) {
+        .root_node = NULL,
+        .poly_params = NULL,
+    };
+
     bh_arr_new(parser.results.allocator, parser.results.includes, 4);
     bh_arr_new(parser.results.allocator, parser.results.nodes_to_process, 4);
 
index 040d26bd537f34d893d9c5367d959c4f031dbc55..5fcbe809e26eab98e80fef77b1c81a556e29ac42 100644 (file)
@@ -464,10 +464,26 @@ const char* type_get_name(Type* type) {
             else
                 return "<anonymous enum>";
 
-        case Type_Kind_Function: return bh_aprintf(global_scratch_allocator, "proc (...) -> %s", type_get_name(type->Function.return_type));
         case Type_Kind_Slice: return bh_aprintf(global_scratch_allocator, "[] %s", type_get_name(type->Slice.ptr_to_data->Pointer.elem));
         case Type_Kind_DynArray: return bh_aprintf(global_scratch_allocator, "[..] %s", type_get_name(type->DynArray.ptr_to_data->Pointer.elem));
 
+        case Type_Kind_Function: {
+            char buf[512];
+            fori (i, 0, 512) buf[i] = 0;
+
+            strncat(buf, "proc (", 511);
+            fori (i, 0, type->Function.param_count) {
+                strncat(buf, type_get_name(type->Function.params[i]), 511);
+                if (i != type->Function.param_count - 1)
+                    strncat(buf, ", ", 511);
+            }
+
+            strncat(buf, ") -> ", 511);
+            strncat(buf, type_get_name(type->Function.return_type), 511);
+
+            return bh_aprintf(global_scratch_allocator, "%s", buf);
+        }
+
         default: return "unknown";
     }
 }
index 67d91f72086e4e5a2647d22f806960e5c3730e8a..a5eec07cf0dc27c0c6d13ab9ad6832cf0eb9f954 100644 (file)
@@ -342,41 +342,90 @@ void promote_numlit_to_larger(AstNumLit* num) {
     }
 }
 
+typedef struct PolySolveElem {
+    AstType* type_expr;
+    Type*    actual;
+} PolySolveElem;
+
 static Type* solve_poly_type(AstNode* target, AstType* type_expr, Type* actual) {
-    while (1) {
-        if (type_expr == (AstType *) target) return actual;
+    bh_arr(PolySolveElem) elem_queue = NULL;
+    bh_arr_new(global_heap_allocator, elem_queue, 4);
+
+    Type* result = NULL;
+
+    bh_arr_push(elem_queue, ((PolySolveElem) {
+        .type_expr = type_expr,
+        .actual    = actual
+    }));
+
+    while (!bh_arr_is_empty(elem_queue)) {
+        PolySolveElem elem = elem_queue[0];
+        bh_arr_deleten(elem_queue, 0, 1);
 
-        switch (type_expr->kind) {
+        if (elem.type_expr == (AstType *) target) {
+            result = elem.actual;
+            break;
+        }
+
+        switch (elem.type_expr->kind) {
             case Ast_Kind_Pointer_Type: {
-                if (actual->kind != Type_Kind_Pointer) return NULL;
+                if (elem.actual->kind != Type_Kind_Pointer) break;
 
-                type_expr = ((AstPointerType *) type_expr)->elem;
-                actual = actual->Pointer.elem;
+                bh_arr_push(elem_queue, ((PolySolveElem) {
+                    .type_expr = ((AstPointerType *) elem.type_expr)->elem,
+                    .actual = elem.actual->Pointer.elem,
+                }));
                 break;
             }
 
             case Ast_Kind_Slice_Type: {
-                if (actual->kind != Type_Kind_Slice) return NULL;
+                if (elem.actual->kind != Type_Kind_Slice) break;
 
-                type_expr = ((AstSliceType *) type_expr)->elem;
-                actual = actual->Slice.ptr_to_data->Pointer.elem;
+                bh_arr_push(elem_queue, ((PolySolveElem) {
+                    .type_expr = ((AstSliceType *) elem.type_expr)->elem,
+                    .actual = elem.actual->Slice.ptr_to_data->Pointer.elem,
+                }));
                 break;
             }
 
             case Ast_Kind_DynArr_Type: {
-                if (actual->kind != Type_Kind_DynArray) return NULL;
+                if (elem.actual->kind != Type_Kind_DynArray) break;
 
-                type_expr = ((AstDynArrType *) type_expr)->elem;
-                actual = actual->DynArray.ptr_to_data->Pointer.elem;
+                bh_arr_push(elem_queue, ((PolySolveElem) {
+                    .type_expr = ((AstDynArrType *) elem.type_expr)->elem,
+                    .actual = elem.actual->DynArray.ptr_to_data->Pointer.elem,
+                }));
                 break;
             }
 
-            default:
-                return NULL;
+            case Ast_Kind_Function_Type: {
+                if (elem.actual->kind != Type_Kind_Function) break;
+
+                AstFunctionType* ft = (AstFunctionType *) elem.type_expr;
+
+                fori (i, 0, ft->param_count) {
+                    bh_arr_push(elem_queue, ((PolySolveElem) {
+                        .type_expr = ft->params[i],
+                        .actual = elem.actual->Function.params[i],
+                    }));
+                }
+
+                bh_arr_push(elem_queue, ((PolySolveElem) {
+                    .type_expr = ft->return_type,
+                    .actual = elem.actual->Function.return_type,
+                }));
+
+                break;
+            }
+
+            default: break;
         }
     }
 
-    return NULL;
+solving_done:
+    bh_arr_free(elem_queue);
+
+    return result;
 }
 
 AstFunction* polymorphic_proc_lookup(AstPolyProc* pp, PolyProcLookupMethod pp_lookup, ptr actual, OnyxFilePos pos) {
@@ -417,12 +466,7 @@ AstFunction* polymorphic_proc_lookup(AstPolyProc* pp, PolyProcLookupMethod pp_lo
         Type* resolved_type = solve_poly_type(param->poly_sym, param->type_expr, actual_type);
 
         if (resolved_type == NULL) {
-            if (pp_lookup == PPLM_By_Call) {
-                onyx_report_error(pos, "Unable to match polymorphic procedure type.");
-            }
-            else if (pp_lookup == PPLM_By_Function_Type) {
-                onyx_report_error(pos, "Unable to match polymorphic procedure type.");
-            }
+            onyx_report_error(pos, "Unable to match polymorphic procedure type with actual type, '%s'.", type_get_name(actual_type));
             return NULL;
         }