add array member access; add tests
authorJudah Caruso <judah@tuta.io>
Mon, 4 Dec 2023 02:01:29 +0000 (19:01 -0700)
committerJudah Caruso <judah@tuta.io>
Mon, 4 Dec 2023 02:01:29 +0000 (19:01 -0700)
compiler/src/checker.c
tests/array_accessors [new file with mode: 0644]
tests/array_accessors.onyx [new file with mode: 0644]

index 2d7f01264120161267971a07300215d20ddfa2d6..ab4e71d839e1eb5c072b7f560e729c08eaf75c60 100644 (file)
@@ -128,7 +128,7 @@ static inline void fill_in_type(AstTyped* node) {
 
 CheckStatus check_return(AstReturn* retnode) {
     Type ** expected_return_type;
-    
+
     if (retnode->count >= (u32) bh_arr_length(context.checker.expected_return_type_stack)) {
         ERROR_(retnode->token->pos, "Too many repeated 'return's here. Expected a maximum of %d.",
                 bh_arr_length(context.checker.expected_return_type_stack));
@@ -554,7 +554,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
                         sc->capture->type = union_variant->type;
                     }
                 }
-                
+
             } else {
                 TYPE_CHECK(value, resolved_expr_type) {
                     OnyxToken* tkn = sc->block->token;
@@ -623,7 +623,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
             if (switchnode->type == NULL) {
                 switchnode->type = resolve_expression_type(sc->expr);
             } else {
-                TYPE_CHECK(&sc->expr, switchnode->type) { 
+                TYPE_CHECK(&sc->expr, switchnode->type) {
                     ERROR_(sc->token->pos, "Expected case expression to be of type '%s', got '%s'.",
                         type_get_name(switchnode->type),
                         type_get_name(sc->expr->type));
@@ -643,7 +643,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
             CHECK(expression, default_case);
 
             if (switchnode->type) {
-                TYPE_CHECK(default_case, switchnode->type) { 
+                TYPE_CHECK(default_case, switchnode->type) {
                     ERROR_((*default_case)->token->pos, "Expected case expression to be of type '%s', got '%s'.",
                         type_get_name(switchnode->type),
                         type_get_name((*default_case)->type));
@@ -723,7 +723,7 @@ static CheckStatus check_resolve_callee(AstCall* call, AstTyped** effective_call
             if (context.cycle_almost_detected < 2) {
                 YIELD(call->token->pos, "Waiting to know all options for overloaded function");
             }
+
             report_unable_to_match_overload(call, ((AstOverloadedFunction *) callee)->overloads);
             return Check_Error;
         }
@@ -1571,7 +1571,7 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) {
         if (sl->type == NULL)
             YIELD(sl->token->pos, "Trying to resolve type of struct literal.");
     }
-    
+
     if (sl->type->kind == Type_Kind_Union) {
         if ((sl->flags & Ast_Flag_Has_Been_Checked) != 0) return Check_Success;
 
@@ -1581,7 +1581,7 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) {
             // Produce an empty value of the first union type.
             UnionVariant *uv = union_type->Union.variants[0].value;
 
-            AstNumLit *tag_value = make_int_literal(context.ast_alloc, uv->tag_value); 
+            AstNumLit *tag_value = make_int_literal(context.ast_alloc, uv->tag_value);
             tag_value->type = union_type->Union.tag_type;
 
             bh_arr_push(sl->args.values, (AstTyped *) tag_value);
@@ -2080,10 +2080,35 @@ CheckStatus check_field_access(AstFieldAccess** pfield) {
     StructMember smem;
     if (!type_lookup_member(field->expr->type, field->field, &smem)) {
         if (field->expr->type->kind == Type_Kind_Array) {
+            u32 field_count = field->expr->type->Array.count;
+
             if (!strcmp(field->field, "count")) {
-                *pfield = (AstFieldAccess *) make_int_literal(context.ast_alloc, field->expr->type->Array.count);
+                *pfield = (AstFieldAccess *) make_int_literal(context.ast_alloc, field_count);
                 return Check_Success;
             }
+
+            // This allows simple field access on fixed-size arrays.
+            // Index 0-3 are mapped to x, y, z, w (vectors) or r, g, b, a (colors).
+            if (field_count <= 4) {
+                u32   index;
+                b32   valid    = 0;
+                char* accessor = field->field;
+
+                // @todo(judah): this should be a small lookup table rather than multiple strcmps.
+                     if (!strcmp(accessor, "x") || !strcmp(accessor, "r")) valid = field_count >= 1, index = 0;
+                else if (!strcmp(accessor, "y") || !strcmp(accessor, "g")) valid = field_count >= 2, index = 1;
+                else if (!strcmp(accessor, "z") || !strcmp(accessor, "b")) valid = field_count >= 3, index = 2;
+                else if (!strcmp(accessor, "w") || !strcmp(accessor, "a")) valid = field_count >= 4, index = 3;
+
+                if (valid) {
+                    *pfield = make_field_access(context.ast_alloc, field->expr, field->field);
+                    (*pfield)->type   = field->expr->type->Array.elem;
+                    (*pfield)->offset = index * type_size_of(field->expr->type->Array.elem);
+                    (*pfield)->idx    = index;
+                    (*pfield)->flags |= Ast_Flag_Has_Been_Checked;
+                    return Check_Success;
+                }
+            }
         }
 
         if (type_union_get_variant_count(field->expr->type) > 0) {
@@ -2194,11 +2219,11 @@ CheckStatus check_field_access(AstFieldAccess** pfield) {
 
     char* closest = find_closest_symbol_in_node((AstNode *) type_node, field->field);
     if (closest) {
-        ERROR_(field->token->pos, "Field '%s' does not exists on '%s'. Did you mean '%s'?", field->field, type_name, closest);
+        ERROR_(field->token->pos, "Field '%s' does not exist on '%s'. Did you mean '%s'?", field->field, type_name, closest);
     }
 
   closest_not_found:
-    ERROR_(field->token->pos, "Field '%s' does not exists on '%s'.", field->field, type_name);
+    ERROR_(field->token->pos, "Field '%s' does not exist on '%s'.", field->field, type_name);
 }
 
 CheckStatus check_method_call(AstBinaryOp** pmcall) {
@@ -2908,7 +2933,7 @@ CheckStatus check_overloaded_function(AstOverloadedFunction* ofunc) {
 
             if (expected_return_node->type_id) {
                 ofunc->expected_return_type = type_lookup_by_id(expected_return_node->type_id);
-                
+
                 // Return early here because the following code does not work with a
                 // polymorphic expected return type.
                 bh_imap_free(&all_overloads);
@@ -2939,7 +2964,7 @@ CheckStatus check_overloaded_function(AstOverloadedFunction* ofunc) {
             }
         }
     }
-    
+
 
     bh_imap_free(&all_overloads);
     return Check_Success;
@@ -2999,7 +3024,7 @@ CheckStatus check_struct(AstStructType* s_node) {
             if (i >= bh_arr_length(s_node->polymorphic_arguments)
                 || !s_node->polymorphic_arguments[i].value) continue;
 
-            
+
             TYPE_CHECK(&s_node->polymorphic_arguments[i].value, arg_type) {
                 ERROR_(s_node->polymorphic_arguments[i].value->token->pos, "Expected value of type %s, got %s.",
                     type_get_name(arg_type),
@@ -3102,7 +3127,7 @@ CheckStatus check_union(AstUnionType *u_node) {
     if (!type_is_integer(tag_type)) {
         ERROR_(u_node->token->pos, "Union tag types must be an integer, got '%s'.", type_get_name(tag_type));
     }
-    
+
     if (u_node->polymorphic_argument_types) {
         assert(u_node->polymorphic_arguments);
 
@@ -3118,7 +3143,7 @@ CheckStatus check_union(AstUnionType *u_node) {
             if (i >= bh_arr_length(u_node->polymorphic_arguments)
                 || !u_node->polymorphic_arguments[i].value) continue;
 
-            
+
             TYPE_CHECK(&u_node->polymorphic_arguments[i].value, arg_type) {
                 ERROR_(u_node->polymorphic_arguments[i].value->token->pos, "Expected value of type %s, got %s.",
                     type_get_name(arg_type),
@@ -3275,7 +3300,7 @@ CheckStatus check_function_header(AstFunction* func) {
     func->type = type_build_function_type(context.ast_alloc, func);
     if (func->type == NULL) YIELD(func->token->pos, "Waiting for function type to be constructed");
 
-    if (func->foreign.import_name) { 
+    if (func->foreign.import_name) {
         CHECK(expression, &func->foreign.module_name);
         CHECK(expression, &func->foreign.import_name);
     }
diff --git a/tests/array_accessors b/tests/array_accessors
new file mode 100644 (file)
index 0000000..bef6ddc
--- /dev/null
@@ -0,0 +1,22 @@
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
+true
diff --git a/tests/array_accessors.onyx b/tests/array_accessors.onyx
new file mode 100644 (file)
index 0000000..671bc49
--- /dev/null
@@ -0,0 +1,46 @@
+#load "core/module"
+
+use core {*}
+
+main :: () {
+   a: [2]f32;
+
+   println(a.x == 0);
+   println(a.y == 0);
+   println(a.r == 0);
+   println(a.g == 0);
+
+   a.x = 10;
+   a.y = 20;
+
+   println(a.x == 10);
+   println(a.y == 20);
+   println(a.r == 10);
+   println(a.g == 20);
+
+   b := f32.[0, 0, 0];
+   b.r = 3;
+   b.g = 2;
+   b.b = 1;
+
+   println(b.x == 3);
+   println(b.y == 2);
+   println(b.z == 1);
+   println(b.r == 3);
+   println(b.g == 2);
+   println(b.b == 1);
+
+   c := f32.[0, 0, 0, 0];
+   c.r, c.g, c.b, c.a = 1, 2, 3, 4;
+   c.x, c.y, c.z, c.w = 1, 2, 3, 4;
+
+   println(c.x == 1);
+   println(c.y == 2);
+   println(c.z == 3);
+   println(c.w == 4);
+
+   println(c.r == 1);
+   println(c.g == 2);
+   println(c.b == 3);
+   println(c.a == 4);
+}