switch statements can use '=='
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Tue, 30 Nov 2021 16:23:34 +0000 (10:23 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Tue, 30 Nov 2021 16:23:34 +0000 (10:23 -0600)
core/type_info/helper.onyx
include/astnodes.h
src/astnodes.c
src/checker.c
src/symres.c
src/types.c
src/wasm_emit.c
tests/switch_using_equals [new file with mode: 0644]
tests/switch_using_equals.onyx [new file with mode: 0644]

index 213c24079ad7a5be0029f2e0facbdc88efcc1d6b..d6ed6a5fe688e944547eed378cad8d1a6b761862 100644 (file)
@@ -125,7 +125,6 @@ write_type_name :: (writer: ^io.Writer, t: type_expr) {
     }
 }
 
-
 offset_of :: (T: type_expr, member: str) -> u32 {
     info := get_type_info(T);
     if info == null         do return 0;
@@ -140,6 +139,21 @@ offset_of :: (T: type_expr, member: str) -> u32 {
     return 0;
 }
 
+get_tags_for_member :: (S: type_expr, member_name: str) -> [] any {
+    use type_info;
+    
+    ti := get_type_info(S);
+    if ti.kind != .Struct do return .[];
+    
+    for ^ (cast(^Type_Info_Struct) ti).members {
+        if it.name == member_name {
+            return it.tags;
+        }
+    }
+    
+    return .[];
+}
+
 struct_constructed_from :: (struct_type: type_expr, base_type: type_expr) -> bool {
     struct_info := get_type_info(struct_type);
     if struct_info.kind != .Struct do return false;
index e1f82cc7623beb1b7bef9de9be48f50bd1a08203..581391a99beb03e1fc5aaca5a19c44bc996f6acc 100644 (file)
@@ -772,12 +772,24 @@ struct AstIfWhile {
 typedef struct AstIfWhile AstIf;
 typedef struct AstIfWhile AstWhile;
 
+typedef enum SwitchKind {
+    Switch_Kind_Integer,
+    Switch_Kind_Use_Equals,
+} SwitchKind;
+
+typedef struct CaseToBlock {
+    AstTyped *original_value;
+    AstBinaryOp *comparison;
+    AstBlock *block;
+} CaseToBlock;
+
 struct AstSwitchCase {
     // NOTE: All expressions that end up in this block
     bh_arr(AstTyped *) values;
     
     AstBlock *block;
 };
+
 struct AstSwitch {
     AstNode_base;
 
@@ -789,10 +801,23 @@ struct AstSwitch {
     bh_arr(AstSwitchCase) cases;
     AstBlock *default_case;
 
-    // NOTE: This is a mapping from the compile time known case value
-    // to a pointer to the block that it is associated with.
-    bh_imap case_map;
-    u64 min_case, max_case;
+    i32 yield_return_index;
+    SwitchKind switch_kind;
+
+    union {
+        struct {
+            // NOTE: This is a mapping from the compile time known case value
+            // to a pointer to the block that it is associated with.
+            bh_imap case_map;
+            u64 min_case, max_case;
+        };
+
+        struct {
+            // NOTE: This is a mapping from the '==' binary op node to
+            // a pointer to the block that it is associated with.
+            bh_arr(CaseToBlock) case_exprs;
+        };
+    };
 };
 
 // Type Nodes
@@ -1539,7 +1564,7 @@ typedef enum TypeMatch {
 #define unify_node_and_type(node, type) (unify_node_and_type_((node), (type), 1))
 TypeMatch unify_node_and_type_(AstTyped** pnode, Type* type, b32 permanent);
 Type* resolve_expression_type(AstTyped* node);
-i64 get_expression_integer_value(AstTyped* node);
+i64 get_expression_integer_value(AstTyped* node, b32 *out_is_valid);
 
 b32 cast_is_legal(Type* from_, Type* to_, char** err_msg);
 char* get_function_name(AstFunction* func);
index 5ec5420eac25fef687e57ff9847e72a404a3d222..77fe25f81ab43ce865f50d629bf7eece3780e408 100644 (file)
@@ -800,9 +800,11 @@ Type* resolve_expression_type(AstTyped* node) {
     return node->type;
 }
 
-i64 get_expression_integer_value(AstTyped* node) {
+i64 get_expression_integer_value(AstTyped* node, b32 *is_valid) {
     resolve_expression_type(node);
 
+    if (is_valid) *is_valid = 1;
+
     if (node->kind == Ast_Kind_NumLit && type_is_integer(node->type)) {
         return ((AstNumLit *) node)->value.l;
     }
@@ -812,7 +814,7 @@ i64 get_expression_integer_value(AstTyped* node) {
     }
 
     if (node->kind == Ast_Kind_Argument) {
-        return get_expression_integer_value(((AstArgument *) node)->value);
+        return get_expression_integer_value(((AstArgument *) node)->value, is_valid);
     }
 
     if (node->kind == Ast_Kind_Size_Of) {
@@ -824,14 +826,19 @@ i64 get_expression_integer_value(AstTyped* node) {
     }
 
     if (node->kind == Ast_Kind_Alias) {
-        return get_expression_integer_value(((AstAlias *) node)->alias);
+        return get_expression_integer_value(((AstAlias *) node)->alias, is_valid);
+    }
+
+    if (node->kind == Ast_Kind_Enum_Value) {
+        return get_expression_integer_value(((AstEnumValue *) node)->value, is_valid);
     }
 
     if (node_is_type((AstNode*) node)) {
         Type* type = type_build_from_ast(context.ast_alloc, (AstType *) node);
-        return type->id;
+        if (type) return type->id;
     }
 
+    if (is_valid) *is_valid = 0;
     return 0;
 }
 
@@ -1227,7 +1234,7 @@ b32 static_if_resolution(AstIf* static_if) {
     if (static_if->kind != Ast_Kind_Static_If) return 0;
 
     // assert(condition_value->kind == Ast_Kind_NumLit); // This should be right, right?
-    i64 value = get_expression_integer_value(static_if->cond);
+    i64 value = get_expression_integer_value(static_if->cond, NULL);
 
     return value != 0;
 }
index 7b6df3f39a880343293ed1ebed630f983c4db8dc..37a52981df50930dce5f2ad28a9ef7fcbbc5bdcf 100644 (file)
@@ -294,6 +294,8 @@ CheckStatus check_for(AstFor* fornode) {
 }
 
 static b32 add_case_to_switch_statement(AstSwitch* switchnode, u64 case_value, AstBlock* block, OnyxFilePos pos) {
+    assert(switchnode->switch_kind == Switch_Kind_Integer);
+
     switchnode->min_case = bh_min(switchnode->min_case, case_value);
     switchnode->max_case = bh_max(switchnode->max_case, case_value);
 
@@ -311,23 +313,39 @@ CheckStatus check_switch(AstSwitch* switchnode) {
 
     CHECK(expression, &switchnode->expr);
     Type* resolved_expr_type = resolve_expression_type(switchnode->expr);
-    if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) {
-        ERROR(switchnode->expr->token->pos, "expected integer or enum type for switch expression");
-    }
 
-    // LEAK if this has to be yielded
-    bh_imap_init(&switchnode->case_map, global_heap_allocator, bh_arr_length(switchnode->cases) * 2);
+    if (!(switchnode->flags & Ast_Flag_Has_Been_Checked)) {
+        if (resolved_expr_type == NULL) YIELD(switchnode->token->pos, "Waiting for expression type to be known.");
+
+        switchnode->switch_kind = Switch_Kind_Integer;
+        if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) {
+            switchnode->switch_kind = Switch_Kind_Use_Equals;
+        }
+
+        switch (switchnode->switch_kind) {
+            case Switch_Kind_Integer:
+                switchnode->min_case = 0xffffffffffffffff;
+                bh_imap_init(&switchnode->case_map, global_heap_allocator, bh_arr_length(switchnode->cases) * 2);
+                break;
+
+            case Switch_Kind_Use_Equals:
+                // Guessing the maximum number of case expressions there will be.
+                bh_arr_new(global_heap_allocator, switchnode->case_exprs, bh_arr_length(switchnode->cases) * 2);
+                break;
 
-    switchnode->min_case = 0xffffffffffffffff;
+            default: assert(0);
+        }
+    }
+    switchnode->flags |= Ast_Flag_Has_Been_Checked;
 
-    // Umm, this doesn't check the type of the case expression to the type of the expression
-    bh_arr_each(AstSwitchCase, sc, switchnode->cases) {
+    fori (i, switchnode->yield_return_index, bh_arr_length(switchnode->cases)) {
+        AstSwitchCase *sc = &switchnode->cases[i];
         CHECK(block, sc->block);
 
         bh_arr_each(AstTyped *, value, sc->values) {
             CHECK(expression, value);
 
-            if ((*value)->kind == Ast_Kind_Range_Literal) {
+            if (switchnode->switch_kind == Switch_Kind_Integer && (*value)->kind == Ast_Kind_Range_Literal) {
                 AstRangeLiteral* rl = (AstRangeLiteral *) (*value);
                 resolve_expression_type(rl->low);
                 resolve_expression_type(rl->high);
@@ -359,29 +377,43 @@ CheckStatus check_switch(AstSwitch* switchnode) {
                     type_get_name(resolved_expr_type), type_get_name((*value)->type));
             }
 
-            if (node_is_type((AstNode*) (*value))) {
-                Type* type = type_build_from_ast(context.ast_alloc, (AstType*) (*value));
+            switch (switchnode->switch_kind) {
+                case Switch_Kind_Integer: {
+                    b32 is_valid;
+                    i64 integer_value = get_expression_integer_value(*value, &is_valid);
+                    if (!is_valid)
+                        ERROR_((*value)->token->pos, "Case statement expected compile time known integer. Got '%s'.", onyx_ast_node_kind_string((*value)->kind));
 
-                if (add_case_to_switch_statement(switchnode, type->id, sc->block, sc->block->token->pos))
-                    return Check_Error;
+                    if (add_case_to_switch_statement(switchnode, integer_value, sc->block, sc->block->token->pos))
+                        return Check_Error;
 
-                continue;
-            }
+                    break;
+                }
 
-            if ((*value)->kind == Ast_Kind_Enum_Value) {
-                (*value) = (AstTyped *) ((AstEnumValue *) (*value))->value;
-            }
+                case Switch_Kind_Use_Equals: {
+                    bh_arr_each(CaseToBlock, ctb, switchnode->case_exprs) {
+                        if (ctb->original_value == *value) {
+                            CHECK(expression, (AstTyped **) &ctb->comparison);
+                            goto value_checked;
+                        }
+                    }
 
-            if ((*value)->kind != Ast_Kind_NumLit) {
-                ERROR((*value)->token->pos, "case statement expected compile time known integer");
-            }
+                    CaseToBlock ctb;
+                    ctb.block = sc->block;
+                    ctb.original_value = *value;
+                    ctb.comparison = make_binary_op(context.ast_alloc, Binary_Op_Equal, switchnode->expr, *value);
+                    ctb.comparison->token = (*value)->token;
+                    bh_arr_push(switchnode->case_exprs, ctb);
 
-            resolve_expression_type((*value));
-            // promote_numlit_to_larger((AstNumLit *) (*value));
+                    CHECK(binaryop, &bh_arr_last(switchnode->case_exprs).comparison);
+                    break;
+                }
+            }
 
-            if (add_case_to_switch_statement(switchnode, ((AstNumLit *) (*value))->value.l, sc->block, sc->block->token->pos))
-                return Check_Error;
+        value_checked:
         }
+
+        switchnode->yield_return_index += 1;
     }
 
     if (switchnode->default_case)
index c6345789d50f06f980668cee6514e4ae2419439a..24b341bf82be462913c1979d20aedd066b434354 100644 (file)
@@ -618,6 +618,12 @@ static SymresStatus symres_switch(AstSwitch* switchnode) {
     if (switchnode->default_case)
         SYMRES(block, switchnode->default_case);
 
+    if (switchnode->switch_kind == Switch_Kind_Use_Equals && switchnode->case_exprs) {
+        bh_arr_each(CaseToBlock, ctb, switchnode->case_exprs) {
+            SYMRES(expression, (AstTyped **) &ctb->comparison);
+        }
+    }
+
     if (switchnode->initialization != NULL) scope_leave();
 
     return Symres_Success;
index 2a3a5d921bd62dd5cf586fc55ca446606b07a460..aeaf3c93dbeff0b7295c5c8304bcda315f5f0d0c 100644 (file)
@@ -324,7 +324,7 @@ Type* type_build_from_ast(bh_allocator alloc, AstType* type_node) {
                     return NULL;
                 }
 
-                count = get_expression_integer_value(a_node->count_expr);
+                count = get_expression_integer_value(a_node->count_expr, NULL);
             }
 
             a_type->Array.count = count;
index ad526f5d3e644559e41b84d0501220cd51785524..a624a9058e2e929c29191df0160e9a4c313cadcf 100644 (file)
@@ -1140,36 +1140,58 @@ EMIT_FUNC(switch, AstSwitch* switch_node) {
         block_num++;
     }
 
-    u64 count = switch_node->max_case + 1 - switch_node->min_case;
-    BranchTable* bt = bh_alloc(mod->extended_instr_alloc, sizeof(BranchTable) + sizeof(u32) * count);
-    bt->count = count;
-    bt->default_case = block_num;
-    fori (i, 0, bt->count) bt->cases[i] = bt->default_case;
-
-    bh_arr_each(bh__imap_entry, sc, switch_node->case_map.entries) {
-        bt->cases[sc->key - switch_node->min_case] = bh_imap_get(&block_map, (u64) sc->value);
-    }
-
-    // CLEANUP: We enter a new block here in order to setup the correct
-    // indicies for the jump targets in the branch table. For example,
-    //
-    // <expr>
-    // jump_table
-    // label0:
-    // ...
-    // label1:
-    // ...
-    //
-    // If we didn't enter a new block, then jumping to label 0, would jump
-    // to the second block, and so on.
-    WID(WI_BLOCK_START, 0x40);
-    emit_expression(mod, &code, switch_node->expr);
-    if (switch_node->min_case != 0) {
-        WID(WI_I32_CONST, switch_node->min_case);
-        WI(WI_I32_SUB);
-    }
-    WIP(WI_JUMP_TABLE, bt);
-    WI(WI_BLOCK_END);
+    switch (switch_node->switch_kind) {
+        case Switch_Kind_Integer: {
+            u64 count = switch_node->max_case + 1 - switch_node->min_case;
+            BranchTable* bt = bh_alloc(mod->extended_instr_alloc, sizeof(BranchTable) + sizeof(u32) * count);
+            bt->count = count;
+            bt->default_case = block_num;
+            fori (i, 0, bt->count) bt->cases[i] = bt->default_case;
+
+            bh_arr_each(bh__imap_entry, sc, switch_node->case_map.entries) {
+                bt->cases[sc->key - switch_node->min_case] = bh_imap_get(&block_map, (u64) sc->value);
+            }
+
+            // NOTE: We enter a new block here in order to setup the correct
+            // indicies for the jump targets in the branch table. For example,
+            //
+            // <expr>
+            // jump_table
+            // label0:
+            // ...
+            // label1:
+            // ...
+            //
+            // If we didn't enter a new block, then jumping to label 0, would jump
+            // to the second block, and so on.
+            WID(WI_BLOCK_START, 0x40);
+            emit_expression(mod, &code, switch_node->expr);
+            if (switch_node->min_case != 0) {
+                WID(WI_I32_CONST, switch_node->min_case);
+                WI(WI_I32_SUB);
+            }
+            WIP(WI_JUMP_TABLE, bt);
+            WI(WI_BLOCK_END);
+            break;
+        }
+
+        case Switch_Kind_Use_Equals: {
+            WID(WI_BLOCK_START, 0x40);
+
+            bh_arr_each(CaseToBlock, ctb, switch_node->case_exprs) {
+                emit_expression(mod, &code, (AstTyped *) ctb->comparison);
+
+                u64 bn = bh_imap_get(&block_map, (u64) ctb->block);
+                WID(WI_IF_START, 0x40);
+                WID(WI_JUMP, bn + 1);
+                WI(WI_IF_END);
+            }
+
+            WID(WI_JUMP, block_num);
+            WI(WI_BLOCK_END);
+            break;
+        }
+    }
 
     bh_arr_each(AstSwitchCase, sc, switch_node->cases) {
         if (bh_imap_get(&block_map, (u64) sc->block) == 0xdeadbeef) continue;
diff --git a/tests/switch_using_equals b/tests/switch_using_equals
new file mode 100644 (file)
index 0000000..81598e8
--- /dev/null
@@ -0,0 +1,4 @@
+Got some!
+Got thing!
+Got default!
+1, 0
diff --git a/tests/switch_using_equals.onyx b/tests/switch_using_equals.onyx
new file mode 100644 (file)
index 0000000..803c654
--- /dev/null
@@ -0,0 +1,25 @@
+#load "core/std"
+
+use package core
+
+Vector2 :: struct { x, y: i32; }
+#operator == macro (v1: Vector2, v2: Vector2) => v1.x == v2.x && v1.y == v2.y;
+
+main :: (args: [] cstr) {
+    for .[ "Some", "Thing", "Other" ] {
+        switch it {
+            case "Thing" do println("Got thing!");
+            case "Some"  do println("Got some!");
+            case #default do println("Got default!");
+        }
+    }
+
+    v := Vector2.{ 1, 0 };
+    switch v {
+        case .{ 0, 0 } do println("0, 0");
+        case .{ 0, 1 } do println("0, 1");
+        case .{ 1, 0 } do println("1, 0");
+        case .{ 1, 1 } do println("1, 1");
+        case #default do println("none of the above.");
+    }
+}
\ No newline at end of file