From a5c480690a0c43765325f4d6b204a24834f00079 Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Tue, 30 Nov 2021 10:23:34 -0600 Subject: [PATCH] switch statements can use '==' --- core/type_info/helper.onyx | 16 ++++++- include/astnodes.h | 35 ++++++++++++--- src/astnodes.c | 17 ++++--- src/checker.c | 82 +++++++++++++++++++++++----------- src/symres.c | 6 +++ src/types.c | 2 +- src/wasm_emit.c | 82 +++++++++++++++++++++------------- tests/switch_using_equals | 4 ++ tests/switch_using_equals.onyx | 25 +++++++++++ 9 files changed, 202 insertions(+), 67 deletions(-) create mode 100644 tests/switch_using_equals create mode 100644 tests/switch_using_equals.onyx diff --git a/core/type_info/helper.onyx b/core/type_info/helper.onyx index 213c2407..d6ed6a5f 100644 --- a/core/type_info/helper.onyx +++ b/core/type_info/helper.onyx @@ -125,7 +125,6 @@ write_type_name :: (writer: ^io.Writer, t: type_expr) { } } - offset_of :: (T: type_expr, member: str) -> u32 { info := get_type_info(T); if info == null do return 0; @@ -140,6 +139,21 @@ offset_of :: (T: type_expr, member: str) -> u32 { return 0; } +get_tags_for_member :: (S: type_expr, member_name: str) -> [] any { + use type_info; + + ti := get_type_info(S); + if ti.kind != .Struct do return .[]; + + for ^ (cast(^Type_Info_Struct) ti).members { + if it.name == member_name { + return it.tags; + } + } + + return .[]; +} + struct_constructed_from :: (struct_type: type_expr, base_type: type_expr) -> bool { struct_info := get_type_info(struct_type); if struct_info.kind != .Struct do return false; diff --git a/include/astnodes.h b/include/astnodes.h index e1f82cc7..581391a9 100644 --- a/include/astnodes.h +++ b/include/astnodes.h @@ -772,12 +772,24 @@ struct AstIfWhile { typedef struct AstIfWhile AstIf; typedef struct AstIfWhile AstWhile; +typedef enum SwitchKind { + Switch_Kind_Integer, + Switch_Kind_Use_Equals, +} SwitchKind; + +typedef struct CaseToBlock { + AstTyped *original_value; + AstBinaryOp *comparison; + AstBlock *block; +} CaseToBlock; + struct AstSwitchCase { // NOTE: All expressions that end up in this block bh_arr(AstTyped *) values; AstBlock *block; }; + struct AstSwitch { AstNode_base; @@ -789,10 +801,23 @@ struct AstSwitch { bh_arr(AstSwitchCase) cases; AstBlock *default_case; - // NOTE: This is a mapping from the compile time known case value - // to a pointer to the block that it is associated with. - bh_imap case_map; - u64 min_case, max_case; + i32 yield_return_index; + SwitchKind switch_kind; + + union { + struct { + // NOTE: This is a mapping from the compile time known case value + // to a pointer to the block that it is associated with. + bh_imap case_map; + u64 min_case, max_case; + }; + + struct { + // NOTE: This is a mapping from the '==' binary op node to + // a pointer to the block that it is associated with. + bh_arr(CaseToBlock) case_exprs; + }; + }; }; // Type Nodes @@ -1539,7 +1564,7 @@ typedef enum TypeMatch { #define unify_node_and_type(node, type) (unify_node_and_type_((node), (type), 1)) TypeMatch unify_node_and_type_(AstTyped** pnode, Type* type, b32 permanent); Type* resolve_expression_type(AstTyped* node); -i64 get_expression_integer_value(AstTyped* node); +i64 get_expression_integer_value(AstTyped* node, b32 *out_is_valid); b32 cast_is_legal(Type* from_, Type* to_, char** err_msg); char* get_function_name(AstFunction* func); diff --git a/src/astnodes.c b/src/astnodes.c index 5ec5420e..77fe25f8 100644 --- a/src/astnodes.c +++ b/src/astnodes.c @@ -800,9 +800,11 @@ Type* resolve_expression_type(AstTyped* node) { return node->type; } -i64 get_expression_integer_value(AstTyped* node) { +i64 get_expression_integer_value(AstTyped* node, b32 *is_valid) { resolve_expression_type(node); + if (is_valid) *is_valid = 1; + if (node->kind == Ast_Kind_NumLit && type_is_integer(node->type)) { return ((AstNumLit *) node)->value.l; } @@ -812,7 +814,7 @@ i64 get_expression_integer_value(AstTyped* node) { } if (node->kind == Ast_Kind_Argument) { - return get_expression_integer_value(((AstArgument *) node)->value); + return get_expression_integer_value(((AstArgument *) node)->value, is_valid); } if (node->kind == Ast_Kind_Size_Of) { @@ -824,14 +826,19 @@ i64 get_expression_integer_value(AstTyped* node) { } if (node->kind == Ast_Kind_Alias) { - return get_expression_integer_value(((AstAlias *) node)->alias); + return get_expression_integer_value(((AstAlias *) node)->alias, is_valid); + } + + if (node->kind == Ast_Kind_Enum_Value) { + return get_expression_integer_value(((AstEnumValue *) node)->value, is_valid); } if (node_is_type((AstNode*) node)) { Type* type = type_build_from_ast(context.ast_alloc, (AstType *) node); - return type->id; + if (type) return type->id; } + if (is_valid) *is_valid = 0; return 0; } @@ -1227,7 +1234,7 @@ b32 static_if_resolution(AstIf* static_if) { if (static_if->kind != Ast_Kind_Static_If) return 0; // assert(condition_value->kind == Ast_Kind_NumLit); // This should be right, right? - i64 value = get_expression_integer_value(static_if->cond); + i64 value = get_expression_integer_value(static_if->cond, NULL); return value != 0; } diff --git a/src/checker.c b/src/checker.c index 7b6df3f3..37a52981 100644 --- a/src/checker.c +++ b/src/checker.c @@ -294,6 +294,8 @@ CheckStatus check_for(AstFor* fornode) { } static b32 add_case_to_switch_statement(AstSwitch* switchnode, u64 case_value, AstBlock* block, OnyxFilePos pos) { + assert(switchnode->switch_kind == Switch_Kind_Integer); + switchnode->min_case = bh_min(switchnode->min_case, case_value); switchnode->max_case = bh_max(switchnode->max_case, case_value); @@ -311,23 +313,39 @@ CheckStatus check_switch(AstSwitch* switchnode) { CHECK(expression, &switchnode->expr); Type* resolved_expr_type = resolve_expression_type(switchnode->expr); - if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) { - ERROR(switchnode->expr->token->pos, "expected integer or enum type for switch expression"); - } - // LEAK if this has to be yielded - bh_imap_init(&switchnode->case_map, global_heap_allocator, bh_arr_length(switchnode->cases) * 2); + if (!(switchnode->flags & Ast_Flag_Has_Been_Checked)) { + if (resolved_expr_type == NULL) YIELD(switchnode->token->pos, "Waiting for expression type to be known."); + + switchnode->switch_kind = Switch_Kind_Integer; + if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) { + switchnode->switch_kind = Switch_Kind_Use_Equals; + } + + switch (switchnode->switch_kind) { + case Switch_Kind_Integer: + switchnode->min_case = 0xffffffffffffffff; + bh_imap_init(&switchnode->case_map, global_heap_allocator, bh_arr_length(switchnode->cases) * 2); + break; + + case Switch_Kind_Use_Equals: + // Guessing the maximum number of case expressions there will be. + bh_arr_new(global_heap_allocator, switchnode->case_exprs, bh_arr_length(switchnode->cases) * 2); + break; - switchnode->min_case = 0xffffffffffffffff; + default: assert(0); + } + } + switchnode->flags |= Ast_Flag_Has_Been_Checked; - // Umm, this doesn't check the type of the case expression to the type of the expression - bh_arr_each(AstSwitchCase, sc, switchnode->cases) { + fori (i, switchnode->yield_return_index, bh_arr_length(switchnode->cases)) { + AstSwitchCase *sc = &switchnode->cases[i]; CHECK(block, sc->block); bh_arr_each(AstTyped *, value, sc->values) { CHECK(expression, value); - if ((*value)->kind == Ast_Kind_Range_Literal) { + if (switchnode->switch_kind == Switch_Kind_Integer && (*value)->kind == Ast_Kind_Range_Literal) { AstRangeLiteral* rl = (AstRangeLiteral *) (*value); resolve_expression_type(rl->low); resolve_expression_type(rl->high); @@ -359,29 +377,43 @@ CheckStatus check_switch(AstSwitch* switchnode) { type_get_name(resolved_expr_type), type_get_name((*value)->type)); } - if (node_is_type((AstNode*) (*value))) { - Type* type = type_build_from_ast(context.ast_alloc, (AstType*) (*value)); + switch (switchnode->switch_kind) { + case Switch_Kind_Integer: { + b32 is_valid; + i64 integer_value = get_expression_integer_value(*value, &is_valid); + if (!is_valid) + ERROR_((*value)->token->pos, "Case statement expected compile time known integer. Got '%s'.", onyx_ast_node_kind_string((*value)->kind)); - if (add_case_to_switch_statement(switchnode, type->id, sc->block, sc->block->token->pos)) - return Check_Error; + if (add_case_to_switch_statement(switchnode, integer_value, sc->block, sc->block->token->pos)) + return Check_Error; - continue; - } + break; + } - if ((*value)->kind == Ast_Kind_Enum_Value) { - (*value) = (AstTyped *) ((AstEnumValue *) (*value))->value; - } + case Switch_Kind_Use_Equals: { + bh_arr_each(CaseToBlock, ctb, switchnode->case_exprs) { + if (ctb->original_value == *value) { + CHECK(expression, (AstTyped **) &ctb->comparison); + goto value_checked; + } + } - if ((*value)->kind != Ast_Kind_NumLit) { - ERROR((*value)->token->pos, "case statement expected compile time known integer"); - } + CaseToBlock ctb; + ctb.block = sc->block; + ctb.original_value = *value; + ctb.comparison = make_binary_op(context.ast_alloc, Binary_Op_Equal, switchnode->expr, *value); + ctb.comparison->token = (*value)->token; + bh_arr_push(switchnode->case_exprs, ctb); - resolve_expression_type((*value)); - // promote_numlit_to_larger((AstNumLit *) (*value)); + CHECK(binaryop, &bh_arr_last(switchnode->case_exprs).comparison); + break; + } + } - if (add_case_to_switch_statement(switchnode, ((AstNumLit *) (*value))->value.l, sc->block, sc->block->token->pos)) - return Check_Error; + value_checked: } + + switchnode->yield_return_index += 1; } if (switchnode->default_case) diff --git a/src/symres.c b/src/symres.c index c6345789..24b341bf 100644 --- a/src/symres.c +++ b/src/symres.c @@ -618,6 +618,12 @@ static SymresStatus symres_switch(AstSwitch* switchnode) { if (switchnode->default_case) SYMRES(block, switchnode->default_case); + if (switchnode->switch_kind == Switch_Kind_Use_Equals && switchnode->case_exprs) { + bh_arr_each(CaseToBlock, ctb, switchnode->case_exprs) { + SYMRES(expression, (AstTyped **) &ctb->comparison); + } + } + if (switchnode->initialization != NULL) scope_leave(); return Symres_Success; diff --git a/src/types.c b/src/types.c index 2a3a5d92..aeaf3c93 100644 --- a/src/types.c +++ b/src/types.c @@ -324,7 +324,7 @@ Type* type_build_from_ast(bh_allocator alloc, AstType* type_node) { return NULL; } - count = get_expression_integer_value(a_node->count_expr); + count = get_expression_integer_value(a_node->count_expr, NULL); } a_type->Array.count = count; diff --git a/src/wasm_emit.c b/src/wasm_emit.c index ad526f5d..a624a905 100644 --- a/src/wasm_emit.c +++ b/src/wasm_emit.c @@ -1140,36 +1140,58 @@ EMIT_FUNC(switch, AstSwitch* switch_node) { block_num++; } - u64 count = switch_node->max_case + 1 - switch_node->min_case; - BranchTable* bt = bh_alloc(mod->extended_instr_alloc, sizeof(BranchTable) + sizeof(u32) * count); - bt->count = count; - bt->default_case = block_num; - fori (i, 0, bt->count) bt->cases[i] = bt->default_case; - - bh_arr_each(bh__imap_entry, sc, switch_node->case_map.entries) { - bt->cases[sc->key - switch_node->min_case] = bh_imap_get(&block_map, (u64) sc->value); - } - - // CLEANUP: We enter a new block here in order to setup the correct - // indicies for the jump targets in the branch table. For example, - // - // - // jump_table - // label0: - // ... - // label1: - // ... - // - // If we didn't enter a new block, then jumping to label 0, would jump - // to the second block, and so on. - WID(WI_BLOCK_START, 0x40); - emit_expression(mod, &code, switch_node->expr); - if (switch_node->min_case != 0) { - WID(WI_I32_CONST, switch_node->min_case); - WI(WI_I32_SUB); - } - WIP(WI_JUMP_TABLE, bt); - WI(WI_BLOCK_END); + switch (switch_node->switch_kind) { + case Switch_Kind_Integer: { + u64 count = switch_node->max_case + 1 - switch_node->min_case; + BranchTable* bt = bh_alloc(mod->extended_instr_alloc, sizeof(BranchTable) + sizeof(u32) * count); + bt->count = count; + bt->default_case = block_num; + fori (i, 0, bt->count) bt->cases[i] = bt->default_case; + + bh_arr_each(bh__imap_entry, sc, switch_node->case_map.entries) { + bt->cases[sc->key - switch_node->min_case] = bh_imap_get(&block_map, (u64) sc->value); + } + + // NOTE: We enter a new block here in order to setup the correct + // indicies for the jump targets in the branch table. For example, + // + // + // jump_table + // label0: + // ... + // label1: + // ... + // + // If we didn't enter a new block, then jumping to label 0, would jump + // to the second block, and so on. + WID(WI_BLOCK_START, 0x40); + emit_expression(mod, &code, switch_node->expr); + if (switch_node->min_case != 0) { + WID(WI_I32_CONST, switch_node->min_case); + WI(WI_I32_SUB); + } + WIP(WI_JUMP_TABLE, bt); + WI(WI_BLOCK_END); + break; + } + + case Switch_Kind_Use_Equals: { + WID(WI_BLOCK_START, 0x40); + + bh_arr_each(CaseToBlock, ctb, switch_node->case_exprs) { + emit_expression(mod, &code, (AstTyped *) ctb->comparison); + + u64 bn = bh_imap_get(&block_map, (u64) ctb->block); + WID(WI_IF_START, 0x40); + WID(WI_JUMP, bn + 1); + WI(WI_IF_END); + } + + WID(WI_JUMP, block_num); + WI(WI_BLOCK_END); + break; + } + } bh_arr_each(AstSwitchCase, sc, switch_node->cases) { if (bh_imap_get(&block_map, (u64) sc->block) == 0xdeadbeef) continue; diff --git a/tests/switch_using_equals b/tests/switch_using_equals new file mode 100644 index 00000000..81598e83 --- /dev/null +++ b/tests/switch_using_equals @@ -0,0 +1,4 @@ +Got some! +Got thing! +Got default! +1, 0 diff --git a/tests/switch_using_equals.onyx b/tests/switch_using_equals.onyx new file mode 100644 index 00000000..803c6540 --- /dev/null +++ b/tests/switch_using_equals.onyx @@ -0,0 +1,25 @@ +#load "core/std" + +use package core + +Vector2 :: struct { x, y: i32; } +#operator == macro (v1: Vector2, v2: Vector2) => v1.x == v2.x && v1.y == v2.y; + +main :: (args: [] cstr) { + for .[ "Some", "Thing", "Other" ] { + switch it { + case "Thing" do println("Got thing!"); + case "Some" do println("Got some!"); + case #default do println("Got default!"); + } + } + + v := Vector2.{ 1, 0 }; + switch v { + case .{ 0, 0 } do println("0, 0"); + case .{ 0, 1 } do println("0, 1"); + case .{ 1, 0 } do println("1, 0"); + case .{ 1, 1 } do println("1, 1"); + case #default do println("none of the above."); + } +} \ No newline at end of file -- 2.25.1