added: explicit sizing of union tags with `#tag_type`
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Sun, 25 Jun 2023 20:36:08 +0000 (15:36 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Sun, 25 Jun 2023 20:36:08 +0000 (15:36 -0500)
compiler/include/astnodes.h
compiler/src/checker.c
compiler/src/parser.c
compiler/src/symres.c
compiler/src/types.c
core/conv/format.onyx
tests/tagged_unions
tests/tagged_unions.onyx

index be3d5f18466853fe23e47c445509da6040107acd..a9afab4d59877703d562dccbc7d929a242e8f279 100644 (file)
@@ -1041,6 +1041,8 @@ struct AstUnionType {
     bh_arr(AstUnionVariant *) variants;
     bh_arr(AstTyped *) meta_tags;
 
+    AstType *tag_backing_type;
+
     // NOTE: Used to cache the actual type, since building
     // a union type is kind of complicated and should
     // only happen once.
index 44376efded40d6f43d94ec4e9edfc146a8507d93..198b7c6a0907527d0832b8490e63835810727ccc 100644 (file)
@@ -3044,6 +3044,13 @@ CheckStatus check_struct_defaults(AstStructType* s_node) {
 }
 
 CheckStatus check_union(AstUnionType *u_node) {
+    CHECK(type, &u_node->tag_backing_type);
+
+    Type *tag_type = type_build_from_ast(context.ast_alloc, u_node->tag_backing_type);
+    if (!type_is_integer(tag_type)) {
+        ERROR_(u_node->token->pos, "Union tag types must be an integer, got '%s'.", type_get_name(tag_type));
+    }
+    
     if (u_node->polymorphic_argument_types) {
         assert(u_node->polymorphic_arguments);
 
index 98e39149d744f70cc96dd614898d92507458dfe3..eac6a14718b400f69e5ae36bc1aac335619de417 100644 (file)
@@ -2357,6 +2357,13 @@ static AstUnionType* parse_union(OnyxParser* parser) {
     Scope *scope_to_restore_parser_to = parser->current_scope;
     Scope *scope_symbols_in_unions_should_be_bound_to = u_node->scope;
 
+    if (parse_possible_directive(parser, "tag_type")) {
+        AstType *backing_type = parse_type(parser);
+        u_node->tag_backing_type = backing_type;
+    } else {
+        u_node->tag_backing_type = &basic_type_u32;
+    }
+
     if (consume_token_if_next(parser, '(')) {
         bh_arr(AstPolyStructParam) poly_params = NULL;
         bh_arr_new(global_heap_allocator, poly_params, 1);
index 781df2b51db40d9c62f7d9e64d474bcafbd59443..50a1163c714a25e907c046c9a74a2655ac2b6023 100644 (file)
@@ -161,6 +161,8 @@ static SymresStatus symres_union_type(AstUnionType* u_node) {
     if (u_node->flags & Ast_Flag_Type_Is_Resolved) return Symres_Success;
     u_node->flags |= Ast_Flag_Comptime;
 
+    SYMRES(type, &u_node->tag_backing_type);
+
     if (u_node->meta_tags) {
         bh_arr_each(AstTyped *, meta, u_node->meta_tags) {
             SYMRES(expression, meta);
index 29c955a9091d7ad5fdf85561e51c677473fde157..9ebd494e71ed0b9966e29a07ac77135d7723723a 100644 (file)
@@ -757,7 +757,7 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b
             AstEnumType* tag_enum_node = onyx_ast_node_new(alloc, sizeof(AstEnumType), Ast_Kind_Enum_Type);
             tag_enum_node->token = union_->token;
             tag_enum_node->name = bh_aprintf(alloc, "%s.tag_enum", union_->name);
-            tag_enum_node->backing_type = &basic_types[Basic_Kind_U32];
+            tag_enum_node->backing_type = type_build_from_ast(alloc, union_->tag_backing_type);
             bh_arr_new(alloc, tag_enum_node->values, bh_arr_length(union_->variants));
 
             void add_entities_for_node(bh_arr(Entity *) *target_arr, AstNode* node, Scope* scope, Package* package); // HACK
@@ -805,7 +805,7 @@ static Type* type_build_from_ast_inner(bh_allocator alloc, AstType* type_node, b
                 bh_arr_push(tag_enum_node->values, ev);
             }
 
-            alignment = bh_max(alignment, 4);
+            alignment = bh_max(alignment, type_alignment_of(tag_enum_node->backing_type));
             bh_align(size, alignment);
 
             u_type->Union.alignment = alignment;
index 91a66fa1ecb9f86187575ad1b02f39f13007f1b7..b86c94b6a8774492b80f6d8a309062b3bb7f9be6 100644 (file)
@@ -729,9 +729,16 @@ format_any :: (output: &Format_Output, formatting: &Format, v: any) {
             if info.kind == .Union {
                 u := cast(&Type_Info_Union) info;
 
-                tag_value := *cast(&u32, v.data);
+                tag_value: u64;
+                switch e := u.tag_enum->info()->as_enum(); e.backing_type {
+                    case i8,  u8  do tag_value = cast(u64) *(cast(&u8) v.data);
+                    case i16, u16 do tag_value = cast(u64) *(cast(&u16) v.data);
+                    case i32, u32 do tag_value = cast(u64) *(cast(&u32) v.data);
+                    case i64, u64 do tag_value = cast(u64) *(cast(&u64) v.data);
+                    case #default do assert(false, "Bad union backing type");
+                }
 
-                variant := array.first(u.variants, [x](x.tag_value == tag_value));
+                variant := array.first(u.variants, [x](x.tag_value == ~~tag_value));
 
                 if !variant {
                     output->write("unknown_variant");
index a3ee253c5ccd0edabd785d448624af58a4f92d60..4d5972526e09dfe18a47ca14a14dd4eef653e562 100644 (file)
@@ -7,3 +7,5 @@ This is an integer: 9876
 456
 Error(Process_Error("bad data"))
 A wrapped value
+2
+Bool(true)
index ececc1c807c8b0c21b1c860f5d21329fa6556a71..c0ad0c3a5e54b7b578fd9af9300a75aeed93c2eb 100644 (file)
@@ -156,6 +156,18 @@ direct_access_is_an_optional :: () {
     println(the_string);
 }
 
+sized_tagged_union :: () {
+    Smol :: union #tag_type u8 {
+        Void: void;
+        Bool: bool;
+    }
+
+    println(sizeof Smol);
+
+    v := Smol.{ Bool = true };
+    println(v);
+}
+
 
 main :: () {
     simple_example();
@@ -164,4 +176,5 @@ main :: () {
     linked_list_example();
     polymorphic_example();
     direct_access_is_an_optional();
+    sized_tagged_union();
 }