added: mandatory check to cover all union variants in switch
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 29 May 2023 16:35:02 +0000 (11:35 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 29 May 2023 16:36:00 +0000 (11:36 -0500)
compiler/include/astnodes.h
compiler/include/types.h
compiler/src/checker.c
compiler/src/types.c

index 1cbbccf4e2f6b1a0e0d9fc4e10e3467eb06d9f49..787e9a3cc70134075c6e2f6d606434ff31eec9ea 100644 (file)
@@ -911,6 +911,10 @@ struct AstSwitch {
     i32 yield_return_index;
     SwitchKind switch_kind;
 
+    // NOTE: This is an array of "bools" that says which union variants have
+    // been handled.
+    u8 *union_variants_handled;
+
     union {
         struct {
             // NOTE: This is a mapping from the compile time known case value
index b04f4d6faadaa186e08ffea7b872a31d59b0611e..ce8c773b547b17501c2780f085bb75a0f47702c1 100644 (file)
@@ -283,5 +283,7 @@ u32 type_structlike_is_simple(Type* type);
 b32 type_is_sl_constructable(Type* type);
 b32 type_constructed_from_poly(Type* base, struct AstType* from);
 Type* type_struct_is_just_one_basic_value(Type *type);
+u32 type_union_get_variant_count(Type *type);
+UnionVariant* type_lookup_union_variant_by_idx(Type* type, i32 idx);
 
 #endif // #ifndef ONYX_TYPES
index c590c122d600c7046357decb42728410ff7bc63c..6bad112f2698d3acb1d9e53487a285334c3741e9 100644 (file)
@@ -440,8 +440,8 @@ CheckStatus check_switch(AstSwitch* switchnode) {
         if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) {
             switchnode->switch_kind = Switch_Kind_Use_Equals;
         }
-        if (switchnode->expr->type->kind == Type_Kind_Union
-            || (switchnode->expr->type->kind == Type_Kind_Pointer && switchnode->expr->type->Pointer.elem->kind == Type_Kind_Union)) {
+
+        if (type_union_get_variant_count(switchnode->expr->type) > 0) {
             switchnode->switch_kind = Switch_Kind_Union;
         }
 
@@ -458,6 +458,9 @@ CheckStatus check_switch(AstSwitch* switchnode) {
             case Switch_Kind_Union:
                 switchnode->min_case = 1;
                 bh_imap_init(&switchnode->case_map, global_heap_allocator, 4);
+
+                u32 variants = type_union_get_variant_count(switchnode->expr->type);
+                switchnode->union_variants_handled = bh_alloc_array(context.ast_alloc, u8, variants);
                 break;
 
             default: assert(0);
@@ -485,7 +488,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
         AstSwitchCase *sc = switchnode->cases[i];
 
         if (sc->capture && bh_arr_length(sc->values) != 1) {
-            ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture.");
+            ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture, i.e. `case X => Y { ... }`.");
         }
 
         if (sc->capture && switchnode->switch_kind != Switch_Kind_Union) {
@@ -537,8 +540,12 @@ CheckStatus check_switch(AstSwitch* switchnode) {
                         (*value)->token->text, (*value)->token->length, type_get_name(union_expr_type));
                 }
 
-                // nocheckin Explain the -1
-                UnionVariant *union_variant = union_expr_type->Union.variants_ordered[get_expression_integer_value(*value, NULL) - 1];
+                // We subtract one here because variant numbering starts at 1, instead of 0.
+                // This is so a zeroed out block of memory does not have a valid variant.
+                i32 variant_number = get_expression_integer_value(*value, NULL) - 1;
+                switchnode->union_variants_handled[variant_number] = 1;
+
+                UnionVariant *union_variant = union_expr_type->Union.variants_ordered[variant_number];
                 if (sc->capture) {
                     if (sc->capture_is_by_pointer) {
                         sc->capture->type = type_make_pointer(context.ast_alloc, union_variant->type);
@@ -604,10 +611,41 @@ CheckStatus check_switch(AstSwitch* switchnode) {
         switchnode->yield_return_index += 1;
     }
 
-    if (switchnode->default_case)
+    if (switchnode->default_case) {
         CHECK(block, switchnode->default_case);
 
-    return 0;
+    } else if (switchnode->switch_kind == Switch_Kind_Union) {
+        // If there is no default case, and this is a union switch,
+        // make sure all cases are handled.
+
+        bh_arr(char *) missed_variants = NULL;
+
+        i32 variant_count = type_union_get_variant_count(switchnode->expr->type);
+        fori (i, 0, variant_count) {
+            if (!switchnode->union_variants_handled[i]) {
+                UnionVariant *uv = type_lookup_union_variant_by_idx(switchnode->expr->type, i);
+                assert(uv && uv->name);
+                bh_arr_push(missed_variants, uv->name);
+            }
+        }
+
+        i32 missed_variant_count = bh_arr_length(missed_variants);
+        if (missed_variant_count > 0) {
+            char  buf[1024] = {0};
+            fori (i, 0, bh_min(missed_variant_count, 2)) {
+                if (i != 0) strncat(buf, ", ", 1023);
+                strncat(buf, missed_variants[i], 1023);
+            }
+
+            if (missed_variant_count > 2) {
+                strncat(buf, bh_bprintf(" and %d more", missed_variant_count - 2), 1023);
+            }
+
+            ERROR_(switchnode->token->pos, "Unhandled union variants: %s", buf);
+        }
+    }
+
+    return Check_Success;
 }
 
 CheckStatus check_arguments(Arguments* args) {
index 1795cbc43480b4d58e561d5f4df3cf4ba2e73735..aed11dde7a7ba040caa809e4f14e1eb033fe1be4 100644 (file)
@@ -1822,3 +1822,22 @@ Type* type_struct_is_just_one_basic_value(Type *type) {
     if (type->Struct.memarr[0]->type->kind != Type_Kind_Basic) return NULL;
     return type->Struct.memarr[0]->type;
 }
+
+u32 type_union_get_variant_count(Type *type) {
+    if (!type) return 0;
+    switch (type->kind) {
+        case Type_Kind_Union: return bh_arr_length(type->Union.variants_ordered);
+        case Type_Kind_Pointer: return type_union_get_variant_count(type->Pointer.elem);
+        default: return 0;
+    }
+}
+
+UnionVariant* type_lookup_union_variant_by_idx(Type* type, i32 idx) {
+    if (!type) return NULL;
+    if (type->kind == Type_Kind_Pointer) type = type->Pointer.elem;
+    if (type->kind != Type_Kind_Union) return NULL;
+    if (idx < 0 || idx >= bh_arr_length(type->Union.variants_ordered)) return NULL;
+
+    return type->Union.variants_ordered[idx];
+}
+