bugfix: misc with unions
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 22 May 2023 19:45:28 +0000 (14:45 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 22 May 2023 19:45:28 +0000 (14:45 -0500)
compiler/include/parser.h
compiler/src/checker.c
compiler/src/clone.c
compiler/src/parser.c
compiler/src/polymorph.h
core/container/slice.onyx
tests/tagged_unions.onyx

index a4c965889d37ce0721fe77a3cfff9eece9c57aee..237f5f8fdb4b53d556357fb190d714a51811f39c 100644 (file)
@@ -56,6 +56,7 @@ typedef struct OnyxParser {
 
     b32 hit_unexpected_token : 1;
     b32 parse_calls : 1;
+    b32 parse_quick_functions : 1;
 
     // Currently, package expressions are only allowed in certain places.
     b32 allow_package_expressions : 1;
index e9140f79b558065ac9e5c9aa04278983135ba4d7..568e315aef00ebf41a13433cffeb29be0483b3e7 100644 (file)
@@ -487,6 +487,10 @@ CheckStatus check_switch(AstSwitch* switchnode) {
             ERROR(sc->token->pos, "Expected exactly one value in switch-case when using a capture.");
         }
 
+        if (sc->capture && switchnode->switch_kind != Switch_Kind_Union) {
+            ERROR(sc->capture->token->pos, "Captures in switch cases are only allowed when switching over a union type.");
+        }
+
         if (sc->flags & Ast_Flag_Has_Been_Checked) goto check_switch_case_block;
 
         bh_arr_each(AstTyped *, value, sc->values) {
@@ -1483,6 +1487,8 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) {
     }
     
     if (sl->type->kind == Type_Kind_Union) {
+        if ((sl->flags & Ast_Flag_Has_Been_Checked) != 0) return Check_Success;
+
         if (bh_arr_length(sl->args.values) != 0 || bh_arr_length(sl->args.named_values) != 1) {
             ERROR_(sl->token->pos, "Expected exactly one named member when constructing an instance of a union type, '%s'.", type_get_name(sl->type));
         }
@@ -1513,6 +1519,8 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) {
 
         bh_arr_push(sl->args.values, (AstTyped *) tag_value);
         bh_arr_push(sl->args.values, value->value);
+
+        sl->flags |= Ast_Flag_Has_Been_Checked;
         return Check_Success;
     }
 
index c44ec994619a193d16cc2978f60e1936d96ba381..fdbe7a8c53069c7a83f9c48e4e3da314d7c385ba 100644 (file)
@@ -304,6 +304,7 @@ AstNode* ast_clone(bh_allocator a, void* n) {
             break;
 
         case Ast_Kind_Switch_Case: {
+            C(AstSwitchCase, capture);
             C(AstSwitchCase, block);
 
             AstSwitchCase *dw = (AstSwitchCase *) nn;
index e5564092c6ff5bd8f78c9419b2200117197d45c0..ad0370200c7faaefc4b18da7b04599fd032104d8 100644 (file)
@@ -1309,6 +1309,7 @@ static AstSwitchCase* parse_case_stmt(OnyxParser* parser) {
     } else {
         bh_arr_new(global_heap_allocator, sc_node->values, 1);
 
+        parser->parse_quick_functions = 0;
         AstTyped* value = parse_expression(parser, 1);
         bh_arr_push(sc_node->values, value);
         while (consume_token_if_next(parser, ',')) {
@@ -1317,6 +1318,8 @@ static AstSwitchCase* parse_case_stmt(OnyxParser* parser) {
             value = parse_expression(parser, 1);
             bh_arr_push(sc_node->values, value);
         }
+
+        parser->parse_quick_functions = 1;
     }
 
     if (consume_token_if_next(parser, Token_Type_Fat_Right_Arrow)) {
@@ -2859,6 +2862,8 @@ typedef struct QuickParam {
 } QuickParam;
 
 static b32 parse_possible_quick_function_definition_no_consume(OnyxParser* parser) {
+    if (!parser->parse_quick_functions) return 0;
+
     //
     // x => x + 1 case.
     if (next_tokens_are(parser, 2, Token_Type_Symbol, Token_Type_Fat_Right_Arrow)) {
@@ -3933,6 +3938,7 @@ OnyxParser onyx_parser_create(bh_allocator alloc, OnyxTokenizer *tokenizer) {
     parser.scope_flags = NULL;
     parser.stored_tags = NULL;
     parser.parse_calls = 1;
+    parser.parse_quick_functions = 1;
     parser.tag_depth = 0;
     parser.overload_count = 0;
     parser.injection_point = NULL;
index de83db133ba682099d269ae2768ef302b48ece28..440688e92128772d28e9075ffb565c189d6e48df 100644 (file)
@@ -648,11 +648,6 @@ static void solve_for_polymorphic_param_value(PolySolveResult* resolved, AstFunc
     } else {
         resolve_expression_type(value);
 
-        if ((value->flags & Ast_Flag_Comptime) == 0) {
-            if (err_msg) *err_msg = "Expected compile-time known argument here.";
-            return;
-        }
-
         param_type = type_build_from_ast(context.ast_alloc, param_type_expr);
         if (param_type == NULL) {
             flag_to_yield = 1;
@@ -679,7 +674,12 @@ static void solve_for_polymorphic_param_value(PolySolveResult* resolved, AstFunc
 
         if (tm == TYPE_MATCH_YIELD) flag_to_yield = 1;
 
-        *resolved = ((PolySolveResult) { PSK_Value, value });
+        if ((value_to_use->flags & Ast_Flag_Comptime) == 0) {
+            if (err_msg) *err_msg = "Expected compile-time known argument here.";
+            return;
+        }
+
+        *resolved = ((PolySolveResult) { PSK_Value, value_to_use });
     }
 
     if (orig_value->kind == Ast_Kind_Argument) {
index 2b37b24bbfbafaac451e8ee0c937000ef7cccaa2..e5fed63857d83417b7c9b3deeccf7adc51312ad5 100644 (file)
@@ -160,14 +160,14 @@ set :: (arr: [] $T, idx: i32, value: T) {
 }
 
 contains :: #match #locked {
-    macro (arr: [] $T, $cmp: Code) -> bool {
-        for it: arr do if #unquote cmp do return true;
-        return false;
-    },
-
     (arr: [] $T, x: T) -> bool {
         for it: arr do if it == x do return true;
         return false;
+    }, 
+
+    macro (arr: [] $T, $cmp: Code) -> bool {
+        for it: arr do if #unquote cmp do return true;
+        return false;
     }
 }
 
index f616c3012043914922f4598b99af302775a33ca0..7170402075b6246628edd0c2617b54eb634a5379 100644 (file)
@@ -1,6 +1,16 @@
 
 use core {*}
 
+extract_variant :: macro (u: $U, $variant: U.tag_enum) => {
+    switch u {
+        case variant => v {
+            return Optional.make(v);
+        }
+    }
+
+    return .{};
+}
+
 SimpleUnion :: union {
     a: i32;
     b: struct {
@@ -40,6 +50,10 @@ simple_test :: () {
     println(cast(SimpleUnion.tag_enum) u);
 
     call_test(u);
+
+    if cast(SimpleUnion.tag_enum, u) == .b {
+        println(extract_variant(u, .b)?);
+    }
 }
 
 main :: () {simple_test(); link_test();}