From: Brendan Hansen Date: Sun, 25 Jun 2023 20:36:08 +0000 (-0500) Subject: added: explicit sizing of union tags with `#tag_type` X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=8c5b09c10593565ce573a381efa7167d47ff40a1;p=onyx.git added: explicit sizing of union tags with `#tag_type` --- diff --git a/compiler/include/astnodes.h b/compiler/include/astnodes.h index be3d5f18..a9afab4d 100644 --- a/compiler/include/astnodes.h +++ b/compiler/include/astnodes.h @@ -1041,6 +1041,8 @@ struct AstUnionType { bh_arr(AstUnionVariant *) variants; bh_arr(AstTyped *) meta_tags; + AstType *tag_backing_type; + // NOTE: Used to cache the actual type, since building // a union type is kind of complicated and should // only happen once. diff --git a/compiler/src/checker.c b/compiler/src/checker.c index 44376efd..198b7c6a 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -3044,6 +3044,13 @@ CheckStatus check_struct_defaults(AstStructType* s_node) { } CheckStatus check_union(AstUnionType *u_node) { + CHECK(type, &u_node->tag_backing_type); + + Type *tag_type = type_build_from_ast(context.ast_alloc, u_node->tag_backing_type); + if (!type_is_integer(tag_type)) { + ERROR_(u_node->token->pos, "Union tag types must be an integer, got '%s'.", type_get_name(tag_type)); + } + if (u_node->polymorphic_argument_types) { assert(u_node->polymorphic_arguments); diff --git a/compiler/src/parser.c b/compiler/src/parser.c index 98e39149..eac6a147 100644 --- a/compiler/src/parser.c +++ b/compiler/src/parser.c @@ -2357,6 +2357,13 @@ static AstUnionType* parse_union(OnyxParser* parser) { Scope *scope_to_restore_parser_to = parser->current_scope; Scope *scope_symbols_in_unions_should_be_bound_to = u_node->scope; + if (parse_possible_directive(parser, "tag_type")) { + AstType *backing_type = parse_type(parser); + u_node->tag_backing_type = backing_type; + } else { + u_node->tag_backing_type = &basic_type_u32; + } + if (consume_token_if_next(parser, '(')) { bh_arr(AstPolyStructParam) poly_params = NULL; bh_arr_new(global_heap_allocator, poly_params, 1); diff --git a/compiler/src/symres.c b/compiler/src/symres.c index 781df2b5..50a1163c 100644 --- a/compiler/src/symres.c +++ b/compiler/src/symres.c @@ -161,6 +161,8 @@ static SymresStatus symres_union_type(AstUnionType* u_node) { if (u_node->flags & Ast_Flag_Type_Is_Resolved) return Symres_Success; u_node->flags |= Ast_Flag_Comptime; + SYMRES(type, &u_node->tag_backing_type); + if (u_node->meta_tags) { bh_arr_each(AstTyped *, meta, u_node->meta_tags) { SYMRES(expression, meta); diff --git a/compiler/src/types.c b/compiler/src/types.c index 29c955a9..9ebd494e 100644 --- a/compiler/src/types.c +++ b/compiler/src/types.c @@ -757,7 +757,7 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b AstEnumType* tag_enum_node = onyx_ast_node_new(alloc, sizeof(AstEnumType), Ast_Kind_Enum_Type); tag_enum_node->token = union_->token; tag_enum_node->name = bh_aprintf(alloc, "%s.tag_enum", union_->name); - tag_enum_node->backing_type = &basic_types[Basic_Kind_U32]; + tag_enum_node->backing_type = type_build_from_ast(alloc, union_->tag_backing_type); bh_arr_new(alloc, tag_enum_node->values, bh_arr_length(union_->variants)); void add_entities_for_node(bh_arr(Entity *) *target_arr, AstNode* node, Scope* scope, Package* package); // HACK @@ -805,7 +805,7 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b bh_arr_push(tag_enum_node->values, ev); } - alignment = bh_max(alignment, 4); + alignment = bh_max(alignment, type_alignment_of(tag_enum_node->backing_type)); bh_align(size, alignment); u_type->Union.alignment = alignment; diff --git a/core/conv/format.onyx b/core/conv/format.onyx index 91a66fa1..b86c94b6 100644 --- a/core/conv/format.onyx +++ b/core/conv/format.onyx @@ -729,9 +729,16 @@ format_any :: (output: &Format_Output, formatting: &Format, v: any) { if info.kind == .Union { u := cast(&Type_Info_Union) info; - tag_value := *cast(&u32, v.data); + tag_value: u64; + switch e := u.tag_enum->info()->as_enum(); e.backing_type { + case i8, u8 do tag_value = cast(u64) *(cast(&u8) v.data); + case i16, u16 do tag_value = cast(u64) *(cast(&u16) v.data); + case i32, u32 do tag_value = cast(u64) *(cast(&u32) v.data); + case i64, u64 do tag_value = cast(u64) *(cast(&u64) v.data); + case #default do assert(false, "Bad union backing type"); + } - variant := array.first(u.variants, [x](x.tag_value == tag_value)); + variant := array.first(u.variants, [x](x.tag_value == ~~tag_value)); if !variant { output->write("unknown_variant"); diff --git a/tests/tagged_unions b/tests/tagged_unions index a3ee253c..4d597252 100644 --- a/tests/tagged_unions +++ b/tests/tagged_unions @@ -7,3 +7,5 @@ This is an integer: 9876 456 Error(Process_Error("bad data")) A wrapped value +2 +Bool(true) diff --git a/tests/tagged_unions.onyx b/tests/tagged_unions.onyx index ececc1c8..c0ad0c3a 100644 --- a/tests/tagged_unions.onyx +++ b/tests/tagged_unions.onyx @@ -156,6 +156,18 @@ direct_access_is_an_optional :: () { println(the_string); } +sized_tagged_union :: () { + Smol :: union #tag_type u8 { + Void: void; + Bool: bool; + } + + println(sizeof Smol); + + v := Smol.{ Bool = true }; + println(v); +} + main :: () { simple_example(); @@ -164,4 +176,5 @@ main :: () { linked_list_example(); polymorphic_example(); direct_access_is_an_optional(); + sized_tagged_union(); }