added '#inject'
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Sun, 28 Aug 2022 19:26:34 +0000 (14:26 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Sun, 28 Aug 2022 19:26:34 +0000 (14:26 -0500)
include/astnodes.h
include/utils.h
src/astnodes.c
src/entities.c
src/parser.c
src/symres.c
src/utils.c

index 3f2069571c62a060756d9571a06912ccb9f0c2d0..0052d233c4b424f8e53befaec5126b7ff75f2327 100644 (file)
@@ -77,6 +77,7 @@
                                \
     NODE(Binding)              \
     NODE(Alias)                \
+    NODE(Injection)            \
     NODE(MemRes)               \
     NODE(Include)              \
     NODE(UsePackage)           \
@@ -132,6 +133,7 @@ typedef enum AstKind {
 
     Ast_Kind_Binding,
     Ast_Kind_Alias,
+    Ast_Kind_Injection,
     Ast_Kind_Function,
     Ast_Kind_Overloaded_Function,
     Ast_Kind_Polymorphic_Proc,
@@ -1011,6 +1013,13 @@ struct AstDistinctType {
 struct AstBinding       { AstTyped_base; AstNode* node; };
 struct AstAlias         { AstTyped_base; AstTyped* alias; };
 struct AstInclude       { AstNode_base;  AstTyped* name_node; char* name; };
+struct AstInjection     {
+    AstTyped_base;
+    AstTyped* full_loc;
+    AstTyped* to_inject;
+    AstTyped* dest;
+    OnyxToken *symbol;
+};
 struct AstMemRes        {
     AstTyped_base;
     AstTyped *initial_value;
index 46ee92b68606bb18779b172cb10ed48d04b0fe4d..2b0083c2c7afccb3c013c253818d9f11968fd6d1 100644 (file)
@@ -26,6 +26,8 @@ AstNode* symbol_resolve(Scope* start_scope, OnyxToken* tkn);
 AstNode* try_symbol_raw_resolve_from_node(AstNode* node, char* symbol);
 AstNode* try_symbol_resolve_from_node(AstNode* node, OnyxToken* token);
 AstNode* try_symbol_raw_resolve_from_type(Type *type, char* symbol);
+Scope *get_scope_from_node(AstNode *node);
+Scope *get_scope_from_node_or_create(AstNode *node);
 
 void build_all_overload_options(bh_arr(OverloadOption) overloads, bh_imap* all_overloads);
 
index 71d36652906ff6d02d784a0cedab5a2c589b8c68..f2c6faf2bec8cb48bb894fef819ba65dea42eb70 100644 (file)
@@ -13,6 +13,7 @@ static const char* ast_node_names[] = {
 
     "BINDING",
     "ALIAS",
+    "INJECTION",
     "FUNCTION",
     "OVERLOADED_FUNCTION",
     "POLYMORPHIC PROC",
index e5961e5111ed366e27eb222c170dc32a8b03a47b..0ceb2f9286b5eb8fdd34bdc11f1379493aa40264 100644 (file)
@@ -354,7 +354,8 @@ void add_entities_for_node(bh_arr(Entity *) *target_arr, AstNode* node, Scope* s
         case Ast_Kind_Directive_Add_Overload:
         case Ast_Kind_Directive_Operator:
         case Ast_Kind_Directive_Init:
-        case Ast_Kind_Directive_Library: {
+        case Ast_Kind_Directive_Library:
+        case Ast_Kind_Injection: {
             ent.type = Entity_Type_Process_Directive;
             ent.expr = (AstTyped *) node;
             ENTITY_INSERT(ent);
index 82ca91b90cb5b48daba112f83b7c33ddc841c988..cd289d4821255e65c18b724653864b6b63cb394a 100644 (file)
@@ -2435,6 +2435,7 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok
 
         } else {
             AstTyped* returned_value = parse_compound_expression(parser, 0);
+            if (returned_value == NULL) goto function_defined;
 
             AstReturn* return_node = make_node(AstReturn, Ast_Kind_Return);
             return_node->token = returned_value->token;
@@ -3249,6 +3250,24 @@ static void parse_top_level_statement(OnyxParser* parser) {
                 ENTITY_SUBMIT(add_overload);
                 return;
             }
+            else if (parse_possible_directive(parser, "inject")) {
+                AstInjection *inject = make_node(AstInjection, Ast_Kind_Injection);
+                inject->token = dir_token;
+
+                parser->parse_calls = 0;
+                inject->full_loc = parse_expression(parser, 0);
+                parser->parse_calls = 1;
+
+                // See comment above
+                if (next_tokens_are(parser, 2, ':', ':')) {
+                    consume_tokens(parser, 2);
+                }
+
+                inject->to_inject = parse_expression(parser, 0);
+                
+                ENTITY_SUBMIT(inject);
+                return;
+            }
             else if (parse_possible_directive(parser, "export")) {
                 AstDirectiveExport *export = make_node(AstDirectiveExport, Ast_Kind_Directive_Export);
                 export->token = dir_token;
index 50b87b61281bfdb85bfac9099fd4975e164ca2ce..16206daded9ef79e20ecb03faaa26234aebe3b7d 100644 (file)
@@ -1455,6 +1455,44 @@ static SymresStatus symres_process_directive(AstNode* directive) {
             SYMRES(expression, &library->library_symbol);
             break;
         }
+
+        case Ast_Kind_Injection: {
+            AstInjection *inject = (AstInjection *) directive;
+
+            if (inject->dest == NULL) {
+                if (inject->full_loc == NULL) return Symres_Error;
+
+                if (inject->full_loc->kind != Ast_Kind_Field_Access) {
+                    onyx_report_error(inject->token->pos, Error_Critical, "#inject expects a dot (a.b) expression for the injection point.");
+                    return Symres_Error;
+                }
+
+                AstFieldAccess *acc = (AstFieldAccess *) inject->full_loc;
+                inject->dest = acc->expr;
+                inject->symbol = acc->token;
+            }
+
+            SYMRES(expression, &inject->dest);
+            SYMRES(expression, &inject->to_inject);
+
+            Scope *scope = get_scope_from_node_or_create((AstNode *) inject->dest);
+            if (scope == NULL) {
+                onyx_report_error(inject->token->pos, Error_Critical, "Cannot #inject here.");
+                return Symres_Error;
+            }
+
+            AstBinding *binding = onyx_ast_node_new(context.ast_alloc, sizeof(AstBinding), Ast_Kind_Binding);
+            binding->token = inject->symbol;
+            binding->node = (AstNode *) inject->to_inject;
+
+            Package *pac = NULL;
+            if (inject->dest->kind == Ast_Kind_Package) {
+                pac = ((AstPackage *) inject->dest)->package;
+            }
+
+            add_entities_for_node(NULL, (AstNode *) binding, scope, pac);
+            return Symres_Complete;
+        }
     }
 
     return Symres_Success;
index 8c6c259badf20d5658b79068022aff7b9f60d1fa..7f4ea49b187e55dcc6823a9803d6ca5d81d450d6 100644 (file)
@@ -74,6 +74,7 @@ void package_track_use_package(Package* package, Entity* entity) {
 }
 
 void package_reinsert_use_packages(Package* package) {
+    if (!package) return;
     if (!package->use_package_entities) return;
 
     bh_arr_each(Entity *, use_package, package->use_package_entities) {
@@ -1046,6 +1047,85 @@ i32 string_process_escape_seqs(char* dest, char* src, i32 len) {
     return total_len;
 }
 
+static Scope **get_scope_from_node_helper(AstNode *node) {
+    b32 used_pointer = 0;
+
+    while (1) {
+        if (!node) return NULL;
+
+        switch (node->kind) {
+            case Ast_Kind_Type_Raw_Alias: node = (AstNode *) ((AstTypeRawAlias *) node)->to->ast_type; break;
+            case Ast_Kind_Type_Alias:     node = (AstNode *) ((AstTypeAlias *) node)->to; break;
+            case Ast_Kind_Alias:          node = (AstNode *) ((AstAlias *) node)->alias; break;
+            case Ast_Kind_Pointer_Type: {
+                if (used_pointer) goto all_types_peeled_off;
+                used_pointer = 1;
+
+                node = (AstNode *) ((AstPointerType *) node)->elem;
+                break;
+            }
+
+            default: goto all_types_peeled_off;
+        }
+    }
+
+all_types_peeled_off:
+    if (!node) return NULL;
+
+    switch (node->kind) {
+        case Ast_Kind_Package: {
+            AstPackage* package = (AstPackage *) node;
+            if (package->package == NULL) return NULL;
+
+            return &package->package->scope;
+        } 
+
+        case Ast_Kind_Enum_Type: {
+            AstEnumType* etype = (AstEnumType *) node;
+            return &etype->scope;
+        }
+
+        case Ast_Kind_Struct_Type: {
+            AstStructType* stype = (AstStructType *) node;
+            return &stype->scope;
+        }
+
+        case Ast_Kind_Poly_Struct_Type: {
+            AstPolyStructType* pstype = (AstPolyStructType *) node;
+            AstStructType* stype = pstype->base_struct;
+            return &stype->scope;
+        }
+    }
+
+    return NULL;
+}
+
+Scope *get_scope_from_node(AstNode *node) {
+    if (!node) return NULL;
+
+    Scope **pscope = get_scope_from_node_helper(node);
+    if (!pscope) return NULL;
+    return *pscope;
+}
+
+Scope *get_scope_from_node_or_create(AstNode *node) {
+    if (!node) return NULL;
+
+    Scope **pscope = get_scope_from_node_helper(node);
+    if (!pscope) return NULL;
+
+    // Create the scope if it does not exist.
+    // This uses a NULL parent, which I think is what 
+    // is used in other parts of the compiler for struct/enum
+    // scopes?
+    if (!*pscope) {
+        assert(node->token);
+        *pscope = scope_create(context.ast_alloc, NULL, node->token->pos);
+    }
+
+    return *pscope;
+}
+
 u32 levenshtein_distance(const char *str1, const char *str2) {
     i32 m = strlen(str1) + 1;
     i32 n = strlen(str2) + 1;
@@ -1110,90 +1190,17 @@ char *find_closest_symbol_in_scope_and_parents(Scope *scope, char *sym) {
 }
         
 char *find_closest_symbol_in_node(AstNode* node, char *sym) {
-    b32 used_pointer = 0;
-
-    while (1) {
-        if (!node) return NULL;
-
-        switch (node->kind) {
-            case Ast_Kind_Type_Raw_Alias: node = (AstNode *) ((AstTypeRawAlias *) node)->to->ast_type; break;
-            case Ast_Kind_Type_Alias:     node = (AstNode *) ((AstTypeAlias *) node)->to; break;
-            case Ast_Kind_Alias:          node = (AstNode *) ((AstAlias *) node)->alias; break;
-            case Ast_Kind_Pointer_Type: {
-                if (used_pointer) goto all_types_peeled_off;
-                used_pointer = 1;
-
-                node = (AstNode *) ((AstPointerType *) node)->elem;
-                break;
-            }
-
-            default: goto all_types_peeled_off;
-        }
-    }
-
-all_types_peeled_off:
-    if (!node) return NULL;
-
-    switch (node->kind) {
-        case Ast_Kind_Package: {
-            AstPackage* package = (AstPackage *) node;
-            if (package->package == NULL) return NULL;
-
-            u32 dist;
-            return find_closest_symbol_in_scope(package->package->scope, sym, &dist);
-        } 
-
-        case Ast_Kind_Enum_Type: {
-            AstEnumType* etype = (AstEnumType *) node;
-            u32 dist;
-            return find_closest_symbol_in_scope(etype->scope, sym, &dist);
-        }
-
-        case Ast_Kind_Struct_Type: {
-            AstStructType* stype = (AstStructType *) node;
-
-            u32 dist;
-            char *closest = find_closest_symbol_in_scope(stype->scope, sym, &dist);
-
-            Type *type = type_build_from_ast(context.ast_alloc, (AstType *) stype);
-            assert(type);
-            bh_arr_each(StructMember *, mem, type->Struct.memarr) {
-                u32 d = levenshtein_distance((*mem)->name, sym);
-                if (d < dist) {
-                    dist = d;
-                    closest = (*mem)->name;
-                }
-            }
-
-            return closest;
-        }
-
-        case Ast_Kind_Poly_Struct_Type: {
-            AstPolyStructType* pstype = (AstPolyStructType *) node;
-            AstStructType* stype = pstype->base_struct;
-            u32 dist;
-            char *closest =  find_closest_symbol_in_scope(stype->scope, sym, &dist);
-
-            bh_arr_each(AstPolyStructParam, param, pstype->poly_params) {
-                token_toggle_end(param->token);
-                u32 d = levenshtein_distance(param->token->text, sym);
-
-                if (d < dist) {
-                    dist = d;
-                    closest = bh_strdup(context.ast_alloc, param->token->text);
-                }
-                token_toggle_end(param->token);
-            }
-
-            return closest;
-        }
-
-        case Ast_Kind_Poly_Call_Type: {
+    Scope *scope = get_scope_from_node(node);
+    if (!scope) {
+        if (node->kind == Ast_Kind_Poly_Call_Type) {
             AstPolyCallType* pcall = (AstPolyCallType *) node;
             return find_closest_symbol_in_node((AstNode *) pcall->callee, sym);
         }
+
+        return NULL;
     }
 
-    return NULL;
+    u32 dist;
+    return find_closest_symbol_in_scope(scope, sym, &dist);
 }