added: `switch` over tagged unions
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 22 May 2023 16:42:25 +0000 (11:42 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 22 May 2023 16:42:25 +0000 (11:42 -0500)
compiler/include/astnodes.h
compiler/include/lex.h
compiler/include/types.h
compiler/src/checker.c
compiler/src/lex.c
compiler/src/parser.c
compiler/src/symres.c
compiler/src/types.c
compiler/src/wasm_emit.c
tests/tagged_unions.onyx

index fc9ecacf31b81ab9a920164c4efe725efc2da44b..f5aab0afc4ecb71473b185834618c4c45f5dc478 100644 (file)
@@ -872,6 +872,7 @@ typedef struct AstIfWhile AstWhile;
 typedef enum SwitchKind {
     Switch_Kind_Integer,
     Switch_Kind_Use_Equals,
+    Switch_Kind_Union,
 } SwitchKind;
 
 typedef struct CaseToBlock {
@@ -888,7 +889,10 @@ struct AstSwitchCase {
 
     AstBlock *block;
 
+    AstLocal *capture;
+
     b32 is_default: 1; // Could this be inferred by the values array being null?
+    b32 capture_is_by_pointer: 1;
 };
 
 struct AstSwitch {
index 37c4e084c55598ea00bb943949ede1c6d34f8ccd..bc9b0cf3882102ca04625cfef1d45a12e2195d7b 100644 (file)
@@ -44,6 +44,7 @@ typedef enum TokenType {
     Token_Type_Keyword_End,
 
     Token_Type_Right_Arrow,
+    Token_Type_Fat_Right_Arrow,
     Token_Type_Left_Arrow,
     Token_Type_Empty_Block,
     Token_Type_Pipe,
index 430e4f8fdf30d9ae9c0274c24915477d0be772c7..29ce2659dc05779495159f0efea9e781bccab04b 100644 (file)
@@ -168,6 +168,7 @@ typedef struct UnionVariant {
         char* name;                                               \
         Type* tag_type;                                           \
         Table(UnionVariant *) variants;                           \
+        bh_arr(UnionVariant *) variants_ordered;                  \
         bh_arr(struct AstPolySolution) poly_sln;                  \
         struct AstType *constructed_from;                         \
         bh_arr(struct AstTyped *) meta_tags;                      \
index fe492897b9956e9b431fcaa78f20d132cd08f02f..e9140f79b558065ac9e5c9aa04278983135ba4d7 100644 (file)
@@ -363,7 +363,7 @@ fornode_expr_checked:
 }
 
 static b32 add_case_to_switch_statement(AstSwitch* switchnode, u64 case_value, AstBlock* block, OnyxFilePos pos) {
-    assert(switchnode->switch_kind == Switch_Kind_Integer);
+    assert(switchnode->switch_kind == Switch_Kind_Integer || switchnode->switch_kind == Switch_Kind_Union);
 
     switchnode->min_case = bh_min(switchnode->min_case, case_value);
     switchnode->max_case = bh_max(switchnode->max_case, case_value);
@@ -440,6 +440,9 @@ CheckStatus check_switch(AstSwitch* switchnode) {
         if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) {
             switchnode->switch_kind = Switch_Kind_Use_Equals;
         }
+        if (switchnode->expr->type->kind == Type_Kind_Union) {
+            switchnode->switch_kind = Switch_Kind_Union;
+        }
 
         switch (switchnode->switch_kind) {
             case Switch_Kind_Integer:
@@ -451,6 +454,11 @@ CheckStatus check_switch(AstSwitch* switchnode) {
                 bh_arr_new(global_heap_allocator, switchnode->case_exprs, 4);
                 break;
 
+            case Switch_Kind_Union:
+                switchnode->min_case = 1;
+                bh_imap_init(&switchnode->case_map, global_heap_allocator, 4);
+                break;
+
             default: assert(0);
         }
     }
@@ -474,11 +482,17 @@ CheckStatus check_switch(AstSwitch* switchnode) {
 
     fori (i, switchnode->yield_return_index, bh_arr_length(switchnode->cases)) {
         AstSwitchCase *sc = switchnode->cases[i];
-        CHECK(block, sc->block);
+
+        if (sc->capture && bh_arr_length(sc->values) != 1) {
+            ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture.");
+        }
+
+        if (sc->flags & Ast_Flag_Has_Been_Checked) goto check_switch_case_block;
 
         bh_arr_each(AstTyped *, value, sc->values) {
             CHECK(expression, value);
 
+            // Handle case 1 .. 10
             if (switchnode->switch_kind == Switch_Kind_Integer && (*value)->kind == Ast_Kind_Range_Literal) {
                 AstRangeLiteral* rl = (AstRangeLiteral *) (*value);
                 resolve_expression_type(rl->low);
@@ -503,16 +517,38 @@ CheckStatus check_switch(AstSwitch* switchnode) {
                 continue;
             }
 
-            TYPE_CHECK(value, resolved_expr_type) {
-                OnyxToken* tkn = sc->block->token;
-                if ((*value)->token) tkn = (*value)->token;
+            if (switchnode->switch_kind == Switch_Kind_Union) {
+                TYPE_CHECK(value, resolved_expr_type->Union.tag_type) {
+                    OnyxToken* tkn = sc->block->token;
+                    if ((*value)->token) tkn = (*value)->token;
+
+                    ERROR_(tkn->pos, "'%b' is not a variant of '%s'.",
+                        (*value)->token->text, (*value)->token->length, type_get_name(resolved_expr_type));
+                }
+
+                // nocheckin Explain the -1
+                UnionVariant *union_variant = resolved_expr_type->Union.variants_ordered[get_expression_integer_value(*value, NULL) - 1];
+                if (sc->capture) {
+                    if (sc->capture_is_by_pointer) {
+                        sc->capture->type = type_make_pointer(context.ast_alloc, union_variant->type);
+                    } else {
+                        sc->capture->type = union_variant->type;
+                    }
+                }
+                
+            } else {
+                TYPE_CHECK(value, resolved_expr_type) {
+                    OnyxToken* tkn = sc->block->token;
+                    if ((*value)->token) tkn = (*value)->token;
 
-                ERROR_(tkn->pos, "Mismatched types in switch-case. Expected '%s', got '%s'.",
-                    type_get_name(resolved_expr_type), type_get_name((*value)->type));
+                    ERROR_(tkn->pos, "Mismatched types in switch-case. Expected '%s', got '%s'.",
+                        type_get_name(resolved_expr_type), type_get_name((*value)->type));
+                }
             }
 
             switch (switchnode->switch_kind) {
-                case Switch_Kind_Integer: {
+                case Switch_Kind_Integer:
+                case Switch_Kind_Union: {
                     b32 is_valid;
                     i64 integer_value = get_expression_integer_value(*value, &is_valid);
                     if (!is_valid)
@@ -549,6 +585,11 @@ CheckStatus check_switch(AstSwitch* switchnode) {
             }
         }
 
+        sc->flags |= Ast_Flag_Has_Been_Checked;
+
+      check_switch_case_block:
+        CHECK(block, sc->block);
+
         switchnode->yield_return_index += 1;
     }
 
index 4ce31361e7a85226c78aa38effd312f120aa22ce..40d88478be972646f3fc5eae2000cb28de94e17a 100644 (file)
@@ -39,6 +39,7 @@ static const char* token_type_names[] = {
     "", // end
 
     "->",
+    "=>",
     "<-",
     "---",
     "|>",
@@ -422,6 +423,7 @@ whitespace_skipped:
 
     case '=':
         LITERAL_TOKEN("==",          0, Token_Type_Equal_Equal);
+        LITERAL_TOKEN("=>",          0, Token_Type_Fat_Right_Arrow);
         break;
 
     case '!':
index 8ad16fb6b2ae6a468a0090271f9be35cf831e293..e5564092c6ff5bd8f78c9419b2200117197d45c0 100644 (file)
@@ -336,7 +336,7 @@ static void parse_arguments(OnyxParser* parser, TokenType end_token, Arguments*
         // This shouldn't be a named argument, but this should:
         //     f(g = x => x + 1)
         //
-        if (next_tokens_are(parser, 2, Token_Type_Symbol, '=') && peek_token(2)->type != '>') {
+        if (next_tokens_are(parser, 2, Token_Type_Symbol, '=')) {
             OnyxToken* name = expect_token(parser, Token_Type_Symbol);
             expect_token(parser, '=');
 
@@ -1319,6 +1319,19 @@ static AstSwitchCase* parse_case_stmt(OnyxParser* parser) {
         }
     }
 
+    if (consume_token_if_next(parser, Token_Type_Fat_Right_Arrow)) {
+        // Captured value for union switching
+        b32 is_pointer = 0;
+        if (consume_token_if_next(parser, '&') || consume_token_if_next(parser, '^'))
+            is_pointer = 1;
+        
+        OnyxToken *capture_symbol = expect_token(parser, Token_Type_Symbol);
+        AstLocal *capture = make_local(parser->allocator, capture_symbol, NULL);
+        
+        sc_node->capture = capture;
+        sc_node->capture_is_by_pointer = is_pointer;
+    }
+
     sc_node->block = parse_block(parser, 1, NULL);
 
     return sc_node;
@@ -2705,8 +2718,7 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok
         name = bh_aprintf(global_heap_allocator, "%b", current_symbol->text, current_symbol->length);
     }
 
-    if (consume_token_if_next(parser, '=')) {
-        expect_token(parser, '>');
+    if (consume_token_if_next(parser, Token_Type_Fat_Right_Arrow)) {
         func_def->return_type = (AstType *) &basic_type_auto_return;
 
         if (parser->curr->type == '{') {
@@ -2798,7 +2810,7 @@ static b32 parse_possible_function_definition_no_consume(OnyxParser* parser) {
         OnyxToken* matching_paren = find_matching_paren(parser->curr);
         if (matching_paren == NULL) return 0;
 
-        if (next_tokens_are(parser, 4, '(', ')', '=', '>')) return 0;
+        if (next_tokens_are(parser, 3, '(', ')', Token_Type_Fat_Right_Arrow)) return 0;
 
         // :LinearTokenDependent
         OnyxToken* token_after_paren = matching_paren + 1;
@@ -2807,7 +2819,7 @@ static b32 parse_possible_function_definition_no_consume(OnyxParser* parser) {
             && token_after_paren->type != Token_Type_Keyword_Do
             && token_after_paren->type != Token_Type_Empty_Block
             && token_after_paren->type != Token_Type_Keyword_Where
-            && (token_after_paren->type != '=' || (token_after_paren + 1)->type != '>'))
+            && token_after_paren->type != Token_Type_Fat_Right_Arrow)
             return 0;
 
         // :LinearTokenDependent
@@ -2849,7 +2861,7 @@ typedef struct QuickParam {
 static b32 parse_possible_quick_function_definition_no_consume(OnyxParser* parser) {
     //
     // x => x + 1 case.
-    if (next_tokens_are(parser, 3, Token_Type_Symbol, '=', '>')) {
+    if (next_tokens_are(parser, 2, Token_Type_Symbol, Token_Type_Fat_Right_Arrow)) {
         return 1;
     }
 
@@ -2860,7 +2872,7 @@ static b32 parse_possible_quick_function_definition_no_consume(OnyxParser* parse
 
     // :LinearTokenDependent
     OnyxToken* token_after_paren = matching_paren + 1;
-    if (token_after_paren->type != '=' || (token_after_paren + 1)->type != '>')
+    if (token_after_paren->type != Token_Type_Fat_Right_Arrow)
         return 0;
 
     return 1;
@@ -2903,8 +2915,7 @@ static b32 parse_possible_quick_function_definition(OnyxParser* parser, AstTyped
         }
     }
 
-    expect_token(parser, '=');
-    expect_token(parser, '>');
+    expect_token(parser, Token_Type_Fat_Right_Arrow);
 
     bh_arr(AstNode *) poly_params=NULL;
     bh_arr_new(global_heap_allocator, poly_params, bh_arr_length(params));
index 086cc23a6805abc511da5714e46897e09c890226..b00a609a64f042026d501f3f79f2a94fb80d71bf 100644 (file)
@@ -815,6 +815,11 @@ static SymresStatus symres_case(AstSwitchCase *casenode) {
         }
     }
 
+    if (casenode->capture) {
+        casenode->block->scope = scope_create(context.ast_alloc, current_scope, casenode->block->token->pos);
+        symbol_introduce(casenode->block->scope, casenode->capture->token, (AstNode *) casenode->capture);
+    }
+
     SYMRES(block, casenode->block);
     return Symres_Success;
 }
index cd2a4082aad8535753e62c23a3a055548cd24967..6fc88817a5181a247bb6c3f2dd2555f6d15f1465 100644 (file)
@@ -689,7 +689,9 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b
                 type_register(u_type);
 
                 u_type->Union.variants = NULL;
+                u_type->Union.variants_ordered = NULL;
                 sh_new_arena(u_type->Union.variants);
+                bh_arr_new(global_heap_allocator, u_type->Union.variants_ordered, bh_arr_length(union_->variants));
             } else {
                 u_type = union_->pending_type;
             }
@@ -770,6 +772,8 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b
                 shput(u_type->Union.variants, variant->token->text, uv);
                 token_toggle_end(variant->token);
 
+                bh_arr_push(u_type->Union.variants_ordered, uv);
+
                 AstEnumValue *ev = onyx_ast_node_new(alloc, sizeof(AstEnumValue), Ast_Kind_Enum_Value);
                 ev->token = uv->token;
                 ev->value = (AstTyped *) make_int_literal(alloc, uv->tag_value);
index 2ec076f19a15922f6904450e9ce9299778e3290a..48281539451d94edd5006255fdd9cafcc2ad0e8b 100644 (file)
@@ -1625,8 +1625,11 @@ EMIT_FUNC(switch, AstSwitch* switch_node) {
         block_num++;
     }
 
+    u64 union_capture_idx = 0;
+
     switch (switch_node->switch_kind) {
-        case Switch_Kind_Integer: {
+        case Switch_Kind_Integer:
+        case Switch_Kind_Union: {
             u64 count = switch_node->max_case + 1 - switch_node->min_case;
             BranchTable* bt = bh_alloc(mod->extended_instr_alloc, sizeof(BranchTable) + sizeof(u32) * count);
             bt->count = count;
@@ -1651,6 +1654,14 @@ EMIT_FUNC(switch, AstSwitch* switch_node) {
             // to the second block, and so on.
             WID(switch_node->expr->token, WI_BLOCK_START, 0x40);
             emit_expression(mod, &code, switch_node->expr);
+
+            if (switch_node->switch_kind == Switch_Kind_Union) {
+                assert(switch_node->expr->type->kind == Type_Kind_Union);
+                union_capture_idx = local_raw_allocate(mod->local_alloc, WASM_TYPE_PTR);
+                WIL(NULL, WI_LOCAL_TEE, union_capture_idx);
+                emit_load_instruction(mod, &code, switch_node->expr->type->Union.tag_type, 0);
+            }
+
             if (switch_node->min_case != 0) {
                 if (onyx_type_to_wasm_type(switch_node->expr->type) == WASM_TYPE_INT64) {
                     WI(switch_node->expr->token, WI_I32_FROM_I64);
@@ -1688,6 +1699,34 @@ EMIT_FUNC(switch, AstSwitch* switch_node) {
 
         u64 bn = bh_imap_get(&block_map, (u64) sc->block);
 
+        if (sc->capture) {
+            assert(union_capture_idx != 0);
+
+            if (sc->capture_is_by_pointer) {
+                u64 capture_pointer_local = emit_local_allocation(mod, &code, sc->capture);
+
+                WIL(NULL, WI_LOCAL_GET, union_capture_idx);
+                WIL(NULL, WI_PTR_CONST, switch_node->expr->type->Union.alignment);
+                WI(NULL, WI_PTR_ADD);
+
+                WIL(NULL, WI_LOCAL_SET, capture_pointer_local);
+
+            } else {
+                sc->capture->flags |= Ast_Flag_Decl_Followed_By_Init;
+                sc->capture->flags |= Ast_Flag_Address_Taken;
+
+                emit_local_allocation(mod, &code, sc->capture);
+                emit_location(mod, &code, sc->capture);
+
+                WIL(NULL, WI_LOCAL_GET, union_capture_idx);
+                WIL(NULL, WI_PTR_CONST, switch_node->expr->type->Union.alignment);
+                WI(NULL, WI_PTR_ADD);
+                
+                WIL(NULL, WI_I32_CONST, type_size_of(sc->capture->type));
+                emit_wasm_copy(mod, &code, NULL);
+            }
+        }
+
         // Maybe the Symbol Frame idea should be controlled as a block_flag?
         debug_enter_symbol_frame(mod);
         emit_block(mod, &code, sc->block, 0);
@@ -1705,6 +1744,7 @@ EMIT_FUNC(switch, AstSwitch* switch_node) {
         emit_block(mod, &code, switch_node->default_case, 0);
     }
 
+    if (union_capture_idx != 0) local_raw_free(mod->local_alloc, WASM_TYPE_PTR);
     emit_leave_structured_block(mod, &code);
 
     bh_imap_free(&block_map);
index 5bfb2fa484c32b7a54c9a098b3de00555abced25..f616c3012043914922f4598b99af302775a33ca0 100644 (file)
@@ -15,11 +15,18 @@ call_test :: (u_: SimpleUnion) {
     u := u_;
     __byte_dump(&u, sizeof typeof u);
 
-    switch cast(SimpleUnion.tag_enum) u {
+    switch u {
         case .a do println("It was a!");
-        case .b do println("It was B!");
+        case .b => v do printf("It was B! {}\n", v);
         case .empty do println("It was EMPTY!");
-        case .large do println("It was a large number!");
+
+        case .large => &value {
+            printf("It was a large number! {} {}\n", value, *value);
+        }
+
+        case #default {
+            printf("none of the above.\n");
+        }
     }
 }
 
@@ -27,14 +34,15 @@ simple_test :: () {
     u := SimpleUnion.{ b = .{ "asdf" } };
     u  = .{ a = 123 };
     u  = .{ empty = .{} };
-    u  = .{ large = 0x0123456789abcdef };
+    u  = .{ large = 123456789 };
+    u  = .{ b = .{ "Wow this works!!" } };
 
     println(cast(SimpleUnion.tag_enum) u);
 
     call_test(u);
 }
 
-main :: () {simple_test();} //link_test();}
+main :: () {simple_test(); link_test();}
 
 Link :: union {
     End: void;
@@ -44,15 +52,15 @@ Link :: union {
     }
 }
 
-/*
 print_links :: (l: Link) {
     walker := l;
     while true {
         switch walker {
-            case Link.End do break;
-            case Link.Next => &next {
-                printf("{} ", next.data);
-                walker = next.next;
+            case .End do break break;
+
+            case .Next => &next {
+                printf("{}\n", next.data);
+                walker = *next.next;
             }
         }
     }
@@ -63,14 +71,16 @@ link_test :: () {
         Next = .{
             data = 123,
             next = &Link.{
-                End = .{}
+                Next = .{
+                    data = 456,
+                    next = &Link.{ End = .{} },
+                }
             }
         }
     };
     
     print_links(l);
 }
-*/
 
 // main :: () { link_test(); }