From ba365c564f6009e224ea47e2af1e91b887392aeb Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Thu, 4 Feb 2021 14:21:01 -0600 Subject: [PATCH] bulked up core math library --- core/math.onyx | 285 ++++++++++++++++++++++------- tests/aoc-2020/day9.onyx | 4 +- tests/compile_time_procedures.onyx | 2 +- 3 files changed, 219 insertions(+), 72 deletions(-) diff --git a/core/math.onyx b/core/math.onyx index f77e60ea..241c1642 100644 --- a/core/math.onyx +++ b/core/math.onyx @@ -1,18 +1,33 @@ package core.math -use package core.intrinsics.wasm { - sqrt_f32, sqrt_f64, - abs_f32, abs_f64, - copysign_f32, copysign_f64 -} - -E :: 2.71828182845904523536f; -PI :: 3.14159265f; -TAU :: 6.28318330f; - -// Simple taylor series approximation of sin(t) -sin :: (t_: f32) -> f32 { - t := t_; +use package core.intrinsics.wasm as wasm + +// Things that are useful in any math library: +// - Trigonometry +// - modf, fmod + +// Other things that can be useful: +// - Vector math +// - Matrix math +// - Why not tensor math?? +// - Complex numbers +// - Dual numbers + + +E :: 2.71828182845904523536f; +PI :: 3.14159265f; +TAU :: 6.28318330f; +SQRT_2 :: 1.414213562f; + +// +// Trigonometry +// Basic trig functions have been implemented using taylor series approximations. The +// approximations are very accurate, but rather computationally expensive. Programs that +// rely heavily on trig functions would greatly benefit from improvements to the +// implementations of these functions. +// + +sin :: (t: f32) -> f32 { while t >= PI do t -= TAU; while t <= -PI do t += TAU; @@ -33,9 +48,7 @@ sin :: (t_: f32) -> f32 { return res; } -// Simple taylor series approximation of cos(t) -cos :: (t_: f32) -> f32 { - t := t_; +cos :: (t: f32) -> f32 { while t >= PI do t -= TAU; while t <= -PI do t += TAU; @@ -55,86 +68,220 @@ cos :: (t_: f32) -> f32 { return res; } -max :: (a: $T, b: T) -> T { - if a >= b do return a; - return b; +asin :: (t: f32) -> f32 { + assert(false, "asin is not implemented yet!"); + return 0; } -min :: (a: $T, b: T) -> T { - if a <= b do return a; - return b; +acos :: (t: f32) -> f32 { + assert(false, "acos is not implemented yet!"); + return 0; } -sqrt_i32 :: (x: i32) -> i32 do return ~~sqrt_f32(~~x); -sqrt_i64 :: (x: i64) -> i64 do return ~~sqrt_f64(~~x); -sqrt :: proc { sqrt_f32, sqrt_f64, sqrt_i32, sqrt_i64 } +atan :: (t: f32) -> f32 { + assert(false, "atan is not implemented yet!"); + return 0; +} -copysign :: proc { copysign_f32, copysign_f64 } +atan2 :: (t: f32) -> f32 { + assert(false, "atan2 is not implemented yet!"); + return 0; +} -abs_i32 :: (x: i32) -> i32 { - if x >= 0 do return x; - return -x; + + + +// +// Hyperbolic trigonometry. +// The hyperbolic trigonometry functions are implemented using the naive +// definitions. There may be fancier, faster and far more precise methods +// of implementing these, but these definitions should suffice. +// + +sinh :: (t: $T) -> T { + et := exp(t); + return (et - (1 / et)) / 2; } -abs_i64 :: (x: i64) -> i64 { - if x >= 0 do return x; - return -x; + +cosh :: (t: $T) -> T { + et := exp(t); + return (et + (1 / et)) / 2; } -abs :: proc { abs_i32, abs_i64, abs_f32, abs_f64 } -pow_int :: (base: $T, p: i32) -> T { - if base == 0 do return 0; - if p == 0 do return 1; +tanh :: (t: $T) -> T { + et := exp(t); + one_over_et := 1 / et; + return (et - one_over_et) / (et + one_over_et); +} - a: T = 1; - while p > 0 { - if p % 2 == 1 do a *= base; - p = p >> 1; - base *= base; - } +asinh :: (t: $T) -> T { + return ~~ ln(cast(f32) (t + sqrt(t * t + 1))); +} - return a; +acosh :: (t: $T) -> T { + return ~~ ln(cast(f32) (t + sqrt(t * t - 1))); } -pow_float :: (base: $T, p: T) -> T { - if p == 0 do return 1; - if p < 0 do return 1 / pow_float(base, -p); +atanh :: (t: $T) -> T { + return ~~ ln(cast(f32) ((1 + t) / (1 - t))) / 2; +} - if p >= 1 { - tmp := pow_float(p = p / 2, base = base); - return tmp * tmp; - } - low : T = 0; - high : T = 1; - sqr := sqrt(base); - acc := sqr; - mid := high / 2; +// +// Exponentials and logarithms. +// Exponentials with floats are implemented using a binary search using square roots, since +// square roots are intrinsic to WASM and therefore "fast". Expoentials with integers are +// implemented using a fast algorithm that minimizes the number of the mulitplications that +// are needed. Logarithms are implemented using a polynomial that is accurate in the range of +// [1, 2], and then utilizes this identity for values outside of that range, +// +// ln(x) = ln(2^n * v) = n * ln(2) + v, v is in [1, 2] +// - while abs(mid - p) > 0.00001 { - sqr = sqrt(sqr); +exp :: (p: $T) -> T do return pow(base = cast(T) E, p = p); - if mid <= p { - low = mid; - acc *= sqr; - } else { - high = mid; - acc /= sqr; +pow :: proc { + // Fast implementation of power when raising to an integer power. + (base: $T, p: i32) -> T { + if base == 0 do return 0; + if p == 0 do return 1; + + a: T = 1; + while p > 0 { + if p % 2 == 1 do a *= base; + p = p >> 1; + base *= base; } - mid = (low + high) / 2; + return a; + }, + + // Also make the implementation work for 64-bit integers. + (base: $T, p: i64) -> T do return pow(base, cast(i32) p);, + + // Generic power implementation for integers using square roots. + (base: $T, p: T) -> T { + if p == 0 do return 1; + if p < 0 do return 1 / pow(base, -p); + + if p >= 1 { + tmp := pow(p = p / 2, base = base); + return tmp * tmp; + } + + low : T = 0; + high : T = 1; + + sqr := sqrt(base); + acc := sqr; + mid := high / 2; + + while abs(mid - p) > 0.00001 { + sqr = sqrt(sqr); + + if mid <= p { + low = mid; + acc *= sqr; + } else { + high = mid; + acc /= sqr; + } + + mid = (low + high) / 2; + } + + return acc; + } +} + +ln :: (a: f32) -> f32 { + // FIX: This is probably not the most numerically stable solution. + if a < 1 { + return -ln(1 / a); + // log2 := 63 - cast(i32) clz_i64(cast(i64) (1 / a)); + // x := a / cast(f32) (1 << log2); + // res := -8.6731532f + (129.946172f + (-558.971892f + (843.967330f - 409.109529f * x) * x) * x) * x; + // return res + cast(f32) log2 * 0.69314718f; // ln(2) = 0.69314718 } - return acc; + log2 := 63 - cast(i32) wasm.clz_i64(cast(i64) a); + x := a / cast(f32) (1 << log2); + res := -1.7417939f + (2.8212026f + (-1.4699568f + (0.44717955f - 0.056570851f * x) * x) * x) * x; + res += cast(f32) log2 * 0.69314718; // ln(2) = 0.69314718 + return res; } -pow :: proc { pow_int, pow_float } -exp :: (p: $T) -> T do return pow(base = cast(T) E, p = p); +log :: (a: $T, base: $R) -> T { + if a <= 0 || base <= 0 do return 0; + return ~~(ln(cast(f32) a) / ln(cast(f32) base)); +} -ln :: (a: f32) -> f32 { + + +// These function are overloaded in order to use the builtin WASM intrinsics for the +// operation first, and then default to a polymoprhic function that works on any type. +// The clunky part about these at the moment is that, if you wanted to pass 'max' to +// a procedure, you would have to pass 'max_poly' instead, because overloaded functions +// are not resolved when used by value, i.e. foo : (f32) -> f32 = math.max; Even if they +// would be however, the fact that these overloads are intrinsic means they cannot be +// reference from the element section and therefore cannot be passed around or used as +// values. +max :: proc { wasm.max_f32, wasm.max_f64, max_poly } +max_poly :: (a: $T, b: T) -> T { + if a >= b do return a; + return b; } -log :: (a: $T, b: $R) ->T { +min :: proc { wasm.min_f32, wasm.min_f64, min_poly } +min_poly :: (a: $T, b: T) -> T { + if a <= b do return a; + return b; +} +sqrt :: proc { wasm.sqrt_f32, wasm.sqrt_f64, sqrt_poly } +sqrt_poly :: proc (x: $T) -> T { + return ~~ sqrt_f64(~~ x); } + +abs :: proc { wasm.abs_f32, wasm.abs_f64, abs_poly } +abs_poly :: (x: $T) -> T { + if x >= 0 do return x; + return -x; +} + +sign :: (x: $T) -> T { + if x > 0 do return 1; + if x < 0 do return cast(T) -1; + return 0; +} + +copysign :: proc { wasm.copysign_f32, wasm.copysign_f64, copysign_poly } +copysign_poly :: (x: $T, y: T) -> T { + return abs(x) * sign(y); +} + + + + +// +// Floating point rounding +// + +ceil :: proc { wasm.ceil_f32, wasm.ceil_f64 } +floor :: proc { wasm.floor_f32, wasm.floor_f64 } +trunc :: proc { wasm.trunc_f32, wasm.trunc_f64 } +nearest :: proc { wasm.nearest_f32, wasm.nearest_f64 } + + + +// +// Integer operations +// + +clz :: proc { wasm.clz_i32, wasm.clz_i64 } +ctz :: proc { wasm.ctz_i32, wasm.ctz_i64 } +popcnt :: proc { wasm.popcnt_i32, wasm.popcnt_i64 } +rotate_left :: proc { wasm.rotl_i32, wasm.rotl_i64 } +rotate_right :: proc { wasm.rotr_i32, wasm.rotr_i64 } diff --git a/tests/aoc-2020/day9.onyx b/tests/aoc-2020/day9.onyx index 15d6b292..92f387a9 100644 --- a/tests/aoc-2020/day9.onyx +++ b/tests/aoc-2020/day9.onyx @@ -67,8 +67,8 @@ main :: proc (args: [] cstr) { se := find_contiguous_subarray_with_sum(nums, invalid); - max := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) 0, math.max); - min := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) max, math.min); + max := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) 0, math.max_poly); + min := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) max, math.min_poly); printf("Extrema sum: %l\n", min + max); } diff --git a/tests/compile_time_procedures.onyx b/tests/compile_time_procedures.onyx index 518fe9ca..a8503261 100644 --- a/tests/compile_time_procedures.onyx +++ b/tests/compile_time_procedures.onyx @@ -39,7 +39,7 @@ main :: proc (args: [] cstr) { t.vt = ^test_vtable2; println(do_a_thing(^t)); - tmp_vt := Test_VTable.{ first = test_vtable1.second, second = test_vtable2.first, third = math.max }; + tmp_vt := Test_VTable.{ first = test_vtable1.second, second = test_vtable2.first, third = math.max_poly }; t.vt = ^tmp_vt; println(do_a_thing(^t)); -- 2.25.1