changed: polymorphic structs can have specializations
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Wed, 5 Apr 2023 02:58:23 +0000 (21:58 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Wed, 5 Apr 2023 02:58:23 +0000 (21:58 -0500)
compiler/include/astnodes.h
compiler/src/astnodes.c
compiler/src/checker.c
compiler/src/lex.c
compiler/src/parser.c
compiler/src/polymorph.h
compiler/src/symres.c
compiler/src/types.c
compiler/src/utils.c
compiler/src/wasm_emit.c
core/encoding/csv.onyx

index c0ec84f0a168463f8e87cce084afad3d087abcbf..44b7c1cb6bd39f7a1dd1ea04ec4973249fe1975b 100644 (file)
@@ -999,6 +999,7 @@ struct AstPolyCallType {
     AstType_base;
 
     AstType* callee;
+    Type *resolved_type;
 
     // NOTE: These nodes can be either AstTypes, or AstTyped expressions.
     bh_arr(AstNode *) params;
index f7e81e86525c6ea5612203bbec124366a65936cd..04bf9f646f93297401339083b7192255b0b63119 100644 (file)
@@ -1704,7 +1704,7 @@ AstPolyCallType* convert_call_to_polycall(AstCall* call) {
     pct->token = call->token;
     pct->__unused = call->next;
     pct->callee = (AstType *) call->callee;
-    pct->params = (AstNode **) call->args.values;
+    pct->params = (AstNode **) bh_arr_copy(global_heap_allocator, call->args.values);
     bh_arr_each(AstNode *, pp, pct->params) {
         if ((*pp)->kind == Ast_Kind_Argument) {
             *pp = (AstNode *) (*(AstArgument **) pp)->value;
index 086e2dab86e0b02f4b0954adbce99bf36a2c4dab..d5dd147c352ed9444e077e0679373e65c312fc80 100644 (file)
@@ -118,6 +118,7 @@ b32 inside_for_iterator            = 0;
 bh_arr(AstFor *) for_node_stack    = NULL;
 static bh_imap __binop_impossible_cache[Binary_Op_Count];
 static AstCall __op_maybe_overloaded;
+static Entity *current_entity = NULL;
 
 
 #define STATEMENT_LEVEL 1
@@ -1960,6 +1961,7 @@ CheckStatus check_field_access(AstFieldAccess** pfield) {
     n = try_symbol_raw_resolve_from_type(field->expr->type, field->field);
 
     type_node = field->expr->type->ast_type;
+    if (!n) n = try_symbol_raw_resolve_from_node((AstNode *) field->expr, field->field);
     if (!n) n = try_symbol_raw_resolve_from_node((AstNode *) type_node, field->field);
 
     if (n) {
@@ -1980,15 +1982,21 @@ CheckStatus check_field_access(AstFieldAccess** pfield) {
         return Check_Yield_Macro;
     }
 
+    char* type_name = (char *) node_get_type_name(field->expr);
+    if (field->expr->type == &basic_types[Basic_Kind_Type_Index]) {
+        Type *actual_type = type_build_from_ast(context.ast_alloc, (AstType *) field->expr);
+        type_name = (char *) type_get_name(actual_type);
+    }
+
     if (!type_node) goto closest_not_found;
 
     char* closest = find_closest_symbol_in_node((AstNode *) type_node, field->field);
     if (closest) {
-        ERROR_(field->token->pos, "Field '%s' does not exists on '%s'. Did you mean '%s'?", field->field, node_get_type_name(field->expr), closest);
+        ERROR_(field->token->pos, "Field '%s' does not exists on '%s'. Did you mean '%s'?", field->field, type_name, closest);
     }
 
   closest_not_found:
-    ERROR_(field->token->pos, "Field '%s' does not exists on '%s'.", field->field, node_get_type_name(field->expr));
+    ERROR_(field->token->pos, "Field '%s' does not exists on '%s'.", field->field, type_name);
 }
 
 CheckStatus check_method_call(AstBinaryOp** pmcall) {
@@ -2568,7 +2576,7 @@ CheckStatus check_overloaded_function(AstOverloadedFunction* ofunc) {
     build_all_overload_options(ofunc->overloads, &all_overloads);
 
     bh_arr_each(bh__imap_entry, entry, all_overloads.entries) {
-        AstTyped* node = (AstTyped *) entry->key;
+        AstTyped* node = (AstTyped *) strip_aliases((AstNode *) entry->key);
         if (node->kind == Ast_Kind_Overloaded_Function) continue;
 
         if (   node->kind != Ast_Kind_Function
@@ -3189,6 +3197,33 @@ CheckStatus check_process_directive(AstNode* directive) {
         return Check_Success;
     }
 
+    if (directive->kind == Ast_Kind_Injection) {
+        AstInjection *inject = (AstInjection *) directive;
+        if (!node_is_type((AstNode *) inject->dest)) {
+            CHECK(expression, &inject->dest);
+        }
+
+        Scope *scope = get_scope_from_node_or_create((AstNode *) inject->dest);
+        if (scope == NULL) {
+            YIELD_ERROR(inject->token->pos, "Cannot #inject here.");
+        }
+
+        AstBinding *binding = onyx_ast_node_new(context.ast_alloc, sizeof(AstBinding), Ast_Kind_Binding);
+        binding->token = inject->symbol;
+        binding->node = (AstNode *) inject->to_inject;
+        binding->documentation = inject->documentation;
+
+        Package *pac = NULL;
+        if (inject->dest->kind == Ast_Kind_Package) {
+            pac = ((AstPackage *) inject->dest)->package;
+        } else {
+            pac = current_entity->package;
+        }
+
+        add_entities_for_node(NULL, (AstNode *) binding, scope, pac);
+        return Check_Complete;
+    }
+
     return Check_Success;
 }
 
@@ -3508,6 +3543,7 @@ CheckStatus check_arbitrary_job(EntityJobData *job) {
 
 void check_entity(Entity* ent) {
     CheckStatus cs = Check_Success;
+    current_entity = ent;
 
     switch (ent->type) {
         case Entity_Type_Foreign_Function_Header:
index 7bf9aa1fec322f129057816ee83e5745f2b54a8b..45918bc9efa20b3276295d02ccd2a855ab3490f1 100644 (file)
@@ -131,6 +131,7 @@ const char* token_name(TokenType tkn_type) {
 void token_toggle_end(OnyxToken* tkn) {
     static char backup = 0;
     char tmp = tkn->text[tkn->length];
+    assert(backup == '\0' || tmp == '\0'); // Sanity check
     tkn->text[tkn->length] = backup;
     backup = tmp;
 }
index 400b4b8a4504deb87e1ba50481e28fd6e4917b81..11c36a2eea3cbe337a25b36f6a468ec60b2a9ef5 100644 (file)
@@ -848,10 +848,7 @@ static AstTyped* parse_factor(OnyxParser* parser) {
                     onyx_report_error((parser->curr - 2)->pos, Error_Critical, "#Self is only allowed in an #inject block.");
                 }
 
-                AstAlias* alias = make_node(AstAlias, Ast_Kind_Alias);
-                alias->token = parser->injection_point->token;
-                alias->alias = parser->injection_point;
-                retval = (AstTyped *) alias;
+                retval = (AstTyped *) parser->injection_point;
                 break;
             }
 
@@ -2069,7 +2066,6 @@ static AstTypeOf* parse_typeof(OnyxParser* parser) {
 static void struct_type_create_scope(OnyxParser *parser, AstStructType *s_node) {
     if (!s_node->scope) {
         s_node->scope = scope_create(context.ast_alloc, parser->current_scope, s_node->token->pos);
-        parser->current_scope = s_node->scope;
 
         if (bh_arr_length(parser->current_symbol_stack) == 0) {
             s_node->scope->name = "<anonymous>";
@@ -2092,6 +2088,10 @@ static AstStructType* parse_struct(OnyxParser* parser) {
 
     flush_stored_tags(parser, &s_node->meta_tags);
 
+    struct_type_create_scope(parser, s_node);
+    Scope *scope_to_restore_parser_to = parser->current_scope;
+    Scope *scope_symbols_in_structures_should_be_bound_to = s_node->scope;
+
     // Parse polymorphic parameters
     if (consume_token_if_next(parser, '(')) {
         bh_arr(AstPolyStructParam) poly_params = NULL;
@@ -2119,6 +2119,8 @@ static AstStructType* parse_struct(OnyxParser* parser) {
         poly_struct->token = s_token;
         poly_struct->poly_params = poly_params;
         poly_struct->base_struct = s_node;
+        poly_struct->scope = s_node->scope;
+        s_node->scope = NULL;
     }
 
     // Parse constraints clause
@@ -2153,8 +2155,8 @@ static AstStructType* parse_struct(OnyxParser* parser) {
     }
 
     expect_token(parser, '{');
-
-    struct_type_create_scope(parser, s_node);
+    
+    parser->current_scope = scope_symbols_in_structures_should_be_bound_to;
 
     b32 member_is_used = 0;
     bh_arr(OnyxToken *) member_list_temp = NULL;
@@ -2258,7 +2260,7 @@ static AstStructType* parse_struct(OnyxParser* parser) {
         }
     }
 
-    if (s_node->scope) parser->current_scope = parser->current_scope->parent;
+    parser->current_scope = scope_to_restore_parser_to;
 
     bh_arr_free(member_list_temp);
 
@@ -3411,10 +3413,9 @@ static void parse_top_level_statement(OnyxParser* parser) {
                 return;
             }
             else if (parse_possible_directive(parser, "inject")) {
-                AstTyped *injection_point;
-                parser->parse_calls = 0;
-                injection_point = parse_expression(parser, 0);
-                parser->parse_calls = 1;
+                AstAlias *injection_point = make_node(AstAlias, Ast_Kind_Alias);
+                injection_point->alias = parse_expression(parser, 0);
+                injection_point->token = injection_point->alias->token;
 
                 if (peek_token(0)->type == '{') {
                     if (parser->injection_point) {
@@ -3422,7 +3423,7 @@ static void parse_top_level_statement(OnyxParser* parser) {
                         return;
                     }
 
-                    parser->injection_point = injection_point;
+                    parser->injection_point = (AstTyped *) injection_point;
 
                     expect_token(parser, '{');
                     parse_top_level_statements_until(parser, '}');
@@ -3439,7 +3440,7 @@ static void parse_top_level_statement(OnyxParser* parser) {
 
                 AstInjection *inject = make_node(AstInjection, Ast_Kind_Injection);
                 inject->token = dir_token;
-                inject->full_loc = injection_point;
+                inject->full_loc = (AstTyped *) injection_point;
                 inject->to_inject = parse_top_level_expression(parser);
                 if (parser->last_documentation_token) {
                     inject->documentation = parser->last_documentation_token;
index 4b2b7630d7abd3a7bb43068b0cd0679b7e8a18bb..a79c9373199eafd6429bad091c23e17386b60a8f 100644 (file)
@@ -963,12 +963,13 @@ b32 potentially_convert_function_to_polyproc(AstFunction *func) {
         b32 done = 0;
         while (!done && param_type) {
             switch (param_type->kind) {
-                case Ast_Kind_Pointer_Type: to_replace = &((AstPointerType *) *to_replace)->elem;  param_type = ((AstPointerType *) param_type)->elem;  break;
-                case Ast_Kind_Array_Type:   to_replace = &((AstArrayType *)   *to_replace)->elem;  param_type = ((AstArrayType *)   param_type)->elem;  break;
-                case Ast_Kind_Slice_Type:   to_replace = &((AstSliceType *)   *to_replace)->elem;  param_type = ((AstSliceType *)   param_type)->elem;  break;
-                case Ast_Kind_DynArr_Type:  to_replace = &((AstDynArrType *)  *to_replace)->elem;  param_type = ((AstDynArrType *)  param_type)->elem;  break;
-                case Ast_Kind_Alias:                                                               param_type = (AstType *) ((AstAlias *) param_type)->alias; break;
-                case Ast_Kind_Type_Alias:                                                          param_type = ((AstTypeAlias *)   param_type)->to;    break;
+                case Ast_Kind_Pointer_Type:       to_replace = &((AstPointerType *)      *to_replace)->elem;  param_type = ((AstPointerType *) param_type)->elem;  break;
+                case Ast_Kind_Multi_Pointer_Type: to_replace = &((AstMultiPointerType *) *to_replace)->elem;  param_type = ((AstMultiPointerType *) param_type)->elem;  break;
+                case Ast_Kind_Array_Type:         to_replace = &((AstArrayType *)        *to_replace)->elem;  param_type = ((AstArrayType *)   param_type)->elem;  break;
+                case Ast_Kind_Slice_Type:         to_replace = &((AstSliceType *)        *to_replace)->elem;  param_type = ((AstSliceType *)   param_type)->elem;  break;
+                case Ast_Kind_DynArr_Type:        to_replace = &((AstDynArrType *)       *to_replace)->elem;  param_type = ((AstDynArrType *)  param_type)->elem;  break;
+                case Ast_Kind_Alias:                                                                          param_type = (AstType *) ((AstAlias *) param_type)->alias; break;
+                case Ast_Kind_Type_Alias:                                                                     param_type = ((AstTypeAlias *)   param_type)->to;    break;
                 case Ast_Kind_Poly_Struct_Type: {
                     AutoPolymorphVariable apv;
                     apv.idx = param_idx;
@@ -1003,7 +1004,9 @@ b32 potentially_convert_function_to_polyproc(AstFunction *func) {
         pcall->flags |= Ast_Flag_Poly_Call_From_Auto;
         bh_arr_new(global_heap_allocator, pcall->params, apv->variable_count);
 
-        if (apv->base_type->kind == Ast_Kind_Poly_Struct_Type) {
+        AstType *dealiased_base_type = (AstType *) strip_aliases((AstNode *) apv->base_type);
+
+        if (dealiased_base_type->kind == Ast_Kind_Poly_Struct_Type) {
             pp.type_expr = (AstType *) pcall;
         } else {
             pp.type_expr = apv->base_type;
@@ -1106,6 +1109,8 @@ Type* polymorphic_struct_lookup(AstPolyStructType* ps_type, bh_arr(AstPolySoluti
         return NULL;
     }
 
+    assert(!ps_type->base_struct->scope);
+
     if (ps_type->concrete_structs == NULL) {
         sh_new_arena(ps_type->concrete_structs);
     }
@@ -1153,6 +1158,7 @@ Type* polymorphic_struct_lookup(AstPolyStructType* ps_type, bh_arr(AstPolySoluti
     insert_poly_slns_into_scope(sln_scope, slns);
 
     AstStructType* concrete_struct = (AstStructType *) ast_clone(context.ast_alloc, ps_type->base_struct);
+    concrete_struct->scope = scope_create(context.ast_alloc, sln_scope, ps_type->token->pos);
     concrete_struct->polymorphic_error_loc = pos;
     BH_MASK_SET(concrete_struct->flags, !error_if_failed, Ast_Flag_Header_Check_No_Error);
 
index a5e8662187872f6ddb2f89b63c0e711ea79ec15c..73dae651c11e27f86eb9d2db4b4ab0c67ea7782c 100644 (file)
@@ -115,13 +115,8 @@ static SymresStatus symres_struct_type(AstStructType* s_node) {
     s_node->flags |= Ast_Flag_Type_Is_Resolved;
     s_node->flags |= Ast_Flag_Comptime;
 
-    if (s_node->scope) {
-        assert(s_node->entity);
-        assert(s_node->entity->scope);
-        s_node->scope->parent = s_node->entity->scope;
-
-        scope_enter(s_node->scope);
-    }
+    assert(s_node->scope);
+    scope_enter(s_node->scope);
     
     if (s_node->min_size_)      SYMRES(expression, &s_node->min_size_);
     if (s_node->min_alignment_) SYMRES(expression, &s_node->min_alignment_);
@@ -210,10 +205,7 @@ static SymresStatus symres_type(AstType** type) {
 
         case Ast_Kind_Poly_Struct_Type: {
             AstPolyStructType* pst_node = (AstPolyStructType *) *type;
-
-            if (pst_node->scope == NULL) {
-                pst_node->scope = scope_create(context.ast_alloc, pst_node->entity->scope, pst_node->token->pos);
-            }
+            assert(pst_node->scope);
             break;
         }
 
@@ -1430,43 +1422,21 @@ static SymresStatus symres_process_directive(AstNode* directive) {
             if (inject->dest == NULL) {
                 if (inject->full_loc == NULL) return Symres_Error;
 
-                if (inject->full_loc->kind != Ast_Kind_Field_Access) {
+                AstTyped *full_loc = (AstTyped *) strip_aliases((AstNode *) inject->full_loc);
+
+                if (full_loc->kind != Ast_Kind_Field_Access) {
                     onyx_report_error(inject->token->pos, Error_Critical, "#inject expects a dot (a.b) expression for the injection point.");
                     return Symres_Error;
                 }
 
-                AstFieldAccess *acc = (AstFieldAccess *) inject->full_loc;
+                AstFieldAccess *acc = (AstFieldAccess *) full_loc;
                 inject->dest = acc->expr;
                 inject->symbol = acc->token;
             }
 
             SYMRES(expression, &inject->dest);
             SYMRES(expression, &inject->to_inject);
-
-            Scope *scope = get_scope_from_node_or_create((AstNode *) inject->dest);
-            if (scope == NULL) {
-                if (context.cycle_almost_detected >= 1) {
-                    onyx_report_error(inject->token->pos, Error_Critical, "Cannot #inject here.");
-                    return Symres_Error;
-                }
-
-                return Symres_Yield_Macro;
-            }
-
-            AstBinding *binding = onyx_ast_node_new(context.ast_alloc, sizeof(AstBinding), Ast_Kind_Binding);
-            binding->token = inject->symbol;
-            binding->node = (AstNode *) inject->to_inject;
-            binding->documentation = inject->documentation;
-
-            Package *pac = NULL;
-            if (inject->dest->kind == Ast_Kind_Package) {
-                pac = ((AstPackage *) inject->dest)->package;
-            } else {
-                pac = current_entity->package;
-            }
-
-            add_entities_for_node(NULL, (AstNode *) binding, scope, pac);
-            return Symres_Complete;
+            break;
         }
 
         case Ast_Kind_Directive_This_Package: {
@@ -1612,7 +1582,7 @@ static SymresStatus symres_foreign_block(AstForeignBlock *fb) {
                 return Symres_Error;
             }
 
-            ent->function->foreign.import_name = make_string_literal(context.ast_alloc, ent->function->intrinsic_name);
+            ent->function->foreign.import_name = (AstTyped *) make_string_literal(context.ast_alloc, ent->function->intrinsic_name);
             ent->function->foreign.module_name = fb->module_name;
             ent->function->is_foreign = 1;
             ent->function->is_foreign_dyncall = fb->uses_dyncall;
index 4611cbedd1e7778c8165ed375059271036416618..f03c91a2c4788eab6f278d05db84fb61a5c01479 100644 (file)
@@ -449,6 +449,7 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b
                 token_toggle_end((*member)->token);
                 if (shgeti(s_type->Struct.members, (*member)->token->text) != -1) {
                     onyx_report_error((*member)->token->pos, Error_Critical, "Duplicate struct member, '%s'.", (*member)->token->text);
+                    token_toggle_end((*member)->token);
                     return NULL;
                 }
 
@@ -602,6 +603,7 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b
             if (!concrete) return NULL;
             if (concrete == (Type *) &node_that_signals_failure) return concrete;
             concrete->Struct.constructed_from = (AstType *) ps_type;
+            pc_type->resolved_type = concrete;
             return concrete;
         }
 
index e474f008e5311bae604cc5acc9e92440885cd2b2..2cbb09f36d0fb45df39adead912c86859e800708 100644 (file)
@@ -279,13 +279,25 @@ all_types_peeled_off:
             // Temporarily disable the parent scope so that you can't access things
             // "above" the structures scope. This leads to unintended behavior, as when
             // you are accessing a static element on a structure, you don't expect to
-            // bleed to the top level scope.
+            // bleed to the top level scope. This code is currently very GROSS, and
+            // should be refactored soon.
             AstNode *result = NULL;
             if (stype->scope) {
-                Scope *tmp_parent = stype->scope->parent;
-                stype->scope->parent = NULL;
+                Scope **tmp_parent;
+                Scope *tmp_parent_backup;
+                if (stype->stcache && stype->stcache->Struct.constructed_from) {
+                    // Structs scope -> Poly Solution Scope -> Poly Struct Scope -> Enclosing Scope
+                    tmp_parent = &stype->scope->parent->parent->parent;
+                } else {
+                    tmp_parent = &stype->scope->parent;
+                }
+
+                tmp_parent_backup = *tmp_parent;
+                *tmp_parent = NULL;
+
                 result = symbol_raw_resolve(stype->scope, symbol);
-                stype->scope->parent = tmp_parent;
+
+                *tmp_parent = tmp_parent_backup;
             }
 
             if (result == NULL && stype->stcache != NULL) {
@@ -307,13 +319,16 @@ all_types_peeled_off:
         }
 
         case Ast_Kind_Poly_Struct_Type: {
-            AstStructType* stype = ((AstPolyStructType *) node)->base_struct;
+            AstPolyStructType* stype = ((AstPolyStructType *) node);
             return symbol_raw_resolve(stype->scope, symbol);
         }
 
         case Ast_Kind_Poly_Call_Type: {
-            AstNode* callee = (AstNode *) ((AstPolyCallType *) node)->callee;
-            return try_symbol_raw_resolve_from_node(callee, symbol);
+            AstPolyCallType* pctype = (AstPolyCallType *) node;
+            if (pctype->resolved_type) {
+                return try_symbol_raw_resolve_from_node((AstNode*) pctype->resolved_type->ast_type, symbol);
+            }
+            return NULL;
         }
 
         case Ast_Kind_Distinct_Type: {
@@ -1259,8 +1274,16 @@ all_types_peeled_off:
 
         case Ast_Kind_Poly_Struct_Type: {
             AstPolyStructType* pstype = (AstPolyStructType *) node;
-            AstStructType* stype = pstype->base_struct;
-            return &stype->scope;
+            return &pstype->scope;
+        }
+
+        case Ast_Kind_Poly_Call_Type: {
+            AstPolyCallType* pctype = (AstPolyCallType *) node;
+            Type *t = type_build_from_ast(context.ast_alloc, (AstType *) pctype);
+            if (t) {
+                return &((AstStructType *) t->ast_type)->scope;
+            }
+            return NULL;
         }
 
         case Ast_Kind_Distinct_Type: {
index 7595168e7540d40180504bb4f5a9883f2617b846..f94c0f923f919991329e4dd4a125349f06d524e0 100644 (file)
@@ -3198,7 +3198,12 @@ EMIT_FUNC(location_return_offset, AstTyped* expr, u64* offset_return) {
         }
 
         default: {
-            onyx_report_error(expr->token->pos, Error_Critical, "Unable to generate location for '%s'.", onyx_ast_node_kind_string(expr->kind));
+            if (expr->token) {
+                onyx_report_error(expr->token->pos, Error_Critical, "Unable to generate location for '%s'.", onyx_ast_node_kind_string(expr->kind));
+            } else {
+                OnyxFilePos pos = {0};
+                onyx_report_error(pos, Error_Critical, "Unable to generate location for '%s'.", onyx_ast_node_kind_string(expr->kind));
+            }
             break;
         }
     }
index c3a636a6d18dafd75aa617a076eb53fca1e1cdc0..25b4b7247fd55f46d1ac5e60e6c4fe2bd355daac 100644 (file)
@@ -112,7 +112,7 @@ CSV_Column :: struct {
         for line: reader->lines(allocator = context.temp_allocator) {
             out: csv.Output_Type;
 
-            for entry: string.split_iter(line, #char ",")
+            for entry: string.split_iter(string.strip_trailing_whitespace(line), #char ",")
                     |> iter.enumerate()
             {
                 header := &any_headers[entry.index];
@@ -157,7 +157,7 @@ CSV_Column :: struct {
             for &member: output_type_info.members {
                 if !#first do io.write(writer, ",");
 
-                io.write_format_va(writer, "{}", .[ .{cast(&u8) it + member.offset, member.type} ]);
+                io.write_format_va(writer, "{}", .[ .{cast([&] u8) it + member.offset, member.type} ]);
             }
 
             io.write(writer, "\n");