Support for constant expressions in where clauses
authorJudah Caruso <judah@tuta.io>
Fri, 8 Dec 2023 01:22:43 +0000 (18:22 -0700)
committerJudah Caruso <judah@tuta.io>
Fri, 8 Dec 2023 01:22:43 +0000 (18:22 -0700)
compiler/include/astnodes.h
compiler/src/checker.c
compiler/src/parser.c

index 73c5138024b4abed4fd9eb5fa4b0c4c9744d5e1d..134c9fc64d5974fdbdc1fca6807ecb0bc50a74f6 100644 (file)
@@ -311,6 +311,8 @@ typedef enum AstFlags {
 
     Ast_Flag_Function_Is_Lambda    = BH_BIT(26),
     Ast_Flag_Function_Is_Lambda_Inside_PolyProc = BH_BIT(27),
+
+    Ast_Flag_Constraint_Is_Expression = BH_BIT(28),
 } AstFlags;
 
 typedef enum UnaryOp {
@@ -1241,16 +1243,19 @@ typedef enum ConstraintPhase {
 struct AstConstraint {
     AstNode_base;
 
-    ConstraintPhase phase;
-
-    AstInterface *interface;
-    bh_arr(AstType *) type_args;
-
-    ConstraintCheckStatus *report_status;
+    ConstraintPhase        phase;
+    ConstraintCheckStatus* report_status;
 
-    Scope* scope;
-    bh_arr(InterfaceConstraint) exprs;
-    u32 expr_idx;
+    union {
+        struct {
+            AstInterface *              interface;
+            bh_arr(AstType *)           type_args;
+            Scope*                      scope;
+            bh_arr(InterfaceConstraint) exprs;
+            u32                         expr_idx;
+        };
+        AstTyped *const_expr; // only used when flags & Ast_Flag_Constraint_Is_Expression
+    };
 };
 
 
index f09f8e9a18d87b320315f4b62b6d4a6b56ea2f36..4f929d2a27cfea128816717dd0ca3af67a625519 100644 (file)
@@ -3620,73 +3620,113 @@ CheckStatus check_macro(AstMacro* macro) {
     return Check_Success;
 }
 
-CheckStatus check_constraint(AstConstraint *constraint) {
-    switch (constraint->phase) {
-        case Constraint_Phase_Cloning_Expressions: {
-            if (constraint->interface->kind == Ast_Kind_Symbol) {
-                return Check_Return_To_Symres;
-            }
+CheckStatus check_interface_constraint(AstConstraint *constraint) {
+    if (constraint->interface->kind != Ast_Kind_Interface) {
+        // CLEANUP: This error message might not look totally right in some cases.
+        ERROR_(constraint->token->pos, "'%b' is not an interface. It is a '%s'.",
+            constraint->token->text, constraint->token->length,
+            onyx_ast_node_kind_string(constraint->interface->kind));
+    }
+
+    // #intrinsic interfaces
+    if (constraint->interface->is_intrinsic) {
+        b32 success = resolve_intrinsic_interface_constraint(constraint);
+        if (success) {
+            *constraint->report_status = Constraint_Check_Status_Success;
+            return Check_Complete;
+        } else {
+            *constraint->report_status = Constraint_Check_Status_Failed;
+            return Check_Failed;
+        }
+    }
 
-            if (constraint->interface->kind != Ast_Kind_Interface) {
-                // CLEANUP: This error message might not look totally right in some cases.
-                ERROR_(constraint->token->pos, "'%b' is not an interface. It is a '%s'.",
-                    constraint->token->text, constraint->token->length,
-                    onyx_ast_node_kind_string(constraint->interface->kind));
-            }
+    bh_arr_new(global_heap_allocator, constraint->exprs, bh_arr_length(constraint->interface->exprs));
+    bh_arr_each(InterfaceConstraint, ic, constraint->interface->exprs) {
+        InterfaceConstraint new_ic = {0};
+        new_ic.expr = (AstTyped *) ast_clone(context.ast_alloc, (AstNode *) ic->expr);
+        new_ic.expected_type_expr = (AstType *) ast_clone(context.ast_alloc, (AstNode *) ic->expected_type_expr);
+        new_ic.invert_condition = ic->invert_condition;
+        bh_arr_push(constraint->exprs, new_ic);
+    }
 
-            // #intrinsic interfaces
-            if (constraint->interface->is_intrinsic) {
-                b32 success = resolve_intrinsic_interface_constraint(constraint);
-                if (success) {
-                    *constraint->report_status = Constraint_Check_Status_Success;
-                    return Check_Complete;
-                } else {
-                    *constraint->report_status = Constraint_Check_Status_Failed;
-                    return Check_Failed;
-                }
-            }
+    assert(constraint->interface->entity && constraint->interface->entity->scope);
+    assert(constraint->interface->scope);
+    assert(constraint->interface->scope->parent == constraint->interface->entity->scope);
 
-            bh_arr_new(global_heap_allocator, constraint->exprs, bh_arr_length(constraint->interface->exprs));
-            bh_arr_each(InterfaceConstraint, ic, constraint->interface->exprs) {
-                InterfaceConstraint new_ic = {0};
-                new_ic.expr = (AstTyped *) ast_clone(context.ast_alloc, (AstNode *) ic->expr);
-                new_ic.expected_type_expr = (AstType *) ast_clone(context.ast_alloc, (AstNode *) ic->expected_type_expr);
-                new_ic.invert_condition = ic->invert_condition;
-                bh_arr_push(constraint->exprs, new_ic);
-            }
+    constraint->scope = scope_create(context.ast_alloc, constraint->interface->scope, constraint->token->pos);
 
-            assert(constraint->interface->entity && constraint->interface->entity->scope);
-            assert(constraint->interface->scope);
-            assert(constraint->interface->scope->parent == constraint->interface->entity->scope);
+    if (bh_arr_length(constraint->type_args) != bh_arr_length(constraint->interface->params)) {
+        ERROR_(constraint->token->pos, "Wrong number of arguments given to interface. Expected %d, got %d.",
+            bh_arr_length(constraint->interface->params),
+            bh_arr_length(constraint->type_args));
+    }
 
-            constraint->scope = scope_create(context.ast_alloc, constraint->interface->scope, constraint->token->pos);
+    fori (i, 0, bh_arr_length(constraint->interface->params)) {
+        InterfaceParam *ip = &constraint->interface->params[i];
 
-            if (bh_arr_length(constraint->type_args) != bh_arr_length(constraint->interface->params)) {
-                ERROR_(constraint->token->pos, "Wrong number of arguments given to interface. Expected %d, got %d.",
-                    bh_arr_length(constraint->interface->params),
-                    bh_arr_length(constraint->type_args));
-            }
+        AstTyped *sentinel = onyx_ast_node_new(context.ast_alloc, sizeof(AstTyped), Ast_Kind_Constraint_Sentinel);
+        sentinel->token = ip->value_token;
+        sentinel->type_node = constraint->type_args[i];
 
-            fori (i, 0, bh_arr_length(constraint->interface->params)) {
-                InterfaceParam *ip = &constraint->interface->params[i];
+        AstAlias *type_alias = onyx_ast_node_new(context.ast_alloc, sizeof(AstAlias), Ast_Kind_Alias);
+        type_alias->token = ip->type_token;
+        type_alias->alias = (AstTyped *) constraint->type_args[i];
 
-                AstTyped *sentinel = onyx_ast_node_new(context.ast_alloc, sizeof(AstTyped), Ast_Kind_Constraint_Sentinel);
-                sentinel->token = ip->value_token;
-                sentinel->type_node = constraint->type_args[i];
+        symbol_introduce(constraint->scope, ip->value_token, (AstNode *) sentinel);
+        symbol_introduce(constraint->scope, ip->type_token, (AstNode *) type_alias);
+    }
 
-                AstAlias *type_alias = onyx_ast_node_new(context.ast_alloc, sizeof(AstAlias), Ast_Kind_Alias);
-                type_alias->token = ip->type_token;
-                type_alias->alias = (AstTyped *) constraint->type_args[i];
+    assert(constraint->entity);
+    constraint->entity->scope = constraint->scope;
 
-                symbol_introduce(constraint->scope, ip->value_token, (AstNode *) sentinel);
-                symbol_introduce(constraint->scope, ip->type_token, (AstNode *) type_alias);
-            }
+    constraint->phase = Constraint_Phase_Checking_Expressions;
+    return Check_Return_To_Symres;
+}
 
-            assert(constraint->entity);
-            constraint->entity->scope = constraint->scope;
+CheckStatus check_expression_constraint(AstConstraint *constraint) {
+    onyx_errors_enable();
 
-            constraint->phase = Constraint_Phase_Checking_Expressions;
-            return Check_Return_To_Symres;
+    AstTyped* expr = constraint->const_expr;
+
+    context.checker.expression_types_must_be_known = 1;
+    CheckStatus result = check_expression(&expr);
+    context.checker.expression_types_must_be_known = 0;
+
+    if (result == Check_Yield_Macro) return Check_Yield_Macro;
+
+    if (result > Check_Errors_Start || !(expr->flags & Ast_Flag_Comptime)) {
+        ERROR(expr->token->pos, "Where clauses must be a constant expressions.");
+    }
+
+    if (!type_is_bool(expr->type)) {
+        ERROR(expr->token->pos, "Where clauses must result in a boolean.");
+    }
+
+    b32 value = (b32)get_expression_integer_value(expr, NULL);
+    if (!value) {
+        *constraint->report_status = Constraint_Check_Status_Failed;
+        return Check_Failed;
+    }
+
+    expr = (AstTyped *)make_bool_literal(context.ast_alloc, 1);
+    *constraint->report_status = Constraint_Check_Status_Success;
+
+    return Check_Complete;
+}
+
+CheckStatus check_constraint(AstConstraint *constraint) {
+    switch (constraint->phase) {
+        case Constraint_Phase_Cloning_Expressions: {
+            if (constraint->interface->kind == Ast_Kind_Symbol) {
+                return Check_Return_To_Symres;
+            }
+
+            if (constraint->flags & Ast_Flag_Constraint_Is_Expression) {
+                return check_expression_constraint(constraint);
+            }
+            else {
+                return check_interface_constraint(constraint);
+            }
         }
 
         case Constraint_Phase_Checking_Expressions: {
@@ -3803,9 +3843,21 @@ CheckStatus check_constraint_context(ConstraintContext *cc, Scope *scope, OnyxFi
                         error_pos = constraint->interface->token->pos;
                     }
 
-                    onyx_report_error(error_pos, Error_Critical, "Failed to satisfy constraint where %s.", constraint_map);
-                    if (error_msg) onyx_report_error(error_pos, Error_Critical, error_msg);
-                    onyx_report_error(constraint->token->pos, Error_Critical, "Here is where the interface was used.");
+                    if (constraint->flags & Ast_Flag_Constraint_Is_Expression) {
+                        onyx_report_error(error_pos, Error_Critical, "Where clause did not evaluate to true.");
+                    }
+                    else {
+                        onyx_report_error(error_pos, Error_Critical, "Failed to satisfy constraint where %s.", constraint_map);
+                    }
+
+                    if (error_msg) {
+                        onyx_report_error(error_pos, Error_Critical, error_msg);
+                    }
+
+                    if (!(constraint->flags & Ast_Flag_Constraint_Is_Expression)) {
+                        onyx_report_error(constraint->token->pos, Error_Critical, "Here is where the interface was used.");
+                    }
+
                     onyx_report_error(pos, Error_Critical, "Here is the code that caused this constraint to be checked.");
 
                     return Check_Error;
index dabf575d57b85071607c0ca6e332c4493c0bfa26..0898ba6c4e79a701698cdced76e58be47d14dce3 100644 (file)
@@ -2544,26 +2544,62 @@ static AstInterface* parse_interface(OnyxParser* parser) {
     return interface;
 }
 
+// :InterfacesAsExpressionsRefactor
+
+// @Todo(Judah): This can be significantly improved by not treating interface calls as a separate thing.
+// Maybe a flag on parser called 'calls_are_interfaces' that tells parse_expression to do what we do below?
+// This should allow us to remove the N token lookahead and give better error reporting for invalid interface usage.
 static AstConstraint* parse_constraint(OnyxParser* parser) {
     AstConstraint* constraint = make_node(AstConstraint, Ast_Kind_Constraint);
 
-    parser->parse_calls = 0;
-    constraint->interface = (AstInterface *) parse_factor(parser);
-    parser->parse_calls = 1;
+    // @Note(Judah): We lookahead N tokens to see which kind of constraint
+    // this is. A simple check will not match things like: 'foo.bar.Baz(T)'
 
-    constraint->token = constraint->interface->token;
+    i32 i               = 0;
+    b32 parse_interface = 0;
+    while (1) {
+        OnyxToken* next = peek_token(i);
+        if (!next || next->type == '{') break;
 
-    bh_arr_new(global_heap_allocator, constraint->type_args, 2);
+        if (next->type == Token_Type_Symbol) {
+            next = peek_token(i + 1);
+            if (next && next->type == '(') {
+                parse_interface = 1;
+                break;
+            }
+        }
 
-    expect_token(parser, '(');
-    while (!consume_token_if_next(parser, ')')) {
-        if (parser->hit_unexpected_token) return constraint;
+        i += 1;
+    }
 
-        AstType* type_node = parse_type(parser);
-        bh_arr_push(constraint->type_args, type_node);
+    // Interface constraint: Foo(T)
+    if (parse_interface) {
+        parser->parse_calls = 0;
+        constraint->interface = (AstInterface *) parse_factor(parser);
+        parser->parse_calls = 1;
 
-        if (parser->curr->type != ')')
-            expect_token(parser, ',');
+        constraint->token = constraint->interface->token;
+
+        bh_arr_new(global_heap_allocator, constraint->type_args, 2);
+
+        expect_token(parser, '(');
+        while (!consume_token_if_next(parser, ')')) {
+            if (parser->hit_unexpected_token) return constraint;
+
+            AstType* type_node = parse_type(parser);
+            bh_arr_push(constraint->type_args, type_node);
+
+            if (parser->curr->type != ')')
+                expect_token(parser, ',');
+        }
+    }
+    // Expression constraint: T == X
+    else {
+        constraint->const_expr = parse_expression(parser, 0);
+        if (parser->hit_unexpected_token || !constraint->const_expr) return constraint;
+
+        constraint->token  = constraint->const_expr->token;
+        constraint->flags |= Ast_Flag_Constraint_Is_Expression;
     }
 
     return constraint;