From 8dd2128819d768299ae5b9f3445175e4d44cb8c5 Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Sun, 30 Oct 2022 22:21:58 -0500 Subject: [PATCH] added implicit cast to bool for pointers and arrays --- compiler/include/astnodes.h | 2 ++ compiler/src/astnodes.c | 42 +++++++++++++++++++++++++++++++++ compiler/src/checker.c | 46 ++++++++++++++++++++++++++++--------- compiler/src/parser.c | 15 +++++++----- 4 files changed, 88 insertions(+), 17 deletions(-) diff --git a/compiler/include/astnodes.h b/compiler/include/astnodes.h index c58df410..0b336593 100644 --- a/compiler/include/astnodes.h +++ b/compiler/include/astnodes.h @@ -1723,6 +1723,8 @@ char *get_expression_string_value(AstTyped* node, b32 *out_is_valid); b32 cast_is_legal(Type* from_, Type* to_, char** err_msg); char* get_function_name(AstFunction* func); +b32 implicit_cast_to_bool(AstTyped **pnode); + AstNode* strip_aliases(AstNode* node); AstNumLit* make_bool_literal(bh_allocator, b32 b); diff --git a/compiler/src/astnodes.c b/compiler/src/astnodes.c index 70601550..e09c9402 100644 --- a/compiler/src/astnodes.c +++ b/compiler/src/astnodes.c @@ -1185,6 +1185,48 @@ b32 cast_is_legal(Type* from_, Type* to_, char** err_msg) { return 1; } +b32 implicit_cast_to_bool(AstTyped **pnode) { + AstTyped *node = *pnode; + + if (node->type->kind == Type_Kind_Pointer) { + AstNumLit *zero = make_int_literal(context.ast_alloc, 0); + zero->type = &basic_types[Basic_Kind_Rawptr]; + + AstBinaryOp* cmp = make_binary_op(context.ast_alloc, Binary_Op_Not_Equal, node, (AstTyped *) zero); + cmp->token = node->token; + cmp->type = &basic_types[Basic_Kind_Bool]; + + *pnode = (AstTyped *) cmp; + return 1; + } + + if (node->type->kind == Type_Kind_Slice || + node->type->kind == Type_Kind_DynArray || + node->type->kind == Type_Kind_VarArgs) { + StructMember smem; + assert(type_lookup_member(node->type, "count", &smem)); + + // These fields are filled out here in order to prevent + // going through the type checker one more time. + AstFieldAccess *field = make_field_access(context.ast_alloc, node, "count"); + field->offset = smem.offset; + field->idx = smem.idx; + field->type = smem.type; + field->flags |= Ast_Flag_Has_Been_Checked; + + AstNumLit *zero = make_int_literal(context.ast_alloc, 0); + zero->type = smem.type; + + AstBinaryOp* cmp = make_binary_op(context.ast_alloc, Binary_Op_Not_Equal, (AstTyped *) field, (AstTyped *) zero); + cmp->type = &basic_types[Basic_Kind_Bool]; + + *pnode = (AstTyped *) cmp; + return 1; + } + + return 0; +} + char* get_function_name(AstFunction* func) { if (func->kind != Ast_Kind_Function) return ""; diff --git a/compiler/src/checker.c b/compiler/src/checker.c index 2f2ad2f6..f879a45c 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -203,7 +203,9 @@ CheckStatus check_if(AstIfWhile* ifnode) { CHECK(expression, &ifnode->cond); if (!type_is_bool(ifnode->cond->type)) { - ERROR_(ifnode->cond->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(ifnode->cond->type)); + if (!implicit_cast_to_bool(&ifnode->cond)) { + ERROR_(ifnode->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(ifnode->cond->type)); + } } if (ifnode->true_stmt) CHECK(statement, (AstNode **) &ifnode->true_stmt); @@ -219,7 +221,9 @@ CheckStatus check_while(AstIfWhile* whilenode) { CHECK(expression, &whilenode->cond); if (!type_is_bool(whilenode->cond->type)) { - ERROR_(whilenode->cond->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(whilenode->cond->type)); + if (!implicit_cast_to_bool(&whilenode->cond)) { + ERROR_(whilenode->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(whilenode->cond->type)); + } } if (whilenode->true_stmt) CHECK(statement, (AstNode **) &whilenode->true_stmt); @@ -1028,7 +1032,22 @@ CheckStatus check_binaryop_compare(AstBinaryOp** pbinop) { CheckStatus check_binaryop_bool(AstBinaryOp** pbinop) { AstBinaryOp* binop = *pbinop; - if (!type_is_bool(binop->left->type) || !type_is_bool(binop->right->type)) { + b32 left_is_bool = 0; + b32 right_is_bool = 0; + + if (type_is_bool(binop->left->type)) { + left_is_bool = 1; + } else if (implicit_cast_to_bool(&binop->left)) { + left_is_bool = 1; + } + + if (type_is_bool(binop->right->type)) { + right_is_bool = 1; + } else if (implicit_cast_to_bool(&binop->right)) { + right_is_bool = 1; + } + + if (!left_is_bool || !right_is_bool) { report_bad_binaryop(binop); return Check_Error; } @@ -1227,16 +1246,15 @@ CheckStatus check_unaryop(AstUnaryOp** punop) { if (!cast_is_legal(unaryop->expr->type, unaryop->type, &err)) { ERROR_(unaryop->token->pos, "Cast Error: %s", err); } - - } else { - unaryop->type = unaryop->expr->type; } if (unaryop->operation == Unary_Op_Not) { if (!type_is_bool(unaryop->expr->type)) { - ERROR_(unaryop->token->pos, - "Bool negation operator expected bool type, got '%s'.", - node_get_type_name(unaryop->expr)); + if (!implicit_cast_to_bool(&unaryop->expr)) { + ERROR_(unaryop->token->pos, + "Bool negation operator expected bool type, got '%s'.", + node_get_type_name(unaryop->expr)); + } } } @@ -1248,6 +1266,10 @@ CheckStatus check_unaryop(AstUnaryOp** punop) { } } + if (unaryop->operation != Unary_Op_Cast) { + unaryop->type = unaryop->expr->type; + } + if (unaryop->expr->flags & Ast_Flag_Comptime) { unaryop->flags |= Ast_Flag_Comptime; // NOTE: Not a unary op @@ -1475,8 +1497,10 @@ CheckStatus check_if_expression(AstIfExpression* if_expr) { CHECK(expression, &if_expr->false_expr); TYPE_CHECK(&if_expr->cond, &basic_types[Basic_Kind_Bool]) { - ERROR_(if_expr->token->pos, "If-expression expected boolean for condition, got '%s'.", - type_get_name(if_expr->cond->type)); + if (!implicit_cast_to_bool(&if_expr->cond)) { + ERROR_(if_expr->token->pos, "If-expression expected boolean for condition, got '%s'.", + type_get_name(if_expr->cond->type)); + } } resolve_expression_type((AstTyped *) if_expr); diff --git a/compiler/src/parser.c b/compiler/src/parser.c index f5ad3388..e74a1290 100644 --- a/compiler/src/parser.c +++ b/compiler/src/parser.c @@ -1968,8 +1968,13 @@ static void struct_type_create_scope(OnyxParser *parser, AstStructType *s_node) s_node->scope = scope_create(context.ast_alloc, parser->current_scope, s_node->token->pos); parser->current_scope = s_node->scope; - OnyxToken* current_symbol = bh_arr_last(parser->current_symbol_stack); - s_node->scope->name = bh_aprintf(global_heap_allocator, "%b", current_symbol->text, current_symbol->length); + if (bh_arr_length(parser->current_symbol_stack) == 0) { + s_node->scope->name = ""; + + } else { + OnyxToken* current_symbol = bh_arr_last(parser->current_symbol_stack); + s_node->scope->name = bh_aprintf(global_heap_allocator, "%b", current_symbol->text, current_symbol->length); + } } } @@ -2046,6 +2051,8 @@ static AstStructType* parse_struct(OnyxParser* parser) { expect_token(parser, '{'); + struct_type_create_scope(parser, s_node); + b32 member_is_used = 0; bh_arr(OnyxToken *) member_list_temp = NULL; bh_arr_new(global_heap_allocator, member_list_temp, 4); @@ -2054,8 +2061,6 @@ static AstStructType* parse_struct(OnyxParser* parser) { if (parser->hit_unexpected_token) return s_node; if (parse_possible_directive(parser, "persist")) { - struct_type_create_scope(parser, s_node); - b32 thread_local = parse_possible_directive(parser, "thread_local"); OnyxToken* symbol = expect_token(parser, Token_Type_Symbol); @@ -2070,8 +2075,6 @@ static AstStructType* parse_struct(OnyxParser* parser) { } if (next_tokens_are(parser, 3, Token_Type_Symbol, ':', ':')) { - struct_type_create_scope(parser, s_node); - OnyxToken* binding_name = expect_token(parser, Token_Type_Symbol); consume_token(parser); -- 2.25.1