From: Brendan Hansen Date: Mon, 22 May 2023 16:42:25 +0000 (-0500) Subject: added: `switch` over tagged unions X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=148d0d81a7cd222f93922a93749106f16d4c3048;p=onyx.git added: `switch` over tagged unions --- diff --git a/compiler/include/astnodes.h b/compiler/include/astnodes.h index fc9ecacf..f5aab0af 100644 --- a/compiler/include/astnodes.h +++ b/compiler/include/astnodes.h @@ -872,6 +872,7 @@ typedef struct AstIfWhile AstWhile; typedef enum SwitchKind { Switch_Kind_Integer, Switch_Kind_Use_Equals, + Switch_Kind_Union, } SwitchKind; typedef struct CaseToBlock { @@ -888,7 +889,10 @@ struct AstSwitchCase { AstBlock *block; + AstLocal *capture; + b32 is_default: 1; // Could this be inferred by the values array being null? + b32 capture_is_by_pointer: 1; }; struct AstSwitch { diff --git a/compiler/include/lex.h b/compiler/include/lex.h index 37c4e084..bc9b0cf3 100644 --- a/compiler/include/lex.h +++ b/compiler/include/lex.h @@ -44,6 +44,7 @@ typedef enum TokenType { Token_Type_Keyword_End, Token_Type_Right_Arrow, + Token_Type_Fat_Right_Arrow, Token_Type_Left_Arrow, Token_Type_Empty_Block, Token_Type_Pipe, diff --git a/compiler/include/types.h b/compiler/include/types.h index 430e4f8f..29ce2659 100644 --- a/compiler/include/types.h +++ b/compiler/include/types.h @@ -168,6 +168,7 @@ typedef struct UnionVariant { char* name; \ Type* tag_type; \ Table(UnionVariant *) variants; \ + bh_arr(UnionVariant *) variants_ordered; \ bh_arr(struct AstPolySolution) poly_sln; \ struct AstType *constructed_from; \ bh_arr(struct AstTyped *) meta_tags; \ diff --git a/compiler/src/checker.c b/compiler/src/checker.c index fe492897..e9140f79 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -363,7 +363,7 @@ fornode_expr_checked: } static b32 add_case_to_switch_statement(AstSwitch* switchnode, u64 case_value, AstBlock* block, OnyxFilePos pos) { - assert(switchnode->switch_kind == Switch_Kind_Integer); + assert(switchnode->switch_kind == Switch_Kind_Integer || switchnode->switch_kind == Switch_Kind_Union); switchnode->min_case = bh_min(switchnode->min_case, case_value); switchnode->max_case = bh_max(switchnode->max_case, case_value); @@ -440,6 +440,9 @@ CheckStatus check_switch(AstSwitch* switchnode) { if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) { switchnode->switch_kind = Switch_Kind_Use_Equals; } + if (switchnode->expr->type->kind == Type_Kind_Union) { + switchnode->switch_kind = Switch_Kind_Union; + } switch (switchnode->switch_kind) { case Switch_Kind_Integer: @@ -451,6 +454,11 @@ CheckStatus check_switch(AstSwitch* switchnode) { bh_arr_new(global_heap_allocator, switchnode->case_exprs, 4); break; + case Switch_Kind_Union: + switchnode->min_case = 1; + bh_imap_init(&switchnode->case_map, global_heap_allocator, 4); + break; + default: assert(0); } } @@ -474,11 +482,17 @@ CheckStatus check_switch(AstSwitch* switchnode) { fori (i, switchnode->yield_return_index, bh_arr_length(switchnode->cases)) { AstSwitchCase *sc = switchnode->cases[i]; - CHECK(block, sc->block); + + if (sc->capture && bh_arr_length(sc->values) != 1) { + ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture."); + } + + if (sc->flags & Ast_Flag_Has_Been_Checked) goto check_switch_case_block; bh_arr_each(AstTyped *, value, sc->values) { CHECK(expression, value); + // Handle case 1 .. 10 if (switchnode->switch_kind == Switch_Kind_Integer && (*value)->kind == Ast_Kind_Range_Literal) { AstRangeLiteral* rl = (AstRangeLiteral *) (*value); resolve_expression_type(rl->low); @@ -503,16 +517,38 @@ CheckStatus check_switch(AstSwitch* switchnode) { continue; } - TYPE_CHECK(value, resolved_expr_type) { - OnyxToken* tkn = sc->block->token; - if ((*value)->token) tkn = (*value)->token; + if (switchnode->switch_kind == Switch_Kind_Union) { + TYPE_CHECK(value, resolved_expr_type->Union.tag_type) { + OnyxToken* tkn = sc->block->token; + if ((*value)->token) tkn = (*value)->token; + + ERROR_(tkn->pos, "'%b' is not a variant of '%s'.", + (*value)->token->text, (*value)->token->length, type_get_name(resolved_expr_type)); + } + + // nocheckin Explain the -1 + UnionVariant *union_variant = resolved_expr_type->Union.variants_ordered[get_expression_integer_value(*value, NULL) - 1]; + if (sc->capture) { + if (sc->capture_is_by_pointer) { + sc->capture->type = type_make_pointer(context.ast_alloc, union_variant->type); + } else { + sc->capture->type = union_variant->type; + } + } + + } else { + TYPE_CHECK(value, resolved_expr_type) { + OnyxToken* tkn = sc->block->token; + if ((*value)->token) tkn = (*value)->token; - ERROR_(tkn->pos, "Mismatched types in switch-case. Expected '%s', got '%s'.", - type_get_name(resolved_expr_type), type_get_name((*value)->type)); + ERROR_(tkn->pos, "Mismatched types in switch-case. Expected '%s', got '%s'.", + type_get_name(resolved_expr_type), type_get_name((*value)->type)); + } } switch (switchnode->switch_kind) { - case Switch_Kind_Integer: { + case Switch_Kind_Integer: + case Switch_Kind_Union: { b32 is_valid; i64 integer_value = get_expression_integer_value(*value, &is_valid); if (!is_valid) @@ -549,6 +585,11 @@ CheckStatus check_switch(AstSwitch* switchnode) { } } + sc->flags |= Ast_Flag_Has_Been_Checked; + + check_switch_case_block: + CHECK(block, sc->block); + switchnode->yield_return_index += 1; } diff --git a/compiler/src/lex.c b/compiler/src/lex.c index 4ce31361..40d88478 100644 --- a/compiler/src/lex.c +++ b/compiler/src/lex.c @@ -39,6 +39,7 @@ static const char* token_type_names[] = { "", // end "->", + "=>", "<-", "---", "|>", @@ -422,6 +423,7 @@ whitespace_skipped: case '=': LITERAL_TOKEN("==", 0, Token_Type_Equal_Equal); + LITERAL_TOKEN("=>", 0, Token_Type_Fat_Right_Arrow); break; case '!': diff --git a/compiler/src/parser.c b/compiler/src/parser.c index 8ad16fb6..e5564092 100644 --- a/compiler/src/parser.c +++ b/compiler/src/parser.c @@ -336,7 +336,7 @@ static void parse_arguments(OnyxParser* parser, TokenType end_token, Arguments* // This shouldn't be a named argument, but this should: // f(g = x => x + 1) // - if (next_tokens_are(parser, 2, Token_Type_Symbol, '=') && peek_token(2)->type != '>') { + if (next_tokens_are(parser, 2, Token_Type_Symbol, '=')) { OnyxToken* name = expect_token(parser, Token_Type_Symbol); expect_token(parser, '='); @@ -1319,6 +1319,19 @@ static AstSwitchCase* parse_case_stmt(OnyxParser* parser) { } } + if (consume_token_if_next(parser, Token_Type_Fat_Right_Arrow)) { + // Captured value for union switching + b32 is_pointer = 0; + if (consume_token_if_next(parser, '&') || consume_token_if_next(parser, '^')) + is_pointer = 1; + + OnyxToken *capture_symbol = expect_token(parser, Token_Type_Symbol); + AstLocal *capture = make_local(parser->allocator, capture_symbol, NULL); + + sc_node->capture = capture; + sc_node->capture_is_by_pointer = is_pointer; + } + sc_node->block = parse_block(parser, 1, NULL); return sc_node; @@ -2705,8 +2718,7 @@ static AstFunction* parse_function_definition(OnyxParser* parser, OnyxToken* tok name = bh_aprintf(global_heap_allocator, "%b", current_symbol->text, current_symbol->length); } - if (consume_token_if_next(parser, '=')) { - expect_token(parser, '>'); + if (consume_token_if_next(parser, Token_Type_Fat_Right_Arrow)) { func_def->return_type = (AstType *) &basic_type_auto_return; if (parser->curr->type == '{') { @@ -2798,7 +2810,7 @@ static b32 parse_possible_function_definition_no_consume(OnyxParser* parser) { OnyxToken* matching_paren = find_matching_paren(parser->curr); if (matching_paren == NULL) return 0; - if (next_tokens_are(parser, 4, '(', ')', '=', '>')) return 0; + if (next_tokens_are(parser, 3, '(', ')', Token_Type_Fat_Right_Arrow)) return 0; // :LinearTokenDependent OnyxToken* token_after_paren = matching_paren + 1; @@ -2807,7 +2819,7 @@ static b32 parse_possible_function_definition_no_consume(OnyxParser* parser) { && token_after_paren->type != Token_Type_Keyword_Do && token_after_paren->type != Token_Type_Empty_Block && token_after_paren->type != Token_Type_Keyword_Where - && (token_after_paren->type != '=' || (token_after_paren + 1)->type != '>')) + && token_after_paren->type != Token_Type_Fat_Right_Arrow) return 0; // :LinearTokenDependent @@ -2849,7 +2861,7 @@ typedef struct QuickParam { static b32 parse_possible_quick_function_definition_no_consume(OnyxParser* parser) { // // x => x + 1 case. - if (next_tokens_are(parser, 3, Token_Type_Symbol, '=', '>')) { + if (next_tokens_are(parser, 2, Token_Type_Symbol, Token_Type_Fat_Right_Arrow)) { return 1; } @@ -2860,7 +2872,7 @@ static b32 parse_possible_quick_function_definition_no_consume(OnyxParser* parse // :LinearTokenDependent OnyxToken* token_after_paren = matching_paren + 1; - if (token_after_paren->type != '=' || (token_after_paren + 1)->type != '>') + if (token_after_paren->type != Token_Type_Fat_Right_Arrow) return 0; return 1; @@ -2903,8 +2915,7 @@ static b32 parse_possible_quick_function_definition(OnyxParser* parser, AstTyped } } - expect_token(parser, '='); - expect_token(parser, '>'); + expect_token(parser, Token_Type_Fat_Right_Arrow); bh_arr(AstNode *) poly_params=NULL; bh_arr_new(global_heap_allocator, poly_params, bh_arr_length(params)); diff --git a/compiler/src/symres.c b/compiler/src/symres.c index 086cc23a..b00a609a 100644 --- a/compiler/src/symres.c +++ b/compiler/src/symres.c @@ -815,6 +815,11 @@ static SymresStatus symres_case(AstSwitchCase *casenode) { } } + if (casenode->capture) { + casenode->block->scope = scope_create(context.ast_alloc, current_scope, casenode->block->token->pos); + symbol_introduce(casenode->block->scope, casenode->capture->token, (AstNode *) casenode->capture); + } + SYMRES(block, casenode->block); return Symres_Success; } diff --git a/compiler/src/types.c b/compiler/src/types.c index cd2a4082..6fc88817 100644 --- a/compiler/src/types.c +++ b/compiler/src/types.c @@ -689,7 +689,9 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b type_register(u_type); u_type->Union.variants = NULL; + u_type->Union.variants_ordered = NULL; sh_new_arena(u_type->Union.variants); + bh_arr_new(global_heap_allocator, u_type->Union.variants_ordered, bh_arr_length(union_->variants)); } else { u_type = union_->pending_type; } @@ -770,6 +772,8 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b shput(u_type->Union.variants, variant->token->text, uv); token_toggle_end(variant->token); + bh_arr_push(u_type->Union.variants_ordered, uv); + AstEnumValue *ev = onyx_ast_node_new(alloc, sizeof(AstEnumValue), Ast_Kind_Enum_Value); ev->token = uv->token; ev->value = (AstTyped *) make_int_literal(alloc, uv->tag_value); diff --git a/compiler/src/wasm_emit.c b/compiler/src/wasm_emit.c index 2ec076f1..48281539 100644 --- a/compiler/src/wasm_emit.c +++ b/compiler/src/wasm_emit.c @@ -1625,8 +1625,11 @@ EMIT_FUNC(switch, AstSwitch* switch_node) { block_num++; } + u64 union_capture_idx = 0; + switch (switch_node->switch_kind) { - case Switch_Kind_Integer: { + case Switch_Kind_Integer: + case Switch_Kind_Union: { u64 count = switch_node->max_case + 1 - switch_node->min_case; BranchTable* bt = bh_alloc(mod->extended_instr_alloc, sizeof(BranchTable) + sizeof(u32) * count); bt->count = count; @@ -1651,6 +1654,14 @@ EMIT_FUNC(switch, AstSwitch* switch_node) { // to the second block, and so on. WID(switch_node->expr->token, WI_BLOCK_START, 0x40); emit_expression(mod, &code, switch_node->expr); + + if (switch_node->switch_kind == Switch_Kind_Union) { + assert(switch_node->expr->type->kind == Type_Kind_Union); + union_capture_idx = local_raw_allocate(mod->local_alloc, WASM_TYPE_PTR); + WIL(NULL, WI_LOCAL_TEE, union_capture_idx); + emit_load_instruction(mod, &code, switch_node->expr->type->Union.tag_type, 0); + } + if (switch_node->min_case != 0) { if (onyx_type_to_wasm_type(switch_node->expr->type) == WASM_TYPE_INT64) { WI(switch_node->expr->token, WI_I32_FROM_I64); @@ -1688,6 +1699,34 @@ EMIT_FUNC(switch, AstSwitch* switch_node) { u64 bn = bh_imap_get(&block_map, (u64) sc->block); + if (sc->capture) { + assert(union_capture_idx != 0); + + if (sc->capture_is_by_pointer) { + u64 capture_pointer_local = emit_local_allocation(mod, &code, sc->capture); + + WIL(NULL, WI_LOCAL_GET, union_capture_idx); + WIL(NULL, WI_PTR_CONST, switch_node->expr->type->Union.alignment); + WI(NULL, WI_PTR_ADD); + + WIL(NULL, WI_LOCAL_SET, capture_pointer_local); + + } else { + sc->capture->flags |= Ast_Flag_Decl_Followed_By_Init; + sc->capture->flags |= Ast_Flag_Address_Taken; + + emit_local_allocation(mod, &code, sc->capture); + emit_location(mod, &code, sc->capture); + + WIL(NULL, WI_LOCAL_GET, union_capture_idx); + WIL(NULL, WI_PTR_CONST, switch_node->expr->type->Union.alignment); + WI(NULL, WI_PTR_ADD); + + WIL(NULL, WI_I32_CONST, type_size_of(sc->capture->type)); + emit_wasm_copy(mod, &code, NULL); + } + } + // Maybe the Symbol Frame idea should be controlled as a block_flag? debug_enter_symbol_frame(mod); emit_block(mod, &code, sc->block, 0); @@ -1705,6 +1744,7 @@ EMIT_FUNC(switch, AstSwitch* switch_node) { emit_block(mod, &code, switch_node->default_case, 0); } + if (union_capture_idx != 0) local_raw_free(mod->local_alloc, WASM_TYPE_PTR); emit_leave_structured_block(mod, &code); bh_imap_free(&block_map); diff --git a/tests/tagged_unions.onyx b/tests/tagged_unions.onyx index 5bfb2fa4..f616c301 100644 --- a/tests/tagged_unions.onyx +++ b/tests/tagged_unions.onyx @@ -15,11 +15,18 @@ call_test :: (u_: SimpleUnion) { u := u_; __byte_dump(&u, sizeof typeof u); - switch cast(SimpleUnion.tag_enum) u { + switch u { case .a do println("It was a!"); - case .b do println("It was B!"); + case .b => v do printf("It was B! {}\n", v); case .empty do println("It was EMPTY!"); - case .large do println("It was a large number!"); + + case .large => &value { + printf("It was a large number! {} {}\n", value, *value); + } + + case #default { + printf("none of the above.\n"); + } } } @@ -27,14 +34,15 @@ simple_test :: () { u := SimpleUnion.{ b = .{ "asdf" } }; u = .{ a = 123 }; u = .{ empty = .{} }; - u = .{ large = 0x0123456789abcdef }; + u = .{ large = 123456789 }; + u = .{ b = .{ "Wow this works!!" } }; println(cast(SimpleUnion.tag_enum) u); call_test(u); } -main :: () {simple_test();} //link_test();} +main :: () {simple_test(); link_test();} Link :: union { End: void; @@ -44,15 +52,15 @@ Link :: union { } } -/* print_links :: (l: Link) { walker := l; while true { switch walker { - case Link.End do break; - case Link.Next => &next { - printf("{} ", next.data); - walker = next.next; + case .End do break break; + + case .Next => &next { + printf("{}\n", next.data); + walker = *next.next; } } } @@ -63,14 +71,16 @@ link_test :: () { Next = .{ data = 123, next = &Link.{ - End = .{} + Next = .{ + data = 456, + next = &Link.{ End = .{} }, + } } } }; print_links(l); } -*/ // main :: () { link_test(); }