From: Brendan Hansen Date: Fri, 21 Apr 2023 19:36:23 +0000 (-0500) Subject: changed: refactored closures to be less stupid X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=a458c032ae5768f50cd0cba523c65fefb6022b10;p=onyx.git changed: refactored closures to be less stupid --- diff --git a/compiler/include/astnodes.h b/compiler/include/astnodes.h index b1296d09..ef896aa6 100644 --- a/compiler/include/astnodes.h +++ b/compiler/include/astnodes.h @@ -105,7 +105,6 @@ \ NODE(CaptureBlock) \ NODE(CaptureLocal) \ - NODE(CaptureBuilder) \ \ NODE(ForeignBlock) \ \ @@ -237,7 +236,6 @@ typedef enum AstKind { Ast_Kind_Capture_Block, Ast_Kind_Capture_Local, - Ast_Kind_Capture_Builder, Ast_Kind_Foreign_Block, @@ -301,7 +299,8 @@ typedef enum AstFlags { Ast_Flag_Binding_Isnt_Captured = BH_BIT(25), - Ast_Flag_Function_Is_Lambda = BH_BIT(26) + Ast_Flag_Function_Is_Lambda = BH_BIT(26), + Ast_Flag_Function_Is_Lambda_Inside_PolyProc = BH_BIT(27), } AstFlags; typedef enum UnaryOp { @@ -1341,6 +1340,7 @@ struct AstFunction { AstBinding *original_binding_to_node; AstCaptureBlock *captures; + Scope *scope_to_lookup_captured_values; b32 is_exported : 1; b32 is_foreign : 1; @@ -1359,16 +1359,9 @@ struct AstCaptureBlock { struct AstCaptureLocal { AstTyped_base; - u32 offset; -}; - -struct AstCaptureBuilder { - AstTyped_base; + AstTyped *captured_value; - AstTyped *func; - AstCaptureBlock *captures; - - bh_arr(AstTyped *) capture_values; + u32 offset; }; struct AstPolyQuery { diff --git a/compiler/include/utils.h b/compiler/include/utils.h index 1b1fc61e..36885992 100644 --- a/compiler/include/utils.h +++ b/compiler/include/utils.h @@ -41,7 +41,5 @@ u32 levenshtein_distance(const char *str1, const char *str2); char *find_closest_symbol_in_scope_and_parents(Scope *scope, char *sym); char *find_closest_symbol_in_node(AstNode *node, char *sym); -b32 maybe_create_capture_builder_for_function_expression(AstTyped **pexpr); - extern AstTyped node_that_signals_a_yield; extern AstTyped node_that_signals_failure; diff --git a/compiler/src/astnodes.c b/compiler/src/astnodes.c index f6045612..d9b40122 100644 --- a/compiler/src/astnodes.c +++ b/compiler/src/astnodes.c @@ -110,7 +110,6 @@ static const char* ast_node_names[] = { "CAPTURE BLOCK", "CAPTURE LOCAL", - "CAPTURE BUILDER", "FOREIGN BLOCK", "ZERO VALUE", @@ -696,13 +695,6 @@ TypeMatch unify_node_and_type_(AstTyped** pnode, Type* type, b32 permanent) { node = *pnode; } - if (node->kind == Ast_Kind_Function && permanent) { - if (maybe_create_capture_builder_for_function_expression(pnode)) { - return TYPE_MATCH_SPECIAL; - } - } - - // HACK: NullProcHack // The null_proc matches any procedure, and because of that, will cause a runtime error if you // try to call it. diff --git a/compiler/src/checker.c b/compiler/src/checker.c index 6f6c277f..b1ccbed3 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -2157,9 +2157,6 @@ CheckStatus check_expression(AstTyped** pexpr) { YIELD(expr->token->pos, "Waiting for function type to be resolved."); expr->flags |= Ast_Flag_Function_Used; - if (maybe_create_capture_builder_for_function_expression(pexpr)) { - retval = Check_Return_To_Symres; - } break; case Ast_Kind_Directive_Solidify: @@ -2228,22 +2225,9 @@ CheckStatus check_expression(AstTyped** pexpr) { YIELD(expr->token->pos, "Waiting to resolve #this_package."); break; - case Ast_Kind_Capture_Builder: { - AstCaptureBuilder *builder = (void *) expr; - builder->type = get_expression_type(builder->func); - - fori (i, 0, bh_arr_length(builder->capture_values)) { - if (!builder->captures->captures[i]->type) { - YIELD(expr->token->pos, "Waiting to know capture value types."); - } - - TYPE_CHECK(&builder->capture_values[i], builder->captures->captures[i]->type) { - ERROR_(builder->captures->captures[i]->token->pos, "Type mismatch for this captured value. Expected '%s', got '%s'.", - type_get_name(builder->captures->captures[i]->type), type_get_name(builder->capture_values[i]->type)); - } - } + case Ast_Kind_Capture_Local: + expr->type = ((AstCaptureLocal *) expr)->captured_value->type; break; - } case Ast_Kind_File_Contents: break; case Ast_Kind_Overloaded_Function: break; @@ -2255,7 +2239,6 @@ CheckStatus check_expression(AstTyped** pexpr) { case Ast_Kind_Switch_Case: break; case Ast_Kind_Foreign_Block: break; case Ast_Kind_Zero_Value: break; - case Ast_Kind_Capture_Local: break; default: retval = Check_Error; diff --git a/compiler/src/clone.c b/compiler/src/clone.c index a689d74c..d9b209ed 100644 --- a/compiler/src/clone.c +++ b/compiler/src/clone.c @@ -448,21 +448,26 @@ AstNode* ast_clone(bh_allocator a, void* n) { case Ast_Kind_Function: case Ast_Kind_Polymorphic_Proc: { - if (clone_depth > 1) { - clone_depth--; - return node; - } - AstFunction* df = (AstFunction *) nn; AstFunction* sf = (AstFunction *) node; - convert_polyproc_to_function(df); + if (clone_depth > 1) { + if ((node->flags & Ast_Flag_Function_Is_Lambda) == 0 || !sf->captures) { + clone_depth--; + return node; + } + } + else { + convert_polyproc_to_function(df); + } if (sf->is_foreign) return node; assert(df->scope == NULL); df->nodes_that_need_entities_after_clone = NULL; bh_arr_new(global_heap_allocator, df->nodes_that_need_entities_after_clone, 1); + + bh_arr(AstNode *) old_captured_entities = captured_entities; captured_entities = df->nodes_that_need_entities_after_clone; df->return_type = (AstType *) ast_clone(a, sf->return_type); @@ -470,7 +475,7 @@ AstNode* ast_clone(bh_allocator a, void* n) { df->captures = (AstCaptureBlock *) ast_clone(a, sf->captures); df->nodes_that_need_entities_after_clone = captured_entities; - captured_entities = NULL; + captured_entities = old_captured_entities; df->params = NULL; bh_arr_new(context.ast_alloc, df->params, bh_arr_length(sf->params)); @@ -506,6 +511,12 @@ AstNode* ast_clone(bh_allocator a, void* n) { } } + if (clone_depth > 1) { + sf->flags |= Ast_Flag_Function_Is_Lambda_Inside_PolyProc; + df->flags &= ~Ast_Flag_Function_Is_Lambda_Inside_PolyProc; + E(df); + } + break; } diff --git a/compiler/src/onyx.c b/compiler/src/onyx.c index b7126d57..cf572acb 100644 --- a/compiler/src/onyx.c +++ b/compiler/src/onyx.c @@ -542,11 +542,12 @@ static b32 process_entity(Entity* ent) { if (context.options->verbose_output == 3) { if (ent->expr && ent->expr->token) snprintf(verbose_output_buffer, 511, - "%20s | %24s (%d, %d) | %s:%i:%i \n", + "%20s | %24s (%d, %d) | %5d | %s:%i:%i \n", entity_state_strings[ent->state], entity_type_strings[ent->type], (u32) ent->macro_attempts, (u32) ent->micro_attempts, + ent->id, ent->expr->token->pos.filename, ent->expr->token->pos.line, ent->expr->token->pos.column); diff --git a/compiler/src/parser.c b/compiler/src/parser.c index 17744805..f74622be 100644 --- a/compiler/src/parser.c +++ b/compiler/src/parser.c @@ -2402,9 +2402,6 @@ static AstCaptureBlock *parse_capture_list(OnyxParser* parser, TokenType end_tok AstCaptureLocal *capture = make_node(AstCaptureLocal, Ast_Kind_Capture_Local); capture->token = expect_token(parser, Token_Type_Symbol); - expect_token(parser, ':'); - capture->type_node = parse_type(parser); - bh_arr_push(captures->captures, capture); if (peek_token(0)->type != end_token) @@ -2427,8 +2424,8 @@ static void parse_function_params(OnyxParser* parser, AstFunction* func) { OnyxToken* symbol; while (!consume_token_if_next(parser, ')')) { - if (consume_token_if_next(parser, '|') && !func->captures) { - func->captures = parse_capture_list(parser, '|'); + if (consume_token_if_next(parser, '[') && !func->captures) { + func->captures = parse_capture_list(parser, ']'); consume_token_if_next(parser, ','); continue; } @@ -2730,25 +2727,13 @@ static b32 parse_possible_function_definition_no_consume(OnyxParser* parser) { b32 is_params = (parser->curr + 1) == matching_paren; OnyxToken* tmp_token = parser->curr; while (!is_params && tmp_token < matching_paren) { - if (tmp_token->type == '|') { - tmp_token++; - while (tmp_token->type != '|' && tmp_token < matching_paren) { - tmp_token++; - } - tmp_token++; - } - if (tmp_token->type == ':') is_params = 1; tmp_token++; } - if (peek_token(1)->type == '|' && (matching_paren - 1)->type == '|') { - OnyxToken* tmp_token = parser->curr + 1; - while (!is_params && tmp_token < matching_paren - 1) { - if (tmp_token->type == ':') is_params = 1; - tmp_token++; - } + if (peek_token(1)->type == '[' && (matching_paren - 1)->type == ']') { + is_params = 1; } return is_params; @@ -2814,8 +2799,8 @@ static b32 parse_possible_quick_function_definition(OnyxParser* parser, AstTyped while (!consume_token_if_next(parser, ')')) { if (parser->hit_unexpected_token) return 0; - if (consume_token_if_next(parser, '|') && !captures) { - captures = parse_capture_list(parser, '|'); + if (consume_token_if_next(parser, '[') && !captures) { + captures = parse_capture_list(parser, ']'); } else { QuickParam param = { 0 }; diff --git a/compiler/src/symres.c b/compiler/src/symres.c index 20fe2b36..483cda9a 100644 --- a/compiler/src/symres.c +++ b/compiler/src/symres.c @@ -12,9 +12,6 @@ static Scope* current_scope = NULL; static b32 report_unresolved_symbols = 1; static b32 resolved_a_symbol = 0; -// Everything related to waiting on is imcomplete at the moment. -static Entity* waiting_on = NULL; - static Entity* current_entity = NULL; #define SYMRES(kind, ...) do { \ @@ -73,7 +70,6 @@ static SymresStatus symres_static_if(AstIf* static_if); static SymresStatus symres_macro(AstMacro* macro); static SymresStatus symres_constraint(AstConstraint* constraint); static SymresStatus symres_polyquery(AstPolyQuery *query); -static SymresStatus symres_capture_builder(AstCaptureBuilder *builder); static void scope_enter(Scope* new_scope) { current_scope = new_scope; @@ -576,7 +572,20 @@ static SymresStatus symres_expression(AstTyped** expr) { (*expr)->type_node = builtin_range_type; break; + case Ast_Kind_Polymorphic_Proc: + if (((AstFunction *) *expr)->captures) { + ((AstFunction *) *expr)->scope_to_lookup_captured_values = current_scope; + } + break; + case Ast_Kind_Function: + if (((AstFunction *) *expr)->captures) { + ((AstFunction *) *expr)->scope_to_lookup_captured_values = current_scope; + } + + SYMRES(type, &(*expr)->type_node); + break; + case Ast_Kind_NumLit: SYMRES(type, &(*expr)->type_node); break; @@ -660,8 +669,6 @@ static SymresStatus symres_expression(AstTyped** expr) { break; } - case Ast_Kind_Capture_Builder: SYMRES(capture_builder, (AstCaptureBuilder *) *expr); break; - default: break; } @@ -840,22 +847,11 @@ static SymresStatus symres_directive_insert(AstDirectiveInsert* insert) { return Symres_Success; } -static SymresStatus symres_capture_block(AstCaptureBlock *block) { +static SymresStatus symres_capture_block(AstCaptureBlock *block, Scope *captured_scope) { bh_arr_each(AstCaptureLocal *, capture, block->captures) { - SYMRES(type, &(*capture)->type_node); - } - - bh_arr_each(AstCaptureLocal *, capture, block->captures) { - symbol_introduce(current_scope, (*capture)->token, (AstNode *) *capture); - } - - return Symres_Success; -} + OnyxToken *token = (*capture)->token; + AstTyped *resolved = (AstTyped *) symbol_resolve(captured_scope, token); -static SymresStatus symres_capture_builder(AstCaptureBuilder *builder) { - fori (i, bh_arr_length(builder->capture_values), bh_arr_length(builder->captures->captures)) { - OnyxToken *token = builder->captures->captures[i]->token; - AstTyped *resolved = (AstTyped *) symbol_resolve(current_scope, token); if (!resolved) { // Should this do a yield? In there any case that that would make sense? onyx_report_error(token->pos, Error_Critical, "'%b' is not found in the enclosing scope.", @@ -863,7 +859,11 @@ static SymresStatus symres_capture_builder(AstCaptureBuilder *builder) { return Symres_Error; } - bh_arr_push(builder->capture_values, resolved); + (*capture)->captured_value = resolved; + } + + bh_arr_each(AstCaptureLocal *, capture, block->captures) { + symbol_introduce(current_scope, (*capture)->token, (AstNode *) *capture); } return Symres_Success; @@ -973,6 +973,17 @@ static SymresStatus symres_block(AstBlock* block) { SymresStatus symres_function_header(AstFunction* func) { func->flags |= Ast_Flag_Comptime; + if (func->captures && !func->scope_to_lookup_captured_values) { + if (!(func->flags & Ast_Flag_Function_Is_Lambda)) { + onyx_report_error(func->captures->token->pos, Error_Critical, "This procedure cannot capture values as it is not defined in an expression."); + return Symres_Error; + } + + if (func->flags & Ast_Flag_Function_Is_Lambda_Inside_PolyProc) return Symres_Complete; + + return Symres_Yield_Macro; + } + if (func->scope == NULL) func->scope = scope_create(context.ast_alloc, current_scope, func->token->pos); @@ -1014,7 +1025,8 @@ SymresStatus symres_function_header(AstFunction* func) { // This makes a lot of assumptions about how these nodes are being processed, // and I don't want to start using this with other nodes without considering // what the ramifications of that is. - assert((*node)->kind == Ast_Kind_Static_If || (*node)->kind == Ast_Kind_File_Contents); + assert((*node)->kind == Ast_Kind_Static_If || (*node)->kind == Ast_Kind_File_Contents + || (*node)->kind == Ast_Kind_Function || (*node)->kind == Ast_Kind_Polymorphic_Proc); // Need to current_scope->parent because current_scope is the function body scope. Scope *scope = current_scope->parent; @@ -1037,12 +1049,7 @@ SymresStatus symres_function_header(AstFunction* func) { } if (func->captures) { - if (!(func->flags & Ast_Flag_Function_Is_Lambda)) { - onyx_report_error(func->captures->token->pos, Error_Critical, "This procedure cannot capture values as it is not defined in an expression."); - return Symres_Error; - } - - SYMRES(capture_block, func->captures); + SYMRES(capture_block, func->captures, func->scope_to_lookup_captured_values); } SYMRES(type, &func->return_type); @@ -1064,6 +1071,7 @@ SymresStatus symres_function_header(AstFunction* func) { SymresStatus symres_function(AstFunction* func) { if (func->entity_header && func->entity_header->state < Entity_State_Check_Types) return Symres_Yield_Macro; if (func->kind == Ast_Kind_Polymorphic_Proc) return Symres_Complete; + if (func->flags & Ast_Flag_Function_Is_Lambda_Inside_PolyProc) return Symres_Complete; assert(func->scope); scope_enter(func->scope); diff --git a/compiler/src/utils.c b/compiler/src/utils.c index 6c6cb59b..cf6d35fe 100644 --- a/compiler/src/utils.c +++ b/compiler/src/utils.c @@ -1482,23 +1482,3 @@ void track_resolution_for_symbol_info(AstNode *original, AstNode *resolved) { } - -b32 maybe_create_capture_builder_for_function_expression(AstTyped **pexpr) { - AstFunction *func = (void *) *pexpr; - - if (!(func->flags & Ast_Flag_Function_Is_Lambda)) return 0; - if (!func->captures) return 0; - - AstCaptureBuilder *builder = onyx_ast_node_new(context.ast_alloc, sizeof(AstCaptureBuilder), Ast_Kind_Capture_Builder); - builder->token = func->captures->token - 1; - - builder->func = (void *) func; - // builder->type = builder->func->type; - builder->captures = func->captures; - - bh_arr_new(context.ast_alloc, builder->capture_values, bh_arr_length(builder->captures->captures)); - - *((void **) pexpr) = builder; - return 1; -} - diff --git a/compiler/src/wasm_emit.c b/compiler/src/wasm_emit.c index bd168ea0..d5cc1398 100644 --- a/compiler/src/wasm_emit.c +++ b/compiler/src/wasm_emit.c @@ -3391,22 +3391,17 @@ EMIT_FUNC(expression, AstTyped* expr) { } case Ast_Kind_Function: { - i32 elemidx = get_element_idx(mod, (AstFunction *) expr); + AstFunction *func = (AstFunction *) expr; + i32 elemidx = get_element_idx(mod, func); WID(NULL, WI_I32_CONST, elemidx); - WIL(NULL, WI_I32_CONST, 0); - break; - } - - case Ast_Kind_Capture_Builder: { - AstCaptureBuilder *builder = (AstCaptureBuilder *) expr; - - assert(builder->func->kind == Ast_Kind_Function); - i32 elemidx = get_element_idx(mod, (AstFunction *) builder->func); - WID(NULL, WI_I32_CONST, elemidx); + if (!func->captures) { + WIL(NULL, WI_I32_CONST, 0); + break; + } // Allocate the block - WIL(NULL, WI_I32_CONST, builder->captures->total_size_in_bytes); + WIL(NULL, WI_I32_CONST, func->captures->total_size_in_bytes); i32 func_idx = (i32) bh_imap_get(&mod->index_map, (u64) builtin_closure_block_allocate); WIL(NULL, WI_CALL, func_idx); @@ -3414,10 +3409,10 @@ EMIT_FUNC(expression, AstTyped* expr) { WIL(NULL, WI_LOCAL_TEE, capture_block_ptr); // Populate the block - fori (i, 0, bh_arr_length(builder->capture_values)) { + bh_arr_each(AstCaptureLocal *, capture, func->captures->captures) { WIL(NULL, WI_LOCAL_GET, capture_block_ptr); - emit_expression(mod, &code, builder->capture_values[i]); - emit_store_instruction(mod, &code, builder->capture_values[i]->type, builder->captures->captures[i]->offset); + emit_expression(mod, &code, (*capture)->captured_value); + emit_store_instruction(mod, &code, (*capture)->captured_value->type, (*capture)->offset); } local_raw_free(mod->local_alloc, WASM_TYPE_PTR); diff --git a/tests/aoc-2021/day12.onyx b/tests/aoc-2021/day12.onyx index c079d9bb..c4723a9f 100644 --- a/tests/aoc-2021/day12.onyx +++ b/tests/aoc-2021/day12.onyx @@ -43,11 +43,11 @@ main :: (args) => { children_of :: (edges: &$T, name: str) -> Iterator(str) { return iter.concat( iter.as_iter(edges) - ->filter((x, |name: str|) => x.a == name) + ->filter((x, [name]) => x.a == name) ->map(x => x.b), iter.as_iter(edges) - ->filter((x, |name: str|) => x.b == name) + ->filter((x, [name]) => x.b == name) ->map(x => x.a) ); }