From: Brendan Hansen Date: Mon, 29 May 2023 16:35:02 +0000 (-0500) Subject: added: mandatory check to cover all union variants in switch X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=6df9aff168daabb542180ea9c448ba70c9ce4b83;p=onyx.git added: mandatory check to cover all union variants in switch --- diff --git a/compiler/include/astnodes.h b/compiler/include/astnodes.h index 1cbbccf4..787e9a3c 100644 --- a/compiler/include/astnodes.h +++ b/compiler/include/astnodes.h @@ -911,6 +911,10 @@ struct AstSwitch { i32 yield_return_index; SwitchKind switch_kind; + // NOTE: This is an array of "bools" that says which union variants have + // been handled. + u8 *union_variants_handled; + union { struct { // NOTE: This is a mapping from the compile time known case value diff --git a/compiler/include/types.h b/compiler/include/types.h index b04f4d6f..ce8c773b 100644 --- a/compiler/include/types.h +++ b/compiler/include/types.h @@ -283,5 +283,7 @@ u32 type_structlike_is_simple(Type* type); b32 type_is_sl_constructable(Type* type); b32 type_constructed_from_poly(Type* base, struct AstType* from); Type* type_struct_is_just_one_basic_value(Type *type); +u32 type_union_get_variant_count(Type *type); +UnionVariant* type_lookup_union_variant_by_idx(Type* type, i32 idx); #endif // #ifndef ONYX_TYPES diff --git a/compiler/src/checker.c b/compiler/src/checker.c index c590c122..6bad112f 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -440,8 +440,8 @@ CheckStatus check_switch(AstSwitch* switchnode) { if (!type_is_integer(switchnode->expr->type) && switchnode->expr->type->kind != Type_Kind_Enum) { switchnode->switch_kind = Switch_Kind_Use_Equals; } - if (switchnode->expr->type->kind == Type_Kind_Union - || (switchnode->expr->type->kind == Type_Kind_Pointer && switchnode->expr->type->Pointer.elem->kind == Type_Kind_Union)) { + + if (type_union_get_variant_count(switchnode->expr->type) > 0) { switchnode->switch_kind = Switch_Kind_Union; } @@ -458,6 +458,9 @@ CheckStatus check_switch(AstSwitch* switchnode) { case Switch_Kind_Union: switchnode->min_case = 1; bh_imap_init(&switchnode->case_map, global_heap_allocator, 4); + + u32 variants = type_union_get_variant_count(switchnode->expr->type); + switchnode->union_variants_handled = bh_alloc_array(context.ast_alloc, u8, variants); break; default: assert(0); @@ -485,7 +488,7 @@ CheckStatus check_switch(AstSwitch* switchnode) { AstSwitchCase *sc = switchnode->cases[i]; if (sc->capture && bh_arr_length(sc->values) != 1) { - ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture."); + ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture, i.e. `case X => Y { ... }`."); } if (sc->capture && switchnode->switch_kind != Switch_Kind_Union) { @@ -537,8 +540,12 @@ CheckStatus check_switch(AstSwitch* switchnode) { (*value)->token->text, (*value)->token->length, type_get_name(union_expr_type)); } - // nocheckin Explain the -1 - UnionVariant *union_variant = union_expr_type->Union.variants_ordered[get_expression_integer_value(*value, NULL) - 1]; + // We subtract one here because variant numbering starts at 1, instead of 0. + // This is so a zeroed out block of memory does not have a valid variant. + i32 variant_number = get_expression_integer_value(*value, NULL) - 1; + switchnode->union_variants_handled[variant_number] = 1; + + UnionVariant *union_variant = union_expr_type->Union.variants_ordered[variant_number]; if (sc->capture) { if (sc->capture_is_by_pointer) { sc->capture->type = type_make_pointer(context.ast_alloc, union_variant->type); @@ -604,10 +611,41 @@ CheckStatus check_switch(AstSwitch* switchnode) { switchnode->yield_return_index += 1; } - if (switchnode->default_case) + if (switchnode->default_case) { CHECK(block, switchnode->default_case); - return 0; + } else if (switchnode->switch_kind == Switch_Kind_Union) { + // If there is no default case, and this is a union switch, + // make sure all cases are handled. + + bh_arr(char *) missed_variants = NULL; + + i32 variant_count = type_union_get_variant_count(switchnode->expr->type); + fori (i, 0, variant_count) { + if (!switchnode->union_variants_handled[i]) { + UnionVariant *uv = type_lookup_union_variant_by_idx(switchnode->expr->type, i); + assert(uv && uv->name); + bh_arr_push(missed_variants, uv->name); + } + } + + i32 missed_variant_count = bh_arr_length(missed_variants); + if (missed_variant_count > 0) { + char buf[1024] = {0}; + fori (i, 0, bh_min(missed_variant_count, 2)) { + if (i != 0) strncat(buf, ", ", 1023); + strncat(buf, missed_variants[i], 1023); + } + + if (missed_variant_count > 2) { + strncat(buf, bh_bprintf(" and %d more", missed_variant_count - 2), 1023); + } + + ERROR_(switchnode->token->pos, "Unhandled union variants: %s", buf); + } + } + + return Check_Success; } CheckStatus check_arguments(Arguments* args) { diff --git a/compiler/src/types.c b/compiler/src/types.c index 1795cbc4..aed11dde 100644 --- a/compiler/src/types.c +++ b/compiler/src/types.c @@ -1822,3 +1822,22 @@ Type* type_struct_is_just_one_basic_value(Type *type) { if (type->Struct.memarr[0]->type->kind != Type_Kind_Basic) return NULL; return type->Struct.memarr[0]->type; } + +u32 type_union_get_variant_count(Type *type) { + if (!type) return 0; + switch (type->kind) { + case Type_Kind_Union: return bh_arr_length(type->Union.variants_ordered); + case Type_Kind_Pointer: return type_union_get_variant_count(type->Pointer.elem); + default: return 0; + } +} + +UnionVariant* type_lookup_union_variant_by_idx(Type* type, i32 idx) { + if (!type) return NULL; + if (type->kind == Type_Kind_Pointer) type = type->Pointer.elem; + if (type->kind != Type_Kind_Union) return NULL; + if (idx < 0 || idx >= bh_arr_length(type->Union.variants_ordered)) return NULL; + + return type->Union.variants_ordered[idx]; +} +