changed: refactored closures to be less stupid
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Fri, 21 Apr 2023 19:36:23 +0000 (14:36 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Fri, 21 Apr 2023 19:36:23 +0000 (14:36 -0500)
compiler/include/astnodes.h
compiler/include/utils.h
compiler/src/astnodes.c
compiler/src/checker.c
compiler/src/clone.c
compiler/src/onyx.c
compiler/src/parser.c
compiler/src/symres.c
compiler/src/utils.c
compiler/src/wasm_emit.c
tests/aoc-2021/day12.onyx

index b1296d09784247e6acc77b1def62063f6bb6c441..ef896aa624b4e8a57571a44a9a48343d95191c76 100644 (file)
                                \
     NODE(CaptureBlock)         \
     NODE(CaptureLocal)         \
-    NODE(CaptureBuilder)       \
                                \
     NODE(ForeignBlock)         \
                                \
@@ -237,7 +236,6 @@ typedef enum AstKind {
 
     Ast_Kind_Capture_Block,
     Ast_Kind_Capture_Local,
-    Ast_Kind_Capture_Builder,
 
     Ast_Kind_Foreign_Block,
 
@@ -301,7 +299,8 @@ typedef enum AstFlags {
 
     Ast_Flag_Binding_Isnt_Captured = BH_BIT(25),
 
-    Ast_Flag_Function_Is_Lambda    = BH_BIT(26)
+    Ast_Flag_Function_Is_Lambda    = BH_BIT(26),
+    Ast_Flag_Function_Is_Lambda_Inside_PolyProc = BH_BIT(27),
 } AstFlags;
 
 typedef enum UnaryOp {
@@ -1341,6 +1340,7 @@ struct AstFunction {
     AstBinding *original_binding_to_node;
 
     AstCaptureBlock *captures;
+    Scope *scope_to_lookup_captured_values;
 
     b32 is_exported        : 1;
     b32 is_foreign         : 1;
@@ -1359,16 +1359,9 @@ struct AstCaptureBlock {
 struct AstCaptureLocal {
     AstTyped_base;
 
-    u32 offset;
-};
-
-struct AstCaptureBuilder {
-    AstTyped_base;
+    AstTyped *captured_value;
 
-    AstTyped *func;
-    AstCaptureBlock *captures;
-
-    bh_arr(AstTyped *) capture_values;
+    u32 offset;
 };
 
 struct AstPolyQuery {
index 1b1fc61e43c4b06d55f75f145ba09b21bc067198..36885992182cc05e396c03610bdf57776b867c45 100644 (file)
@@ -41,7 +41,5 @@ u32 levenshtein_distance(const char *str1, const char *str2);
 char *find_closest_symbol_in_scope_and_parents(Scope *scope, char *sym);
 char *find_closest_symbol_in_node(AstNode *node, char *sym);
 
-b32 maybe_create_capture_builder_for_function_expression(AstTyped **pexpr);
-
 extern AstTyped node_that_signals_a_yield;
 extern AstTyped node_that_signals_failure;
index f604561219e0ec1fbfc9f6f2ea0aee8f9ed6cb4e..d9b40122104c184529e98b2d501dc1186e50d8d1 100644 (file)
@@ -110,7 +110,6 @@ static const char* ast_node_names[] = {
 
     "CAPTURE BLOCK",
     "CAPTURE LOCAL",
-    "CAPTURE BUILDER",
 
     "FOREIGN BLOCK",
     "ZERO VALUE",
@@ -696,13 +695,6 @@ TypeMatch unify_node_and_type_(AstTyped** pnode, Type* type, b32 permanent) {
         node = *pnode;
     }
 
-    if (node->kind == Ast_Kind_Function && permanent) {
-        if (maybe_create_capture_builder_for_function_expression(pnode)) {
-            return TYPE_MATCH_SPECIAL;
-        }
-    }
-
-
     // HACK: NullProcHack
     // The null_proc matches any procedure, and because of that, will cause a runtime error if you
     // try to call it.
index 6f6c277fee4cd9342f26bfbd9d56fbd44d019c51..b1ccbed394ce7d400503d9c960705abbda705cdc 100644 (file)
@@ -2157,9 +2157,6 @@ CheckStatus check_expression(AstTyped** pexpr) {
                 YIELD(expr->token->pos, "Waiting for function type to be resolved.");
 
             expr->flags |= Ast_Flag_Function_Used;
-            if (maybe_create_capture_builder_for_function_expression(pexpr)) {
-                retval = Check_Return_To_Symres;
-            }
             break;
 
         case Ast_Kind_Directive_Solidify:
@@ -2228,22 +2225,9 @@ CheckStatus check_expression(AstTyped** pexpr) {
             YIELD(expr->token->pos, "Waiting to resolve #this_package.");
             break;
 
-        case Ast_Kind_Capture_Builder: {
-            AstCaptureBuilder *builder = (void *) expr;
-            builder->type = get_expression_type(builder->func);
-
-            fori (i, 0, bh_arr_length(builder->capture_values)) {
-                if (!builder->captures->captures[i]->type) {
-                    YIELD(expr->token->pos, "Waiting to know capture value types.");
-                }
-
-                TYPE_CHECK(&builder->capture_values[i], builder->captures->captures[i]->type) {
-                    ERROR_(builder->captures->captures[i]->token->pos, "Type mismatch for this captured value. Expected '%s', got '%s'.",
-                            type_get_name(builder->captures->captures[i]->type), type_get_name(builder->capture_values[i]->type));
-                }
-            }
+        case Ast_Kind_Capture_Local:
+            expr->type = ((AstCaptureLocal *) expr)->captured_value->type;
             break;
-        }
 
         case Ast_Kind_File_Contents: break;
         case Ast_Kind_Overloaded_Function: break;
@@ -2255,7 +2239,6 @@ CheckStatus check_expression(AstTyped** pexpr) {
         case Ast_Kind_Switch_Case: break;
         case Ast_Kind_Foreign_Block: break;
         case Ast_Kind_Zero_Value: break;
-        case Ast_Kind_Capture_Local: break;
 
         default:
             retval = Check_Error;
index a689d74cfbdabf742d4bb6fa5b66f8a19d931ed8..d9b209ed6a82ee7419092bff7df8bc1ad7c2ce32 100644 (file)
@@ -448,21 +448,26 @@ AstNode* ast_clone(bh_allocator a, void* n) {
 
         case Ast_Kind_Function:
         case Ast_Kind_Polymorphic_Proc: {
-            if (clone_depth > 1) {
-                clone_depth--;
-                return node;
-            }
-
             AstFunction* df = (AstFunction *) nn;
             AstFunction* sf = (AstFunction *) node;
 
-            convert_polyproc_to_function(df);
+            if (clone_depth > 1) {
+                if ((node->flags & Ast_Flag_Function_Is_Lambda) == 0 || !sf->captures) {
+                    clone_depth--;
+                    return node;
+                }
+            }
+            else {
+                convert_polyproc_to_function(df);
+            }
 
             if (sf->is_foreign) return node;
             assert(df->scope == NULL);
 
             df->nodes_that_need_entities_after_clone = NULL;
             bh_arr_new(global_heap_allocator, df->nodes_that_need_entities_after_clone, 1);
+
+            bh_arr(AstNode *) old_captured_entities = captured_entities;
             captured_entities = df->nodes_that_need_entities_after_clone;
 
             df->return_type = (AstType *) ast_clone(a, sf->return_type);
@@ -470,7 +475,7 @@ AstNode* ast_clone(bh_allocator a, void* n) {
             df->captures = (AstCaptureBlock *) ast_clone(a, sf->captures);
 
             df->nodes_that_need_entities_after_clone = captured_entities;
-            captured_entities = NULL;
+            captured_entities = old_captured_entities;
 
             df->params = NULL;
             bh_arr_new(context.ast_alloc, df->params, bh_arr_length(sf->params));
@@ -506,6 +511,12 @@ AstNode* ast_clone(bh_allocator a, void* n) {
                 }    
             }
 
+            if (clone_depth > 1) {
+                sf->flags |= Ast_Flag_Function_Is_Lambda_Inside_PolyProc;
+                df->flags &= ~Ast_Flag_Function_Is_Lambda_Inside_PolyProc;
+                E(df);
+            }
+
             break;
         }
 
index b7126d572a47cc82a462a0de33029c26b0b60d88..cf572acb02650ea881c9bac8605b95e1bfbda33a 100644 (file)
@@ -542,11 +542,12 @@ static b32 process_entity(Entity* ent) {
     if (context.options->verbose_output == 3) {
         if (ent->expr && ent->expr->token)
             snprintf(verbose_output_buffer, 511,
-                    "%20s | %24s (%d, %d) | %s:%i:%i \n",
+                    "%20s | %24s (%d, %d) | %5d | %s:%i:%i \n",
                    entity_state_strings[ent->state],
                    entity_type_strings[ent->type],
                    (u32) ent->macro_attempts,
                    (u32) ent->micro_attempts,
+                   ent->id,
                    ent->expr->token->pos.filename,
                    ent->expr->token->pos.line,
                    ent->expr->token->pos.column);
index 1774480507acd1b27bd3a51feba1bf866daf9391..f74622be4a10f700dba82611a9d5831eb74284f7 100644 (file)
@@ -2402,9 +2402,6 @@ static AstCaptureBlock *parse_capture_list(OnyxParser* parser, TokenType end_tok
         AstCaptureLocal *capture = make_node(AstCaptureLocal, Ast_Kind_Capture_Local);
         capture->token = expect_token(parser, Token_Type_Symbol);
 
-        expect_token(parser, ':');
-        capture->type_node = parse_type(parser);
-
         bh_arr_push(captures->captures, capture);
 
         if (peek_token(0)->type != end_token)
@@ -2427,8 +2424,8 @@ static void parse_function_params(OnyxParser* parser, AstFunction* func) {
 
     OnyxToken* symbol;
     while (!consume_token_if_next(parser, ')')) {
-        if (consume_token_if_next(parser, '|') && !func->captures) {
-            func->captures = parse_capture_list(parser, '|');
+        if (consume_token_if_next(parser, '[') && !func->captures) {
+            func->captures = parse_capture_list(parser, ']');
             consume_token_if_next(parser, ',');
             continue;
         }
@@ -2730,25 +2727,13 @@ static b32 parse_possible_function_definition_no_consume(OnyxParser* parser) {
         b32 is_params = (parser->curr + 1) == matching_paren;
         OnyxToken* tmp_token = parser->curr;
         while (!is_params && tmp_token < matching_paren) {
-            if (tmp_token->type == '|') {
-                tmp_token++;
-                while (tmp_token->type != '|' && tmp_token < matching_paren) {
-                    tmp_token++;
-                }
-                tmp_token++;
-            }
-
             if (tmp_token->type == ':') is_params = 1;
 
             tmp_token++;
         }
 
-        if (peek_token(1)->type == '|' && (matching_paren - 1)->type == '|') {
-            OnyxToken* tmp_token = parser->curr + 1;
-            while (!is_params && tmp_token < matching_paren - 1) {
-                if (tmp_token->type == ':') is_params = 1;
-                tmp_token++;
-            }
+        if (peek_token(1)->type == '[' && (matching_paren - 1)->type == ']') {
+            is_params = 1;
         }
 
         return is_params;
@@ -2814,8 +2799,8 @@ static b32 parse_possible_quick_function_definition(OnyxParser* parser, AstTyped
         while (!consume_token_if_next(parser, ')')) {
             if (parser->hit_unexpected_token) return 0;
 
-            if (consume_token_if_next(parser, '|') && !captures) {
-                captures = parse_capture_list(parser, '|');
+            if (consume_token_if_next(parser, '[') && !captures) {
+                captures = parse_capture_list(parser, ']');
 
             } else {
                 QuickParam param = { 0 };
index 20fe2b3636e25902af9c62a8e90e9c60e5d64fb8..483cda9af188d71e0f274701a0de2c7d717c12d6 100644 (file)
@@ -12,9 +12,6 @@ static Scope*       current_scope    = NULL;
 static b32 report_unresolved_symbols = 1;
 static b32 resolved_a_symbol         = 0;
 
-// Everything related to waiting on is imcomplete at the moment.
-static Entity* waiting_on         = NULL;
-
 static Entity* current_entity = NULL;
 
 #define SYMRES(kind, ...) do { \
@@ -73,7 +70,6 @@ static SymresStatus symres_static_if(AstIf* static_if);
 static SymresStatus symres_macro(AstMacro* macro);
 static SymresStatus symres_constraint(AstConstraint* constraint);
 static SymresStatus symres_polyquery(AstPolyQuery *query);
-static SymresStatus symres_capture_builder(AstCaptureBuilder *builder);
 
 static void scope_enter(Scope* new_scope) {
     current_scope = new_scope;
@@ -576,7 +572,20 @@ static SymresStatus symres_expression(AstTyped** expr) {
             (*expr)->type_node = builtin_range_type;
             break;
 
+        case Ast_Kind_Polymorphic_Proc:
+            if (((AstFunction *) *expr)->captures) {
+                ((AstFunction *) *expr)->scope_to_lookup_captured_values = current_scope;
+            }
+            break;
+
         case Ast_Kind_Function:
+            if (((AstFunction *) *expr)->captures) {
+                ((AstFunction *) *expr)->scope_to_lookup_captured_values = current_scope;
+            }
+
+            SYMRES(type, &(*expr)->type_node);
+            break;
+
         case Ast_Kind_NumLit:
             SYMRES(type, &(*expr)->type_node);
             break;
@@ -660,8 +669,6 @@ static SymresStatus symres_expression(AstTyped** expr) {
             break;
         }
 
-        case Ast_Kind_Capture_Builder: SYMRES(capture_builder, (AstCaptureBuilder *) *expr); break;
-
         default: break;
     }
 
@@ -840,22 +847,11 @@ static SymresStatus symres_directive_insert(AstDirectiveInsert* insert) {
     return Symres_Success;
 }
 
-static SymresStatus symres_capture_block(AstCaptureBlock *block) {
+static SymresStatus symres_capture_block(AstCaptureBlock *block, Scope *captured_scope) {
     bh_arr_each(AstCaptureLocal *, capture, block->captures) {
-        SYMRES(type, &(*capture)->type_node);
-    }
-
-    bh_arr_each(AstCaptureLocal *, capture, block->captures) {
-        symbol_introduce(current_scope, (*capture)->token, (AstNode *) *capture);
-    }
-
-    return Symres_Success;
-}
+        OnyxToken *token = (*capture)->token;
+        AstTyped *resolved = (AstTyped *) symbol_resolve(captured_scope, token);
 
-static SymresStatus symres_capture_builder(AstCaptureBuilder *builder) {
-    fori (i, bh_arr_length(builder->capture_values), bh_arr_length(builder->captures->captures)) {
-        OnyxToken *token = builder->captures->captures[i]->token;
-        AstTyped *resolved = (AstTyped *) symbol_resolve(current_scope, token);
         if (!resolved) {
             // Should this do a yield? In there any case that that would make sense?
             onyx_report_error(token->pos, Error_Critical, "'%b' is not found in the enclosing scope.",
@@ -863,7 +859,11 @@ static SymresStatus symres_capture_builder(AstCaptureBuilder *builder) {
             return Symres_Error;
         }
 
-        bh_arr_push(builder->capture_values, resolved);
+        (*capture)->captured_value = resolved;
+    }
+
+    bh_arr_each(AstCaptureLocal *, capture, block->captures) {
+        symbol_introduce(current_scope, (*capture)->token, (AstNode *) *capture);
     }
 
     return Symres_Success;
@@ -973,6 +973,17 @@ static SymresStatus symres_block(AstBlock* block) {
 SymresStatus symres_function_header(AstFunction* func) {
     func->flags |= Ast_Flag_Comptime;
 
+    if (func->captures && !func->scope_to_lookup_captured_values) {
+        if (!(func->flags & Ast_Flag_Function_Is_Lambda)) {
+            onyx_report_error(func->captures->token->pos, Error_Critical, "This procedure cannot capture values as it is not defined in an expression.");
+            return Symres_Error;
+        }
+
+        if (func->flags & Ast_Flag_Function_Is_Lambda_Inside_PolyProc) return Symres_Complete;
+
+        return Symres_Yield_Macro;
+    }
+
     if (func->scope == NULL)
         func->scope = scope_create(context.ast_alloc, current_scope, func->token->pos);
 
@@ -1014,7 +1025,8 @@ SymresStatus symres_function_header(AstFunction* func) {
             // This makes a lot of assumptions about how these nodes are being processed,
             // and I don't want to start using this with other nodes without considering
             // what the ramifications of that is.
-            assert((*node)->kind == Ast_Kind_Static_If || (*node)->kind == Ast_Kind_File_Contents);
+            assert((*node)->kind == Ast_Kind_Static_If || (*node)->kind == Ast_Kind_File_Contents
+                    || (*node)->kind == Ast_Kind_Function || (*node)->kind == Ast_Kind_Polymorphic_Proc);
 
             // Need to current_scope->parent because current_scope is the function body scope.
             Scope *scope = current_scope->parent;
@@ -1037,12 +1049,7 @@ SymresStatus symres_function_header(AstFunction* func) {
     }
 
     if (func->captures) {
-        if (!(func->flags & Ast_Flag_Function_Is_Lambda)) {
-            onyx_report_error(func->captures->token->pos, Error_Critical, "This procedure cannot capture values as it is not defined in an expression.");
-            return Symres_Error;
-        }
-        
-        SYMRES(capture_block, func->captures);
+        SYMRES(capture_block, func->captures, func->scope_to_lookup_captured_values);
     }
 
     SYMRES(type, &func->return_type);
@@ -1064,6 +1071,7 @@ SymresStatus symres_function_header(AstFunction* func) {
 SymresStatus symres_function(AstFunction* func) {
     if (func->entity_header && func->entity_header->state < Entity_State_Check_Types) return Symres_Yield_Macro;
     if (func->kind == Ast_Kind_Polymorphic_Proc) return Symres_Complete;
+    if (func->flags & Ast_Flag_Function_Is_Lambda_Inside_PolyProc) return Symres_Complete;
     assert(func->scope);
 
     scope_enter(func->scope);
index 6c6cb59ba1129e6a05c718ce7a5364848c843ccf..cf6d35fe7923655ae325a5a41867256bee1e7b9d 100644 (file)
@@ -1482,23 +1482,3 @@ void track_resolution_for_symbol_info(AstNode *original, AstNode *resolved) {
 }
 
 
-
-b32 maybe_create_capture_builder_for_function_expression(AstTyped **pexpr) {
-    AstFunction *func = (void *) *pexpr;
-
-    if (!(func->flags & Ast_Flag_Function_Is_Lambda)) return 0;
-    if (!func->captures) return 0;
-
-    AstCaptureBuilder *builder = onyx_ast_node_new(context.ast_alloc, sizeof(AstCaptureBuilder), Ast_Kind_Capture_Builder);
-    builder->token = func->captures->token - 1;
-
-    builder->func = (void *) func;
-    // builder->type = builder->func->type;
-    builder->captures = func->captures;
-
-    bh_arr_new(context.ast_alloc, builder->capture_values, bh_arr_length(builder->captures->captures));
-
-    *((void **) pexpr) = builder;
-    return 1;
-}
-
index bd168ea05fab66409a6f8c5be708c8ef2b78d45c..d5cc13982308a040290efe8cc91e6b119db95ef9 100644 (file)
@@ -3391,22 +3391,17 @@ EMIT_FUNC(expression, AstTyped* expr) {
         }
 
         case Ast_Kind_Function: {
-            i32 elemidx = get_element_idx(mod, (AstFunction *) expr);
+            AstFunction *func = (AstFunction *) expr;
+            i32 elemidx = get_element_idx(mod, func);
 
             WID(NULL, WI_I32_CONST, elemidx);
-            WIL(NULL, WI_I32_CONST, 0);
-            break;
-        }
-
-        case Ast_Kind_Capture_Builder: {
-            AstCaptureBuilder *builder = (AstCaptureBuilder *) expr;
-            
-            assert(builder->func->kind == Ast_Kind_Function);
-            i32 elemidx = get_element_idx(mod, (AstFunction *) builder->func);
-            WID(NULL, WI_I32_CONST, elemidx);
+            if (!func->captures) {
+                WIL(NULL, WI_I32_CONST, 0);
+                break;
+            }
 
             // Allocate the block
-            WIL(NULL, WI_I32_CONST, builder->captures->total_size_in_bytes);
+            WIL(NULL, WI_I32_CONST, func->captures->total_size_in_bytes);
             i32 func_idx = (i32) bh_imap_get(&mod->index_map, (u64) builtin_closure_block_allocate);
             WIL(NULL, WI_CALL, func_idx);
 
@@ -3414,10 +3409,10 @@ EMIT_FUNC(expression, AstTyped* expr) {
             WIL(NULL, WI_LOCAL_TEE, capture_block_ptr);
             
             // Populate the block
-            fori (i, 0, bh_arr_length(builder->capture_values)) {
+            bh_arr_each(AstCaptureLocal *, capture, func->captures->captures) {
                 WIL(NULL, WI_LOCAL_GET, capture_block_ptr);
-                emit_expression(mod, &code, builder->capture_values[i]);
-                emit_store_instruction(mod, &code, builder->capture_values[i]->type, builder->captures->captures[i]->offset);
+                emit_expression(mod, &code, (*capture)->captured_value);
+                emit_store_instruction(mod, &code, (*capture)->captured_value->type, (*capture)->offset);
             }
             
             local_raw_free(mod->local_alloc, WASM_TYPE_PTR);
index c079d9bb3da740ae0d4514ce17207eeed9334393..c4723a9f3b8337e4f217791f681cf4eccba8673a 100644 (file)
@@ -43,11 +43,11 @@ main :: (args) => {
         children_of :: (edges: &$T, name: str) -> Iterator(str) {
             return iter.concat(
                 iter.as_iter(edges)
-                ->filter((x, |name: str|) => x.a == name)
+                ->filter((x, [name]) => x.a == name)
                 ->map(x => x.b),
 
                 iter.as_iter(edges)
-                ->filter((x, |name: str|) => x.b == name)
+                ->filter((x, [name]) => x.b == name)
                 ->map(x => x.a)
             );
         }