From: Brendan Hansen Date: Wed, 9 Nov 2022 16:26:54 +0000 (-0600) Subject: performance improvements for binary operator overloads; polymorph X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=c962f321c963dc68c041127ef713540309ba048b;p=onyx.git performance improvements for binary operator overloads; polymorph reduction bugfix --- diff --git a/compiler/include/astnodes.h b/compiler/include/astnodes.h index ca1d9389..05349313 100644 --- a/compiler/include/astnodes.h +++ b/compiler/include/astnodes.h @@ -285,7 +285,7 @@ typedef enum AstFlags { Ast_Flag_Dead = BH_BIT(22), - Ast_Flag_Extra_Field_Access = BH_BIT(23) + Ast_Flag_Extra_Field_Access = BH_BIT(23), } AstFlags; typedef enum UnaryOp { @@ -617,6 +617,7 @@ struct AstNumLit { } value; b32 was_hex_literal : 1; + b32 was_char_literal : 1; }; struct AstBinaryOp { AstTyped_base; @@ -697,7 +698,7 @@ struct AstArrayLiteral { bh_arr(AstTyped *) values; }; struct AstRangeLiteral { - AstTyped_base; + AstTyped_base; // HACK: Currently, range literals are parsed as binary operators, which means // the first sizeof(AstBinaryOp) bytes of this structure must match that of @@ -850,7 +851,7 @@ struct AstSwitchCase { // NOTE: All expressions that end up in this block bh_arr(AstTyped *) values; - + AstBlock *block; b32 is_default: 1; // Could this be inferred by the values array being null? @@ -939,7 +940,7 @@ struct AstStructType { // completely generated, but is a valid pointer to where the // type will be generated to. Type *pending_type; - + // NOTE: Used to store statically bound expressions in the struct. Scope* scope; @@ -1309,6 +1310,8 @@ struct AstDirectiveOperator { AstNode_base; BinaryOp operator; + + u64 precedence; AstTyped *overload; }; @@ -1402,7 +1405,7 @@ struct AstForeignBlock { typedef enum EntityState { Entity_State_Error, - + Entity_State_Parse_Builtin, Entity_State_Introduce_Symbols, Entity_State_Parse, @@ -1576,7 +1579,7 @@ struct CompileOptions { b32 print_notes : 1; b32 no_colors : 1; b32 no_file_contents : 1; - + b32 use_post_mvp_features : 1; b32 use_multi_threading : 1; b32 generate_foreign_info : 1; @@ -1717,8 +1720,15 @@ typedef enum TypeMatch { } TypeMatch; #define unify_node_and_type(node, type) (unify_node_and_type_((node), (type), 1)) TypeMatch unify_node_and_type_(AstTyped** pnode, Type* type, b32 permanent); + +// resolve_expression_type is a permanent action that modifies +// the node in whatever is necessary to cement a type into it. Type* resolve_expression_type(AstTyped* node); +// query_expression_type does not modify the node at all, but +// does its best to deduce the type of the node without context. +Type* query_expression_type(AstTyped *node); + i64 get_expression_integer_value(AstTyped* node, b32 *out_is_valid); char *get_expression_string_value(AstTyped* node, b32 *out_is_valid); diff --git a/compiler/src/astnodes.c b/compiler/src/astnodes.c index a1c15386..0c2a72b8 100644 --- a/compiler/src/astnodes.c +++ b/compiler/src/astnodes.c @@ -835,6 +835,100 @@ TypeMatch unify_node_and_type_(AstTyped** pnode, Type* type, b32 permanent) { return TYPE_MATCH_FAILED; } +// TODO CLEANUP: Currently, query_expression_type and resolve_expression_type +// are almost the exact same function. Any logic that would be added to one +// will also HAVE TO BE ADDED TO THE OTHER. I would like to abstract the common +// code between them, but I think there enough minor differences that that +// might not be possible. + +Type* query_expression_type(AstTyped *node) { + if (node == NULL) return NULL; + + if (node->kind == Ast_Kind_Argument) { + return query_expression_type(((AstArgument *) node)->value); + } + + if (node->kind == Ast_Kind_If_Expression) { + AstIfExpression* if_expr = (AstIfExpression *) node; + return query_expression_type(if_expr->true_expr); + } + + if (node->kind == Ast_Kind_Alias) { + AstAlias* alias = (AstAlias *) node; + return query_expression_type(alias->alias); + } + + if (node_is_type((AstNode *) node)) { + return &basic_types[Basic_Kind_Type_Index]; + } + + if (node->kind == Ast_Kind_Array_Literal && node->type == NULL) { + AstArrayLiteral* al = (AstArrayLiteral *) node; + Type* elem_type = &basic_types[Basic_Kind_Void]; + if (bh_arr_length(al->values) > 0) { + elem_type = query_expression_type(al->values[0]); + } + + if (elem_type) { + return type_make_array(context.ast_alloc, elem_type, bh_arr_length(al->values)); + } + } + + if (node->kind == Ast_Kind_Struct_Literal && node->type == NULL) { + AstStructLiteral* sl = (AstStructLiteral *) node; + if (sl->stnode || sl->type_node) return NULL; + + // If values without names are given to a struct literal without + // a type, then we cannot implicitly build the type of the struct + // literal, as the name of every member cannot be known. Maybe we + // could implicitly do something like _1, _2, ... for the members + // that we not given names? + if (bh_arr_length(sl->args.values) > 0) { + return NULL; + } + + return type_build_implicit_type_of_struct_literal(context.ast_alloc, sl); + } + + // If polymorphic procedures HAVE to have a type, most likely + // because they are part of a `typeof` expression, they are + // assigned a void type. This is cleared before the procedure + // is solidified. + if (node->kind == Ast_Kind_Polymorphic_Proc) { + return &basic_types[Basic_Kind_Void]; + } + + if (node->kind == Ast_Kind_Macro) { + return query_expression_type((AstTyped *) ((AstMacro *) node)->body); + } + + if (node->kind == Ast_Kind_Package) { + return type_build_from_ast(context.ast_alloc, node->type_node); + } + + if (node->type == NULL) + return type_build_from_ast(context.ast_alloc, node->type_node); + + if (node->kind == Ast_Kind_NumLit && node->type->kind == Type_Kind_Basic) { + if (node->type->Basic.kind == Basic_Kind_Int_Unsized) { + b32 big = bh_abs(((AstNumLit *) node)->value.l) >= (1ull << 32); + b32 unsign = ((AstNumLit *) node)->was_hex_literal; + + if (((AstNumLit *) node)->was_char_literal) return &basic_types[Basic_Kind_U8]; + else if ( big && !unsign) return &basic_types[Basic_Kind_I64]; + else if ( big && unsign) return &basic_types[Basic_Kind_U64]; + else if (!big && !unsign) return &basic_types[Basic_Kind_I32]; + else if (!big && unsign) return &basic_types[Basic_Kind_U32]; + } + else if (node->type->Basic.kind == Basic_Kind_Float_Unsized) { + return &basic_types[Basic_Kind_F64]; + } + } + + return node->type; +} + +// See note above about query_expresion_type. Type* resolve_expression_type(AstTyped* node) { if (node == NULL) return NULL; @@ -930,7 +1024,8 @@ Type* resolve_expression_type(AstTyped* node) { b32 big = bh_abs(((AstNumLit *) node)->value.l) >= (1ull << 32); b32 unsign = ((AstNumLit *) node)->was_hex_literal; - if ( big && !unsign) convert_numlit_to_type((AstNumLit *) node, &basic_types[Basic_Kind_I64]); + if (((AstNumLit *) node)->was_char_literal) convert_numlit_to_type((AstNumLit *) node, &basic_types[Basic_Kind_U8]); + else if ( big && !unsign) convert_numlit_to_type((AstNumLit *) node, &basic_types[Basic_Kind_I64]); else if ( big && unsign) convert_numlit_to_type((AstNumLit *) node, &basic_types[Basic_Kind_U64]); else if (!big && !unsign) convert_numlit_to_type((AstNumLit *) node, &basic_types[Basic_Kind_I32]); else if (!big && unsign) convert_numlit_to_type((AstNumLit *) node, &basic_types[Basic_Kind_U32]); diff --git a/compiler/src/checker.c b/compiler/src/checker.c index 720342c3..7b3ab8cc 100644 --- a/compiler/src/checker.c +++ b/compiler/src/checker.c @@ -116,6 +116,9 @@ b32 expression_types_must_be_known = 0; b32 all_checks_are_final = 1; b32 inside_for_iterator = 0; bh_arr(AstFor *) for_node_stack = NULL; +static bh_imap __binop_impossible_cache[Binary_Op_Count]; +static AstCall __binop_maybe_overloaded; + #define STATEMENT_LEVEL 1 #define EXPRESSION_LEVEL 2 @@ -349,7 +352,7 @@ fornode_expr_checked: do { CheckStatus cs = check_block(fornode->stmt); inside_for_iterator = old_inside_for_iterator; - if (cs > Check_Errors_Start) return cs; + if (cs > Check_Errors_Start) return cs; } while(0); bh_arr_pop(for_node_stack); @@ -407,7 +410,7 @@ static CheckStatus collect_switch_case_blocks(AstSwitch* switchnode, AstBlock* r } else { if (static_if->false_stmt) collect_switch_case_blocks(switchnode, static_if->false_stmt); } - + break; } @@ -784,7 +787,7 @@ static void report_bad_binaryop(AstBinaryOp* binop) { } static AstCall* binaryop_try_operator_overload(AstBinaryOp* binop, AstTyped* third_argument) { - if (bh_arr_length(operator_overloads[binop->operation]) == 0) return NULL; + if (bh_arr_length(operator_overloads[binop->operation]) == 0) return &__binop_maybe_overloaded; if (binop->overload_args == NULL || binop->overload_args->values[1] == NULL) { if (binop->overload_args == NULL) { @@ -1061,6 +1064,12 @@ CheckStatus check_binaryop_bool(AstBinaryOp** pbinop) { return Check_Success; } +static inline b32 type_is_not_basic_or_pointer(Type *t) { + return (t != NULL + && (t->kind != Type_Kind_Basic || (t->Basic.flags & Basic_Flag_SIMD) != 0) + && (t->kind != Type_Kind_Pointer)); +} + CheckStatus check_binaryop(AstBinaryOp** pbinop) { AstBinaryOp* binop = *pbinop; @@ -1125,14 +1134,27 @@ CheckStatus check_binaryop(AstBinaryOp** pbinop) { } // NOTE: Try operator overloading before checking everything else. - if ((binop->left->type != NULL && (binop->left->type->kind != Type_Kind_Basic || (binop->left->type->Basic.flags & Basic_Flag_SIMD) != 0)) - || (binop->right->type != NULL && (binop->right->type->kind != Type_Kind_Basic || (binop->right->type->Basic.flags & Basic_Flag_SIMD) != 0))) { + if (type_is_not_basic_or_pointer(binop->left->type) || type_is_not_basic_or_pointer(binop->right->type)) { + + u64 cache_key = 0; + if (binop->left->type && binop->right->type) { + if (!__binop_impossible_cache[binop->operation].hashes) { + bh_imap_init(&__binop_impossible_cache[binop->operation], global_heap_allocator, 256); + } + + cache_key = ((u64) (binop->left->type->id) << 32ll) | (u64) binop->right->type->id; + + if (bh_imap_has(&__binop_impossible_cache[binop->operation], cache_key)) { + goto definitely_not_op_overload; + } + } + AstCall *implicit_call = binaryop_try_operator_overload(binop, NULL); if (implicit_call == (AstCall *) &node_that_signals_a_yield) YIELD(binop->token->pos, "Trying to resolve operator overload."); - if (implicit_call != NULL) { + if (implicit_call != NULL && implicit_call != &__binop_maybe_overloaded) { // NOTE: Not a binary op implicit_call->next = binop->next; *pbinop = (AstBinaryOp *) implicit_call; @@ -1140,8 +1162,14 @@ CheckStatus check_binaryop(AstBinaryOp** pbinop) { CHECK(call, (AstCall **) pbinop); return Check_Success; } + + if (cache_key && implicit_call != &__binop_maybe_overloaded) { + bh_imap_put(&__binop_impossible_cache[binop->operation], cache_key, 1); + } } + definitely_not_op_overload: + if (binop_is_assignment(binop->operation)) return check_binaryop_assignment(pbinop); if (binop->left->type == NULL && binop->left->entity && binop->left->entity->state <= Entity_State_Check_Types) { @@ -1323,8 +1351,10 @@ CheckStatus check_struct_literal(AstStructLiteral* sl) { } TYPE_CHECK(&sl->args.values[0], type_to_match) { - ERROR(sl->token->pos, - "Mismatched type in initialized type. FIX ME"); + ERROR_(sl->token->pos, + "Mismatched type in initialized type. Expected something of type '%s', got '%s'.", + type_get_name(type_to_match), + type_get_name(sl->args.values[0]->type)); } sl->flags |= Ast_Flag_Has_Been_Checked; @@ -1582,7 +1612,7 @@ CheckStatus check_address_of(AstAddressOf** paof) { if (node_is_addressable_literal((AstNode *) aof->expr)) { resolve_expression_type(aof->expr); } - + if (aof->expr->type == NULL) { YIELD(aof->token->pos, "Trying to resolve type of expression to take a reference."); } @@ -1645,7 +1675,8 @@ CheckStatus check_subscript(AstSubscript** psub) { // NOTE: Try operator overloading before checking everything else. if (sub->expr->type != NULL && - (sub->addr->type->kind != Type_Kind_Basic || sub->expr->type->kind != Type_Kind_Basic)) { + (sub->addr->type->kind != Type_Kind_Basic || sub->expr->type->kind != Type_Kind_Basic) + && !(type_is_array_accessible(sub->addr->type))) { // AstSubscript is the same as AstBinaryOp for the first sizeof(AstBinaryOp) bytes AstBinaryOp* binop = (AstBinaryOp *) sub; AstCall *implicit_call = binaryop_try_operator_overload(binop, NULL); @@ -3192,7 +3223,7 @@ CheckStatus check_polyquery(AstPolyQuery *query) { case TYPE_MATCH_SPECIAL: if (solved_something || query->successful_symres) { - return Check_Return_To_Symres; + return Check_Return_To_Symres; } else { return Check_Yield_Macro; } diff --git a/compiler/src/parser.c b/compiler/src/parser.c index e74a1290..e79b6acd 100644 --- a/compiler/src/parser.c +++ b/compiler/src/parser.c @@ -632,6 +632,7 @@ static AstTyped* parse_factor(OnyxParser* parser) { char_lit->flags |= Ast_Flag_Comptime; char_lit->type_node = (AstType *) &basic_type_int_unsized; char_lit->token = expect_token(parser, Token_Type_Literal_String); + char_lit->was_char_literal = 1; i8 dest = '\0'; i32 length = string_process_escape_seqs((char *) &dest, char_lit->token->text, 1); @@ -3244,6 +3245,16 @@ static void parse_top_level_statement(OnyxParser* parser) { } operator_determined: + if (parse_possible_directive(parser, "precedence")) { + AstNumLit* pre = parse_int_literal(parser); + if (parser->hit_unexpected_token) return; + + operator->precedence = bh_max(pre->value.l, 0); + + } else { + operator->precedence = parser->overload_count++; + } + operator->overload = parse_expression(parser, 0); ENTITY_SUBMIT(operator); diff --git a/compiler/src/polymorph.h b/compiler/src/polymorph.h index ae79d30b..91de95c4 100644 --- a/compiler/src/polymorph.h +++ b/compiler/src/polymorph.h @@ -44,7 +44,7 @@ void insert_poly_sln_into_scope(Scope* scope, AstPolySolution *sln) { onyx_report_error(sln->value->token->pos, Error_Critical, "Expected value to be compile time known."); return; } - + node = (AstNode *) sln->value; break; } @@ -529,12 +529,12 @@ static void solve_for_polymorphic_param_type(PolySolveResult* resolved, AstFunct } } - if (all_types) + if (all_types) typed_param = try_lookup_based_on_partial_function_type((AstFunction *) potential, ft); skip_nested_polymorph_case: - actual_type = resolve_expression_type(typed_param); + actual_type = query_expression_type(typed_param); if (actual_type == NULL) return; break; @@ -910,7 +910,7 @@ AstFunction* polymorphic_proc_build_only_header_with_slns(AstFunction* pp, bh_ar // NOTE: Cache the function for later use. shput(pp->concrete_funcs, unique_key, solidified_func); - + return (AstFunction *) &node_that_signals_a_yield; } diff --git a/compiler/src/symres.c b/compiler/src/symres.c index 4ba8554a..4b93f276 100644 --- a/compiler/src/symres.c +++ b/compiler/src/symres.c @@ -1414,7 +1414,7 @@ static SymresStatus symres_process_directive(AstNode* directive) { return Symres_Error; } - add_overload_option(&operator_overloads[operator->operator], 0, operator->overload); + add_overload_option(&operator_overloads[operator->operator], operator->precedence, operator->overload); break; } diff --git a/compiler/src/types.c b/compiler/src/types.c index 1d9acf14..e74f0c78 100644 --- a/compiler/src/types.c +++ b/compiler/src/types.c @@ -392,7 +392,7 @@ Type* type_build_from_ast(bh_allocator alloc, AstType* type_node) { mem_alignment = type_alignment_of((*member)->type); if (mem_alignment <= 0) { - onyx_report_error((*member)->token->pos, Error_Critical, "Invalid member type: %s. Has alignment %d", type_get_name((*member)->type), mem_alignment); + onyx_report_error((*member)->token->pos, Error_Critical, "Invalid member type: %s. Has alignment %d", type_get_name((*member)->type), mem_alignment); return NULL; } @@ -486,7 +486,7 @@ Type* type_build_from_ast(bh_allocator alloc, AstType* type_node) { } case Ast_Kind_VarArg_Type: { - Type* va_type = type_make_varargs(alloc, type_build_from_ast(alloc, ((AstVarArgType *) type_node)->elem)); + Type* va_type = type_make_varargs(alloc, type_build_from_ast(alloc, ((AstVarArgType *) type_node)->elem)); if (va_type) va_type->ast_type = type_node; return va_type; } @@ -603,7 +603,7 @@ Type* type_build_from_ast(bh_allocator alloc, AstType* type_node) { if (type_of->resolved_type != NULL) { return type_of->resolved_type; } - + return NULL; } @@ -613,7 +613,7 @@ Type* type_build_from_ast(bh_allocator alloc, AstType* type_node) { Type *base_type = type_build_from_ast(alloc, distinct->base_type); if (base_type == NULL) return NULL; - if (base_type->kind != Type_Kind_Basic) { + if (base_type->kind != Type_Kind_Basic && base_type->kind != Type_Kind_Pointer) { onyx_report_error(distinct->token->pos, Error_Critical, "Distinct types can only be made out of primitive types. '%s' is not a primitive type.", type_get_name(base_type)); return NULL; } @@ -697,9 +697,9 @@ Type* type_build_compound_type(bh_allocator alloc, AstCompound* compound) { comp_type->Compound.types[i] = compound->exprs[i]->type; comp_type->Compound.size += bh_max(type_size_of(comp_type->Compound.types[i]), 4); } - + bh_align(comp_type->Compound.size, 4); - + comp_type->Compound.linear_members = NULL; bh_arr_new(global_heap_allocator, comp_type->Compound.linear_members, comp_type->Compound.count); build_linear_types_with_offset(comp_type, &comp_type->Compound.linear_members, 0); @@ -741,7 +741,7 @@ Type* type_build_implicit_type_of_struct_literal(bh_allocator alloc, AstStructLi } alignment = bh_max(alignment, mem_alignment); - + // Should these structs be packed or not? bh_align(offset, mem_alignment); @@ -876,7 +876,7 @@ Type* type_make_dynarray(bh_allocator alloc, Type* of) { Type* type_make_varargs(bh_allocator alloc, Type* of) { if (of == NULL) return NULL; if (of == (Type *) &node_that_signals_failure) return of; - + assert(of->id > 0); u64 vararg_id = bh_imap_get(&type_vararg_map, of->id); if (vararg_id > 0) { @@ -1233,7 +1233,7 @@ i32 type_linear_member_count(Type* type) { case Type_Kind_DynArray: return 5; case Type_Kind_Compound: return bh_arr_length(type->Compound.linear_members); case Type_Kind_Struct: return bh_arr_length(type->Struct.linear_members); - default: return 1; + default: return 1; } } @@ -1241,7 +1241,7 @@ b32 type_linear_member_lookup(Type* type, i32 idx, TypeWithOffset* two) { switch (type->kind) { case Type_Kind_Slice: case Type_Kind_VarArgs: { - if (idx == 0) { + if (idx == 0) { two->type = type_make_pointer(context.ast_alloc, type->Slice.elem); two->offset = 0; } @@ -1253,7 +1253,7 @@ b32 type_linear_member_lookup(Type* type, i32 idx, TypeWithOffset* two) { return 1; } case Type_Kind_DynArray: { - if (idx == 0) { + if (idx == 0) { two->type = type_make_pointer(context.ast_alloc, type->DynArray.elem); two->offset = 0; } @@ -1407,7 +1407,7 @@ b32 type_is_compound(Type* type) { // single non-compound value; in this situation, the structure can be // "dissolved" at compile-time and turn into the underlying type. // - + if (bh_arr_length(type->Struct.linear_members) != 1) return 1; return type_is_compound(type->Struct.linear_members[0].type); } diff --git a/core/net/net.onyx b/core/net/net.onyx index c27f0175..fc14abb8 100644 --- a/core/net/net.onyx +++ b/core/net/net.onyx @@ -150,7 +150,7 @@ socket_accept :: (s: ^Socket) -> (Socket, Socket_Address) { new_addr: Socket_Address; new_socket.handle = __net_accept(s.handle, ^new_addr); - if new_socket.handle >= 0 { + if cast(i32) new_socket.handle >= 0 { new_socket.vtable = ^__net_socket_vtable; } @@ -241,7 +241,7 @@ network_to_host :: #match {} #local __net_socket_vtable := io.Stream_Vtable.{ read = (use s: ^Socket, buffer: [] u8) -> (io.Error, u32) { - if handle == 0 do return .BadFile, 0; + if cast(i32) handle == 0 do return .BadFile, 0; would_block := false; bytes_read := __net_recv(handle, buffer, ^would_block); @@ -253,7 +253,7 @@ network_to_host :: #match {} }, write_byte = (use s: ^Socket, byte: u8) -> io.Error { - if handle == 0 do return .BadFile; + if cast(i32) handle == 0 do return .BadFile; bytes_written := __net_send(handle, .[ byte ]); if bytes_written < 0 { s.vtable = null; return .BufferFull; } @@ -261,7 +261,7 @@ network_to_host :: #match {} }, write = (use s: ^Socket, buffer: [] u8) -> (io.Error, u32) { - if handle == 0 do return .BadFile, 0; + if cast(i32) handle == 0 do return .BadFile, 0; bytes_written := __net_send(handle, buffer); if bytes_written < 0 { s.vtable = null; } diff --git a/core/std.onyx b/core/std.onyx index 4cdfa5fe..cced1f57 100644 --- a/core/std.onyx +++ b/core/std.onyx @@ -44,6 +44,7 @@ package core #load "./misc/arg_parse" #load "./misc/any_utils" +#load "./misc/method_ops" #local runtime :: package runtime #if runtime.runtime == .Wasi || runtime.runtime == .Onyx { diff --git a/tests/aoc-2021/day21.onyx b/tests/aoc-2021/day21.onyx index 1ae72403..cce34f15 100644 --- a/tests/aoc-2021/day21.onyx +++ b/tests/aoc-2021/day21.onyx @@ -96,6 +96,10 @@ Player :: struct { return h; } + __eq :: (p1, p2: Player) => { + return p1.square == p2.square && p1.score == p2.score; + } + move :: (p: Player, m: u32) -> Player { n := p; n.square += m; @@ -106,7 +110,12 @@ Player :: struct { } } -#operator == (p1, p2: Player) => p1.square == p2.square && p1.score == p2.score; + +#local __HasEqMethod :: interface (t: $T) { + { T.__eq(t, t) } -> bool; +} + +#operator == macro (t1: $T/__HasEqMethod, t2: typeof t1) => T.__eq(t1, t2); Case :: struct {