added implicit cast to bool for pointers and arrays
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 31 Oct 2022 03:21:58 +0000 (22:21 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 31 Oct 2022 03:21:58 +0000 (22:21 -0500)
compiler/include/astnodes.h
compiler/src/astnodes.c
compiler/src/checker.c
compiler/src/parser.c

index c58df410c3046dcb3bb94f430cea90c95012c302..0b3365936879dc3f2ef360b3f404db82e4263569 100644 (file)
@@ -1723,6 +1723,8 @@ char *get_expression_string_value(AstTyped* node, b32 *out_is_valid);
 b32 cast_is_legal(Type* from_, Type* to_, char** err_msg);
 char* get_function_name(AstFunction* func);
 
+b32 implicit_cast_to_bool(AstTyped **pnode);
+
 AstNode* strip_aliases(AstNode* node);
 
 AstNumLit*       make_bool_literal(bh_allocator, b32 b);
index 70601550990829b8fbc643b2c25d511613e9705e..e09c9402359b477288495ebd6b31721648630ad3 100644 (file)
@@ -1185,6 +1185,48 @@ b32 cast_is_legal(Type* from_, Type* to_, char** err_msg) {
     return 1;
 }
 
+b32 implicit_cast_to_bool(AstTyped **pnode) {
+    AstTyped *node = *pnode;
+
+    if (node->type->kind == Type_Kind_Pointer) {
+        AstNumLit *zero = make_int_literal(context.ast_alloc, 0);
+        zero->type = &basic_types[Basic_Kind_Rawptr];
+
+        AstBinaryOp* cmp = make_binary_op(context.ast_alloc, Binary_Op_Not_Equal, node, (AstTyped *) zero);
+        cmp->token = node->token;
+        cmp->type = &basic_types[Basic_Kind_Bool];
+
+        *pnode = (AstTyped *) cmp;
+        return 1;
+    }
+
+    if (node->type->kind == Type_Kind_Slice ||
+        node->type->kind == Type_Kind_DynArray ||
+        node->type->kind == Type_Kind_VarArgs) {
+        StructMember smem;
+        assert(type_lookup_member(node->type, "count", &smem));
+
+        // These fields are filled out here in order to prevent
+        // going through the type checker one more time.
+        AstFieldAccess *field = make_field_access(context.ast_alloc, node, "count");
+        field->offset = smem.offset;
+        field->idx = smem.idx;
+        field->type = smem.type;
+        field->flags |= Ast_Flag_Has_Been_Checked;
+
+        AstNumLit *zero = make_int_literal(context.ast_alloc, 0);
+        zero->type = smem.type;
+
+        AstBinaryOp* cmp = make_binary_op(context.ast_alloc, Binary_Op_Not_Equal, (AstTyped *) field, (AstTyped *) zero);
+        cmp->type = &basic_types[Basic_Kind_Bool];
+
+        *pnode = (AstTyped *) cmp;
+        return 1;
+    }
+
+    return 0;
+}
+
 char* get_function_name(AstFunction* func) {
     if (func->kind != Ast_Kind_Function) return "<procedure>";
 
index 2f2ad2f6bc78f28512992eb888851e29786329ff..f879a45c9ed183629672e70f970949ee2e8e99e5 100644 (file)
@@ -203,7 +203,9 @@ CheckStatus check_if(AstIfWhile* ifnode) {
         CHECK(expression, &ifnode->cond);
 
         if (!type_is_bool(ifnode->cond->type)) {
-            ERROR_(ifnode->cond->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(ifnode->cond->type));
+            if (!implicit_cast_to_bool(&ifnode->cond)) {
+                ERROR_(ifnode->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(ifnode->cond->type));
+            }
         }
 
         if (ifnode->true_stmt)  CHECK(statement, (AstNode **) &ifnode->true_stmt);
@@ -219,7 +221,9 @@ CheckStatus check_while(AstIfWhile* whilenode) {
     CHECK(expression, &whilenode->cond);
 
     if (!type_is_bool(whilenode->cond->type)) {
-        ERROR_(whilenode->cond->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(whilenode->cond->type));
+        if (!implicit_cast_to_bool(&whilenode->cond)) {
+            ERROR_(whilenode->token->pos, "Expected expression of type 'bool' for condition, got '%s'", type_get_name(whilenode->cond->type));
+        }
     }
 
     if (whilenode->true_stmt)  CHECK(statement, (AstNode **) &whilenode->true_stmt);
@@ -1028,7 +1032,22 @@ CheckStatus check_binaryop_compare(AstBinaryOp** pbinop) {
 CheckStatus check_binaryop_bool(AstBinaryOp** pbinop) {
     AstBinaryOp* binop = *pbinop;
 
-    if (!type_is_bool(binop->left->type) || !type_is_bool(binop->right->type)) {
+    b32 left_is_bool = 0;
+    b32 right_is_bool = 0;
+
+    if (type_is_bool(binop->left->type)) {
+        left_is_bool = 1;
+    } else if (implicit_cast_to_bool(&binop->left)) {
+        left_is_bool = 1;
+    }
+
+    if (type_is_bool(binop->right->type)) {
+        right_is_bool = 1;
+    } else if (implicit_cast_to_bool(&binop->right)) {
+        right_is_bool = 1;
+    }
+
+    if (!left_is_bool || !right_is_bool) {
         report_bad_binaryop(binop);
         return Check_Error;
     }
@@ -1227,16 +1246,15 @@ CheckStatus check_unaryop(AstUnaryOp** punop) {
         if (!cast_is_legal(unaryop->expr->type, unaryop->type, &err)) {
             ERROR_(unaryop->token->pos, "Cast Error: %s", err);
         }
-
-    } else {
-        unaryop->type = unaryop->expr->type;
     }
 
     if (unaryop->operation == Unary_Op_Not) {
         if (!type_is_bool(unaryop->expr->type)) {
-            ERROR_(unaryop->token->pos,
-                    "Bool negation operator expected bool type, got '%s'.",
-                    node_get_type_name(unaryop->expr));
+            if (!implicit_cast_to_bool(&unaryop->expr)) {
+                ERROR_(unaryop->token->pos,
+                        "Bool negation operator expected bool type, got '%s'.",
+                        node_get_type_name(unaryop->expr));
+            }
         }
     }
 
@@ -1248,6 +1266,10 @@ CheckStatus check_unaryop(AstUnaryOp** punop) {
         }
     }
 
+    if (unaryop->operation != Unary_Op_Cast) {
+        unaryop->type = unaryop->expr->type;
+    }
+
     if (unaryop->expr->flags & Ast_Flag_Comptime) {
         unaryop->flags |= Ast_Flag_Comptime;
         // NOTE: Not a unary op
@@ -1475,8 +1497,10 @@ CheckStatus check_if_expression(AstIfExpression* if_expr) {
     CHECK(expression, &if_expr->false_expr);
 
     TYPE_CHECK(&if_expr->cond, &basic_types[Basic_Kind_Bool]) {
-        ERROR_(if_expr->token->pos, "If-expression expected boolean for condition, got '%s'.",
-            type_get_name(if_expr->cond->type));
+        if (!implicit_cast_to_bool(&if_expr->cond)) {
+            ERROR_(if_expr->token->pos, "If-expression expected boolean for condition, got '%s'.",
+                type_get_name(if_expr->cond->type));
+        }
     }
 
     resolve_expression_type((AstTyped *) if_expr);
index f5ad33887192c6c8b707dfee49ba2c8bb91fcc9c..e74a1290bdc4a040d085350d0cfc9ba77572cb40 100644 (file)
@@ -1968,8 +1968,13 @@ static void struct_type_create_scope(OnyxParser *parser, AstStructType *s_node)
         s_node->scope = scope_create(context.ast_alloc, parser->current_scope, s_node->token->pos);
         parser->current_scope = s_node->scope;
 
-        OnyxToken* current_symbol = bh_arr_last(parser->current_symbol_stack);
-        s_node->scope->name = bh_aprintf(global_heap_allocator, "%b", current_symbol->text, current_symbol->length);
+        if (bh_arr_length(parser->current_symbol_stack) == 0) {
+            s_node->scope->name = "<anonymous>";
+
+        } else {
+            OnyxToken* current_symbol = bh_arr_last(parser->current_symbol_stack);
+            s_node->scope->name = bh_aprintf(global_heap_allocator, "%b", current_symbol->text, current_symbol->length);
+        }
     }
 }
 
@@ -2046,6 +2051,8 @@ static AstStructType* parse_struct(OnyxParser* parser) {
 
     expect_token(parser, '{');
 
+    struct_type_create_scope(parser, s_node);
+
     b32 member_is_used = 0;
     bh_arr(OnyxToken *) member_list_temp = NULL;
     bh_arr_new(global_heap_allocator, member_list_temp, 4);
@@ -2054,8 +2061,6 @@ static AstStructType* parse_struct(OnyxParser* parser) {
         if (parser->hit_unexpected_token) return s_node;
 
         if (parse_possible_directive(parser, "persist")) {
-            struct_type_create_scope(parser, s_node);
-            
             b32 thread_local = parse_possible_directive(parser, "thread_local");
 
             OnyxToken* symbol = expect_token(parser, Token_Type_Symbol);
@@ -2070,8 +2075,6 @@ static AstStructType* parse_struct(OnyxParser* parser) {
         }
 
         if (next_tokens_are(parser, 3, Token_Type_Symbol, ':', ':')) {
-            struct_type_create_scope(parser, s_node);
-
             OnyxToken* binding_name = expect_token(parser, Token_Type_Symbol);
             consume_token(parser);