added: working closures for non-polymorphic functions! feature/closures
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Wed, 19 Apr 2023 16:03:30 +0000 (11:03 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Wed, 19 Apr 2023 16:03:30 +0000 (11:03 -0500)
compiler/include/astnodes.h
compiler/include/utils.h
compiler/include/wasm_emit.h
compiler/src/astnodes.c
compiler/src/builtins.c
compiler/src/checker.c
compiler/src/onyx.c
compiler/src/symres.c
compiler/src/utils.c
compiler/src/wasm_emit.c
core/builtin.onyx

index c3ffdc3bf9f649a7e0cbcc94cd131d45bd4ff654..654ef4d6407808dcd0867c37e710c885a7506919 100644 (file)
                                \
     NODE(CaptureBlock)         \
     NODE(CaptureLocal)         \
+    NODE(CaptureBuilder)       \
                                \
     NODE(ForeignBlock)         \
                                \
@@ -236,6 +237,7 @@ typedef enum AstKind {
 
     Ast_Kind_Capture_Block,
     Ast_Kind_Capture_Local,
+    Ast_Kind_Capture_Builder,
 
     Ast_Kind_Foreign_Block,
 
@@ -1362,6 +1364,15 @@ struct AstCaptureLocal {
     u32 offset;
 };
 
+struct AstCaptureBuilder {
+    AstTyped_base;
+
+    AstTyped *func;
+    AstCaptureBlock *captures;
+
+    bh_arr(AstTyped *) capture_values;
+};
+
 struct AstPolyQuery {
     AstNode_base;
 
@@ -1797,6 +1808,7 @@ extern AstGlobal builtin_heap_start;
 extern AstGlobal builtin_stack_top;
 extern AstGlobal builtin_tls_base;
 extern AstGlobal builtin_tls_size;
+extern AstGlobal builtin_closure_base;
 extern AstType  *builtin_string_type;
 extern AstType  *builtin_cstring_type;
 extern AstType  *builtin_range_type;
@@ -1818,6 +1830,7 @@ extern AstType  *foreign_block_type;
 extern AstTyped *tagged_procedures_node;
 extern AstFunction *builtin_initialize_data_segments;
 extern AstFunction *builtin_run_init_procedures;
+extern AstFunction *builtin_closure_block_allocate;
 extern bh_arr(AstFunction *) init_procedures;
 extern AstOverloadedFunction *builtin_implicit_bool_cast;
 
@@ -1964,6 +1977,7 @@ static inline b32 is_lval(AstNode* node) {
         || (node->kind == Ast_Kind_Subscript)
         || (node->kind == Ast_Kind_Field_Access)
         || (node->kind == Ast_Kind_Memres)
+        || (node->kind == Ast_Kind_Capture_Local)
         || (node->kind == Ast_Kind_Constraint_Sentinel)) // Bit of a hack, but this makes constraints like 'T->foo()' work.
         return 1;
 
index 36885992182cc05e396c03610bdf57776b867c45..1b1fc61e43c4b06d55f75f145ba09b21bc067198 100644 (file)
@@ -41,5 +41,7 @@ u32 levenshtein_distance(const char *str1, const char *str2);
 char *find_closest_symbol_in_scope_and_parents(Scope *scope, char *sym);
 char *find_closest_symbol_in_node(AstNode *node, char *sym);
 
+b32 maybe_create_capture_builder_for_function_expression(AstTyped **pexpr);
+
 extern AstTyped node_that_signals_a_yield;
 extern AstTyped node_that_signals_failure;
index f54daf910940a85ad1dfbddb2780f7c4f56afb96..7cd54e0da98c5fd9cf979c2d1f86b29a9837f043 100644 (file)
@@ -748,6 +748,7 @@ typedef struct OnyxWasmModule {
     i32 *tls_size_ptr;
     i32 *heap_start_ptr;
     u64 stack_base_idx;
+    u64 closure_base_idx;
     CallingConvention curr_cc;
     i32 null_proc_func_idx;
 
index d97b793ca694e4919d023e2bc4dc4acaf38c7a39..a5324eff6db195a7fcb6b7867ce5e90abe4482fd 100644 (file)
@@ -110,6 +110,7 @@ static const char* ast_node_names[] = {
 
     "CAPTURE BLOCK",
     "CAPTURE LOCAL",
+    "CAPTURE BUILDER",
 
     "FOREIGN BLOCK",
     "ZERO VALUE",
index 6491f3850e4a3e14c6d38d7af1d6240baef90b1b..b077d667466fc9f6d338e7909be33132025662fe 100644 (file)
@@ -38,14 +38,16 @@ AstBasicType basic_type_auto_return = { Ast_Kind_Basic_Type, 0, &simd_token, NUL
 
 OnyxToken builtin_package_token = { Token_Type_Symbol, 7, "builtin ", { 0 } };
 
-static OnyxToken builtin_heap_start_token = { Token_Type_Symbol, 12, "__heap_start ", { 0 } };
-static OnyxToken builtin_stack_top_token  = { Token_Type_Symbol, 11, "__stack_top ",  { 0 } };
-static OnyxToken builtin_tls_base_token   = { Token_Type_Symbol, 10, "__tls_base ",  { 0 } };
-static OnyxToken builtin_tls_size_token   = { Token_Type_Symbol, 10, "__tls_size ",  { 0 } };
+static OnyxToken builtin_heap_start_token  = { Token_Type_Symbol, 12, "__heap_start ", { 0 } };
+static OnyxToken builtin_stack_top_token   = { Token_Type_Symbol, 11, "__stack_top ",  { 0 } };
+static OnyxToken builtin_tls_base_token    = { Token_Type_Symbol, 10, "__tls_base ",  { 0 } };
+static OnyxToken builtin_tls_size_token    = { Token_Type_Symbol, 10, "__tls_size ",  { 0 } };
+static OnyxToken builtin_closure_base_token = { Token_Type_Symbol, 14, "__closure_base ",  { 0 } };
 AstGlobal builtin_heap_start  = { Ast_Kind_Global, Ast_Flag_Const, &builtin_heap_start_token, NULL, NULL, (AstType *) &basic_type_rawptr, NULL };
 AstGlobal builtin_stack_top   = { Ast_Kind_Global, 0, &builtin_stack_top_token, NULL, NULL, (AstType *) &basic_type_rawptr, NULL };
 AstGlobal builtin_tls_base    = { Ast_Kind_Global, 0, &builtin_tls_base_token, NULL, NULL, (AstType *) &basic_type_rawptr, NULL };
 AstGlobal builtin_tls_size    = { Ast_Kind_Global, 0, &builtin_tls_size_token, NULL, NULL, (AstType *) &basic_type_u32, NULL };
+AstGlobal builtin_closure_base = { Ast_Kind_Global, 0, &builtin_closure_base_token, NULL, NULL, (AstType *) &basic_type_rawptr, NULL };
 
 AstType  *builtin_string_type;
 AstType  *builtin_cstring_type;
@@ -69,6 +71,7 @@ AstType     *foreign_block_type = NULL;
 AstTyped    *tagged_procedures_node = NULL;
 AstFunction *builtin_initialize_data_segments = NULL;
 AstFunction *builtin_run_init_procedures = NULL;
+AstFunction *builtin_closure_block_allocate = NULL;
 bh_arr(AstFunction *) init_procedures = NULL;
 AstOverloadedFunction *builtin_implicit_bool_cast;
 
@@ -100,6 +103,7 @@ const BuiltinSymbol builtin_symbols[] = {
     { "builtin", "__stack_top",  (AstNode *) &builtin_stack_top },
     { "builtin", "__tls_base",   (AstNode *) &builtin_tls_base },
     { "builtin", "__tls_size",   (AstNode *) &builtin_tls_size },
+    { "builtin", "__closure_base",   (AstNode *) &builtin_closure_base },
 
     { NULL, NULL, NULL },
 };
@@ -479,6 +483,15 @@ void initialize_builtins(bh_allocator a) {
         return;
     }
 
+    builtin_closure_block_allocate = (AstFunction *) symbol_raw_resolve(p->scope, "__closure_block_allocate");
+    if (builtin_closure_block_allocate == NULL || builtin_closure_block_allocate->kind != Ast_Kind_Function) {
+        onyx_report_error((OnyxFilePos) { 0 }, Error_Critical, "'__closure_block_allocate' procedure not found.");
+        return;
+    }
+    // HACK
+    builtin_closure_block_allocate->flags |= Ast_Flag_Function_Used;
+
+
     builtin_link_options_type = (AstType *) symbol_raw_resolve(p->scope, "Link_Options");
     if (builtin_link_options_type == NULL) {
         onyx_report_error((OnyxFilePos) { 0 }, Error_Critical, "'Link_Options' type not found.");
index 7112f11b57428d3a4daae3a69b877a8e2ff9c98b..2ea8471a901c4dad3c8a8bc9fe6264c5fc84ae51 100644 (file)
@@ -1756,6 +1756,7 @@ CheckStatus check_address_of(AstAddressOf** paof) {
             && expr->kind != Ast_Kind_Field_Access
             && expr->kind != Ast_Kind_Memres
             && expr->kind != Ast_Kind_Local
+            && expr->kind != Ast_Kind_Capture_Local
             && expr->kind != Ast_Kind_Constraint_Sentinel
             && !node_is_addressable_literal((AstNode *) expr))
             || (expr->flags & Ast_Flag_Cannot_Take_Addr) != 0) {
@@ -2162,6 +2163,9 @@ CheckStatus check_expression(AstTyped** pexpr) {
                 YIELD(expr->token->pos, "Waiting for function type to be resolved.");
 
             expr->flags |= Ast_Flag_Function_Used;
+            if (maybe_create_capture_builder_for_function_expression(pexpr)) {
+                retval = Check_Return_To_Symres;
+            }
             break;
 
         case Ast_Kind_Directive_Solidify:
@@ -2230,6 +2234,22 @@ CheckStatus check_expression(AstTyped** pexpr) {
             YIELD(expr->token->pos, "Waiting to resolve #this_package.");
             break;
 
+        case Ast_Kind_Capture_Builder: {
+            AstCaptureBuilder *builder = (void *) expr;
+
+            fori (i, 0, bh_arr_length(builder->capture_values)) {
+                if (!builder->captures->captures[i]->type) {
+                    YIELD(expr->token->pos, "Waiting to know capture value types.");
+                }
+
+                TYPE_CHECK(&builder->capture_values[i], builder->captures->captures[i]->type) {
+                    ERROR_(builder->captures->captures[i]->token->pos, "Type mismatch for this captured value. Expected '%s', got '%s'.",
+                            type_get_name(builder->captures->captures[i]->type), type_get_name(builder->capture_values[i]->type));
+                }
+            }
+            break;
+        }
+
         case Ast_Kind_File_Contents: break;
         case Ast_Kind_Overloaded_Function: break;
         case Ast_Kind_Enum_Value: break;
@@ -2414,8 +2434,7 @@ CheckStatus check_capture_block(AstCaptureBlock *block) {
 
     bh_arr_each(AstCaptureLocal *, capture, block->captures) {
         CHECK(expression, (AstTyped **) capture);
-
-        assert((*capture)->type);
+        if (!(*capture)->type) YIELD((*capture)->token->pos, "Waiting to resolve captures type.");
 
         (*capture)->offset = block->total_size_in_bytes;
         block->total_size_in_bytes += type_size_of((*capture)->type);
index 28c79a5d8baa15b2535cc82c4d5e07f52beb0faf..255848925fc0f49705c6f4f7900d8e7260f09240 100644 (file)
@@ -384,6 +384,7 @@ static void context_init(CompileOptions* opts) {
     add_entities_for_node(NULL, (AstNode *) &builtin_heap_start, context.global_scope, NULL);
     add_entities_for_node(NULL, (AstNode *) &builtin_tls_base, context.global_scope, NULL);
     add_entities_for_node(NULL, (AstNode *) &builtin_tls_size, context.global_scope, NULL);
+    add_entities_for_node(NULL, (AstNode *) &builtin_closure_base, context.global_scope, NULL);
 
     // NOTE: Add all files passed by command line to the queue
     bh_arr_each(const char *, filename, opts->files) {
index 5cc9cab2923eee5ebffa14c4df47ec903333465b..0d9b6dbd3c450026092c7c4e312fd0fa3d80dc2f 100644 (file)
@@ -73,6 +73,7 @@ static SymresStatus symres_static_if(AstIf* static_if);
 static SymresStatus symres_macro(AstMacro* macro);
 static SymresStatus symres_constraint(AstConstraint* constraint);
 static SymresStatus symres_polyquery(AstPolyQuery *query);
+static SymresStatus symres_capture_builder(AstCaptureBuilder *builder);
 
 static void scope_enter(Scope* new_scope) {
     current_scope = new_scope;
@@ -659,6 +660,8 @@ static SymresStatus symres_expression(AstTyped** expr) {
             break;
         }
 
+        case Ast_Kind_Capture_Builder: SYMRES(capture_builder, (AstCaptureBuilder *) *expr); break;
+
         default: break;
     }
 
@@ -854,6 +857,23 @@ static SymresStatus symres_capture_block(AstCaptureBlock *block) {
     return Symres_Success;
 }
 
+static SymresStatus symres_capture_builder(AstCaptureBuilder *builder) {
+    fori (i, bh_arr_length(builder->capture_values), bh_arr_length(builder->captures->captures)) {
+        OnyxToken *token = builder->captures->captures[i]->token;
+        AstTyped *resolved = (AstTyped *) symbol_resolve(current_scope, token);
+        if (!resolved) {
+            // Should this do a yield? In there any case that that would make sense?
+            onyx_report_error(token->pos, Error_Critical, "'%b' is not found in the enclosing scope.",
+                    token->text, token->length);
+            return Symres_Error;
+        }
+
+        bh_arr_push(builder->capture_values, resolved);
+    }
+
+    return Symres_Success;
+}
+
 static SymresStatus symres_statement(AstNode** stmt, b32 *remove) {
     if (remove) *remove = 0;
 
index cbc8fff4e42bbaa548f1ad0d2930b908d91bdfa3..298f6ed1eb7eabe74346855c543846945bc5c4b3 100644 (file)
@@ -1480,3 +1480,25 @@ void track_resolution_for_symbol_info(AstNode *original, AstNode *resolved) {
 
     bh_arr_push(syminfo->symbols_resolutions, res);
 }
+
+
+
+b32 maybe_create_capture_builder_for_function_expression(AstTyped **pexpr) {
+    AstFunction *func = (void *) *pexpr;
+
+    if (!(func->flags & Ast_Flag_Function_Is_Lambda)) return 0;
+    if (!func->captures) return 0;
+
+    AstCaptureBuilder *builder = onyx_ast_node_new(context.ast_alloc, sizeof(AstCaptureBuilder), Ast_Kind_Capture_Builder);
+    builder->token = func->captures->token - 1;
+
+    builder->func = (void *) func;
+    builder->type = builder->func->type;
+    builder->captures = func->captures;
+
+    bh_arr_new(context.ast_alloc, builder->capture_values, bh_arr_length(builder->captures->captures));
+
+    *((void **) pexpr) = builder;
+    return 1;
+}
+
index 5c27686868b5e6a8694adfa9cba1a92a3e866549..5a5bc47d09989e0db95d38aba2cb516c877c9dee 100644 (file)
@@ -516,6 +516,7 @@ EMIT_FUNC(intrinsic_call,                  AstCall* call);
 EMIT_FUNC(subscript_location,              AstSubscript* sub, u64* offset_return);
 EMIT_FUNC(field_access_location,           AstFieldAccess* field, u64* offset_return);
 EMIT_FUNC(local_location,                  AstLocal* local, u64* offset_return);
+EMIT_FUNC(capture_local_location,          AstCaptureLocal *capture, u64 *offset_return);
 EMIT_FUNC(memory_reservation_location,     AstMemRes* memres);
 EMIT_FUNC(location_return_offset,          AstTyped* expr, u64* offset_return);
 EMIT_FUNC(location,                        AstTyped* expr);
@@ -2123,7 +2124,9 @@ EMIT_FUNC(call, AstCall* call) {
 
     } else {
         emit_expression(mod, &code, call->callee);
-        WI(NULL, WI_DROP);
+        
+        u64 global_closure_base_idx = bh_imap_get(&mod->index_map, (u64) &builtin_closure_base);
+        WIL(NULL, WI_GLOBAL_SET, global_closure_base_idx);
 
         i32 type_idx = generate_type_idx(mod, call->callee->type);
         WID(NULL, WI_CALL_INDIRECT, ((WasmInstructionData) { type_idx, 0x00 }));
@@ -2784,6 +2787,16 @@ EMIT_FUNC(local_location, AstLocal* local, u64* offset_return) {
     *pcode = code;
 }
 
+EMIT_FUNC(capture_local_location, AstCaptureLocal *capture, u64 *offset_return) {
+    bh_arr(WasmInstruction) code = *pcode;
+
+    WIL(NULL, WI_LOCAL_GET, mod->closure_base_idx);
+
+    *offset_return = capture->offset;
+
+    *pcode = code;
+}
+
 EMIT_FUNC(compound_load, Type* type, u64 offset, i32 ignored_value_count) {
     bh_arr(WasmInstruction) code = *pcode;
     i32 mem_count = type_linear_member_count(type);
@@ -3213,6 +3226,12 @@ EMIT_FUNC(location_return_offset, AstTyped* expr, u64* offset_return) {
             break;
         }
 
+        case Ast_Kind_Capture_Local: {
+            AstCaptureLocal *capture = (AstCaptureLocal *) expr;
+            emit_capture_local_location(mod, &code, capture, offset_return);
+            break;
+        }
+
         default: {
             if (expr->token) {
                 onyx_report_error(expr->token->pos, Error_Critical, "Unable to generate location for '%s'.", onyx_ast_node_kind_string(expr->kind));
@@ -3379,6 +3398,32 @@ EMIT_FUNC(expression, AstTyped* expr) {
             break;
         }
 
+        case Ast_Kind_Capture_Builder: {
+            AstCaptureBuilder *builder = (AstCaptureBuilder *) expr;
+            
+            assert(builder->func->kind == Ast_Kind_Function);
+            i32 elemidx = get_element_idx(mod, (AstFunction *) builder->func);
+            WID(NULL, WI_I32_CONST, elemidx);
+
+            // Allocate the block
+            WIL(NULL, WI_I32_CONST, builder->captures->total_size_in_bytes);
+            i32 func_idx = (i32) bh_imap_get(&mod->index_map, (u64) builtin_closure_block_allocate);
+            WIL(NULL, WI_CALL, func_idx);
+
+            u64 capture_block_ptr = local_raw_allocate(mod->local_alloc, WASM_TYPE_PTR);
+            WIL(NULL, WI_LOCAL_TEE, capture_block_ptr);
+            
+            // Populate the block
+            fori (i, 0, bh_arr_length(builder->capture_values)) {
+                WIL(NULL, WI_LOCAL_GET, capture_block_ptr);
+                emit_expression(mod, &code, builder->capture_values[i]);
+                emit_store_instruction(mod, &code, builder->capture_values[i]->type, builder->captures->captures[i]->offset);
+            }
+            
+            local_raw_free(mod->local_alloc, WASM_TYPE_PTR);
+            break;
+        }
+
         case Ast_Kind_Block:          emit_block(mod, &code, (AstBlock *) expr, 1); break;
         case Ast_Kind_Do_Block:       emit_do_block(mod, &code, (AstDoBlock *) expr); break;
         case Ast_Kind_Call:           emit_call(mod, &code, (AstCall *) expr); break;
@@ -3644,8 +3689,10 @@ EMIT_FUNC(expression, AstTyped* expr) {
         }
 
         case Ast_Kind_Capture_Local: {
-            printf("HANDLE CAPTURE LOCAL!!!\n");
-            assert(0);
+            AstCaptureLocal* capture = (AstCaptureLocal *) expr;
+            u64 offset = 0;
+            emit_capture_local_location(mod, &code, capture, &offset);
+            emit_load_instruction(mod, &code, capture->type, offset);
             break;
         }
 
@@ -4123,6 +4170,17 @@ static void emit_function(OnyxWasmModule* mod, AstFunction* fd) {
         mod->stack_base_idx = local_raw_allocate(mod->local_alloc, WASM_TYPE_PTR);
         debug_function_set_ptr_idx(mod, func_idx, mod->stack_base_idx);
 
+        if (fd->captures) {
+            mod->closure_base_idx = local_raw_allocate(mod->local_alloc, WASM_TYPE_PTR);
+
+            debug_emit_instruction(mod, NULL);
+            debug_emit_instruction(mod, NULL);
+
+            u64 global_closure_base_idx = bh_imap_get(&mod->index_map, (u64) &builtin_closure_base);
+            bh_arr_push(wasm_func.code, ((WasmInstruction) { WI_GLOBAL_GET, { .l = global_closure_base_idx } }));
+            bh_arr_push(wasm_func.code, ((WasmInstruction) { WI_LOCAL_SET,  { .l = mod->closure_base_idx } }));
+        }
+
         // Generate code
         emit_function_body(mod, &wasm_func.code, fd);
 
@@ -4708,6 +4766,8 @@ OnyxWasmModule onyx_wasm_module_create(bh_allocator alloc) {
         .stack_top_ptr = NULL,
         .stack_base_idx = 0,
 
+        .closure_base_idx = 0,
+
         .foreign_function_count = 0,
 
         .null_proc_func_idx = -1,
index abe73849a07e254001ae16f5d867fee7a8bf6f7a..1cea27ae23da133bd3bfcb9e3f2033957475f7e3 100644 (file)
@@ -469,6 +469,15 @@ __run_init_procedures :: () -> void ---
 """
 __implicit_bool_cast :: #match -> bool {}
 
+#doc """
+    Internal procedure to allocate space for the captures in a closure. This will be soon
+    changed to a configurable way, but for now it simply allocates out of the heap allocator.
+"""
+__closure_block_allocate :: (size: i32) -> rawptr {
+    return raw_alloc(context.allocator, size);
+}
+
+
 #doc """
     Defines all options for changing the memory layout, imports and exports,
     and more of an Onyx binary.