From: Judah Caruso Date: Mon, 4 Dec 2023 02:01:29 +0000 (-0700) Subject: add array member access; add tests X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=a4cddeb4cfd706a9e65dcf5d57f63244b22121ca;p=onyx.git add array member access; add tests --- diff --git a/compiler/src/checker.c b/compiler/src/checker.c index 2d7f0126..ab4e71d8 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -128,7 +128,7 @@ static inline void fill_in_type(AstTyped* node) { CheckStatus check_return(AstReturn* retnode) { Type ** expected_return_type; - + if (retnode->count >= (u32) bh_arr_length(context.checker.expected_return_type_stack)) { ERROR_(retnode->token->pos, "Too many repeated 'return's here. Expected a maximum of %d.", bh_arr_length(context.checker.expected_return_type_stack)); @@ -554,7 +554,7 @@ CheckStatus check_switch(AstSwitch* switchnode) { sc->capture->type = union_variant->type; } } - + } else { TYPE_CHECK(value, resolved_expr_type) { OnyxToken* tkn = sc->block->token; @@ -623,7 +623,7 @@ CheckStatus check_switch(AstSwitch* switchnode) { if (switchnode->type == NULL) { switchnode->type = resolve_expression_type(sc->expr); } else { - TYPE_CHECK(&sc->expr, switchnode->type) { + TYPE_CHECK(&sc->expr, switchnode->type) { ERROR_(sc->token->pos, "Expected case expression to be of type '%s', got '%s'.", type_get_name(switchnode->type), type_get_name(sc->expr->type)); @@ -643,7 +643,7 @@ CheckStatus check_switch(AstSwitch* switchnode) { CHECK(expression, default_case); if (switchnode->type) { - TYPE_CHECK(default_case, switchnode->type) { + TYPE_CHECK(default_case, switchnode->type) { ERROR_((*default_case)->token->pos, "Expected case expression to be of type '%s', got '%s'.", type_get_name(switchnode->type), type_get_name((*default_case)->type)); @@ -723,7 +723,7 @@ static CheckStatus check_resolve_callee(AstCall* call, AstTyped** effective_call if (context.cycle_almost_detected < 2) { YIELD(call->token->pos, "Waiting to know all options for overloaded function"); } - + report_unable_to_match_overload(call, ((AstOverloadedFunction *) callee)->overloads); return Check_Error; } @@ -1571,7 +1571,7 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) { if (sl->type == NULL) YIELD(sl->token->pos, "Trying to resolve type of struct literal."); } - + if (sl->type->kind == Type_Kind_Union) { if ((sl->flags & Ast_Flag_Has_Been_Checked) != 0) return Check_Success; @@ -1581,7 +1581,7 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) { // Produce an empty value of the first union type. UnionVariant *uv = union_type->Union.variants[0].value; - AstNumLit *tag_value = make_int_literal(context.ast_alloc, uv->tag_value); + AstNumLit *tag_value = make_int_literal(context.ast_alloc, uv->tag_value); tag_value->type = union_type->Union.tag_type; bh_arr_push(sl->args.values, (AstTyped *) tag_value); @@ -2080,10 +2080,35 @@ CheckStatus check_field_access(AstFieldAccess** pfield) { StructMember smem; if (!type_lookup_member(field->expr->type, field->field, &smem)) { if (field->expr->type->kind == Type_Kind_Array) { + u32 field_count = field->expr->type->Array.count; + if (!strcmp(field->field, "count")) { - *pfield = (AstFieldAccess *) make_int_literal(context.ast_alloc, field->expr->type->Array.count); + *pfield = (AstFieldAccess *) make_int_literal(context.ast_alloc, field_count); return Check_Success; } + + // This allows simple field access on fixed-size arrays. + // Index 0-3 are mapped to x, y, z, w (vectors) or r, g, b, a (colors). + if (field_count <= 4) { + u32 index; + b32 valid = 0; + char* accessor = field->field; + + // @todo(judah): this should be a small lookup table rather than multiple strcmps. + if (!strcmp(accessor, "x") || !strcmp(accessor, "r")) valid = field_count >= 1, index = 0; + else if (!strcmp(accessor, "y") || !strcmp(accessor, "g")) valid = field_count >= 2, index = 1; + else if (!strcmp(accessor, "z") || !strcmp(accessor, "b")) valid = field_count >= 3, index = 2; + else if (!strcmp(accessor, "w") || !strcmp(accessor, "a")) valid = field_count >= 4, index = 3; + + if (valid) { + *pfield = make_field_access(context.ast_alloc, field->expr, field->field); + (*pfield)->type = field->expr->type->Array.elem; + (*pfield)->offset = index * type_size_of(field->expr->type->Array.elem); + (*pfield)->idx = index; + (*pfield)->flags |= Ast_Flag_Has_Been_Checked; + return Check_Success; + } + } } if (type_union_get_variant_count(field->expr->type) > 0) { @@ -2194,11 +2219,11 @@ CheckStatus check_field_access(AstFieldAccess** pfield) { char* closest = find_closest_symbol_in_node((AstNode *) type_node, field->field); if (closest) { - ERROR_(field->token->pos, "Field '%s' does not exists on '%s'. Did you mean '%s'?", field->field, type_name, closest); + ERROR_(field->token->pos, "Field '%s' does not exist on '%s'. Did you mean '%s'?", field->field, type_name, closest); } closest_not_found: - ERROR_(field->token->pos, "Field '%s' does not exists on '%s'.", field->field, type_name); + ERROR_(field->token->pos, "Field '%s' does not exist on '%s'.", field->field, type_name); } CheckStatus check_method_call(AstBinaryOp** pmcall) { @@ -2908,7 +2933,7 @@ CheckStatus check_overloaded_function(AstOverloadedFunction* ofunc) { if (expected_return_node->type_id) { ofunc->expected_return_type = type_lookup_by_id(expected_return_node->type_id); - + // Return early here because the following code does not work with a // polymorphic expected return type. bh_imap_free(&all_overloads); @@ -2939,7 +2964,7 @@ CheckStatus check_overloaded_function(AstOverloadedFunction* ofunc) { } } } - + bh_imap_free(&all_overloads); return Check_Success; @@ -2999,7 +3024,7 @@ CheckStatus check_struct(AstStructType* s_node) { if (i >= bh_arr_length(s_node->polymorphic_arguments) || !s_node->polymorphic_arguments[i].value) continue; - + TYPE_CHECK(&s_node->polymorphic_arguments[i].value, arg_type) { ERROR_(s_node->polymorphic_arguments[i].value->token->pos, "Expected value of type %s, got %s.", type_get_name(arg_type), @@ -3102,7 +3127,7 @@ CheckStatus check_union(AstUnionType *u_node) { if (!type_is_integer(tag_type)) { ERROR_(u_node->token->pos, "Union tag types must be an integer, got '%s'.", type_get_name(tag_type)); } - + if (u_node->polymorphic_argument_types) { assert(u_node->polymorphic_arguments); @@ -3118,7 +3143,7 @@ CheckStatus check_union(AstUnionType *u_node) { if (i >= bh_arr_length(u_node->polymorphic_arguments) || !u_node->polymorphic_arguments[i].value) continue; - + TYPE_CHECK(&u_node->polymorphic_arguments[i].value, arg_type) { ERROR_(u_node->polymorphic_arguments[i].value->token->pos, "Expected value of type %s, got %s.", type_get_name(arg_type), @@ -3275,7 +3300,7 @@ CheckStatus check_function_header(AstFunction* func) { func->type = type_build_function_type(context.ast_alloc, func); if (func->type == NULL) YIELD(func->token->pos, "Waiting for function type to be constructed"); - if (func->foreign.import_name) { + if (func->foreign.import_name) { CHECK(expression, &func->foreign.module_name); CHECK(expression, &func->foreign.import_name); } diff --git a/tests/array_accessors b/tests/array_accessors new file mode 100644 index 00000000..bef6ddca --- /dev/null +++ b/tests/array_accessors @@ -0,0 +1,22 @@ +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true +true diff --git a/tests/array_accessors.onyx b/tests/array_accessors.onyx new file mode 100644 index 00000000..671bc49b --- /dev/null +++ b/tests/array_accessors.onyx @@ -0,0 +1,46 @@ +#load "core/module" + +use core {*} + +main :: () { + a: [2]f32; + + println(a.x == 0); + println(a.y == 0); + println(a.r == 0); + println(a.g == 0); + + a.x = 10; + a.y = 20; + + println(a.x == 10); + println(a.y == 20); + println(a.r == 10); + println(a.g == 20); + + b := f32.[0, 0, 0]; + b.r = 3; + b.g = 2; + b.b = 1; + + println(b.x == 3); + println(b.y == 2); + println(b.z == 1); + println(b.r == 3); + println(b.g == 2); + println(b.b == 1); + + c := f32.[0, 0, 0, 0]; + c.r, c.g, c.b, c.a = 1, 2, 3, 4; + c.x, c.y, c.z, c.w = 1, 2, 3, 4; + + println(c.x == 1); + println(c.y == 2); + println(c.z == 3); + println(c.w == 4); + + println(c.r == 1); + println(c.g == 2); + println(c.b == 3); + println(c.a == 4); +}