bulked up core math library
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Thu, 4 Feb 2021 20:21:01 +0000 (14:21 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Thu, 4 Feb 2021 20:21:01 +0000 (14:21 -0600)
core/math.onyx
tests/aoc-2020/day9.onyx
tests/compile_time_procedures.onyx

index f77e60eabb7f0dc448beb361d1b2f846ae08f682..241c1642f8594cae68a04a30be72ef6121a70800 100644 (file)
@@ -1,18 +1,33 @@
 package core.math
 
-use package core.intrinsics.wasm {
-    sqrt_f32, sqrt_f64,
-    abs_f32,  abs_f64,
-    copysign_f32, copysign_f64
-}
-
-E   :: 2.71828182845904523536f;
-PI  :: 3.14159265f;
-TAU :: 6.28318330f;
-
-// Simple taylor series approximation of sin(t)
-sin :: (t_: f32) -> f32 {
-    t := t_;
+use package core.intrinsics.wasm as wasm
+
+// Things that are useful in any math library:
+//  - Trigonometry
+//  - modf, fmod
+
+// Other things that can be useful:
+//  - Vector math
+//  - Matrix math
+//  - Why not tensor math??
+//  - Complex numbers
+//  - Dual numbers
+
+
+E      :: 2.71828182845904523536f;
+PI     :: 3.14159265f;
+TAU    :: 6.28318330f;
+SQRT_2 :: 1.414213562f;
+
+//
+// Trigonometry
+// Basic trig functions have been implemented using taylor series approximations. The
+// approximations are very accurate, but rather computationally expensive. Programs that
+// rely heavily on trig functions would greatly benefit from improvements to the
+// implementations of these functions.
+//
+
+sin :: (t: f32) -> f32 {
     while t >=  PI do t -= TAU;
     while t <= -PI do t += TAU;
 
@@ -33,9 +48,7 @@ sin :: (t_: f32) -> f32 {
     return res;
 }
 
-// Simple taylor series approximation of cos(t)
-cos :: (t_: f32) -> f32 {
-    t := t_;
+cos :: (t: f32) -> f32 {
     while t >=  PI do t -= TAU;
     while t <= -PI do t += TAU;
 
@@ -55,86 +68,220 @@ cos :: (t_: f32) -> f32 {
     return res;
 }
 
-max :: (a: $T, b: T) -> T {
-    if a >= b do return a;
-    return b;
+asin :: (t: f32) -> f32 {
+    assert(false, "asin is not implemented yet!");
+    return 0;
 }
 
-min :: (a: $T, b: T) -> T {
-    if a <= b do return a;
-    return b;
+acos :: (t: f32) -> f32 {
+    assert(false, "acos is not implemented yet!");
+    return 0;
 }
 
-sqrt_i32 :: (x: i32) -> i32 do return ~~sqrt_f32(~~x);
-sqrt_i64 :: (x: i64) -> i64 do return ~~sqrt_f64(~~x);
-sqrt :: proc { sqrt_f32, sqrt_f64, sqrt_i32, sqrt_i64 }
+atan :: (t: f32) -> f32 {
+    assert(false, "atan is not implemented yet!");
+    return 0;
+}
 
-copysign :: proc { copysign_f32, copysign_f64 }
+atan2 :: (t: f32) -> f32 {
+    assert(false, "atan2 is not implemented yet!");
+    return 0;
+}
 
-abs_i32 :: (x: i32) -> i32 {
-    if x >= 0 do return x;
-    return -x;
+
+
+
+//
+// Hyperbolic trigonometry.
+// The hyperbolic trigonometry functions are implemented using the naive
+// definitions. There may be fancier, faster and far more precise methods
+// of implementing these, but these definitions should suffice.
+//
+
+sinh :: (t: $T) -> T {
+    et := exp(t);
+    return (et - (1 / et)) / 2;
 }
-abs_i64 :: (x: i64) -> i64 {
-    if x >= 0 do return x;
-    return -x;
+
+cosh :: (t: $T) -> T {
+    et := exp(t);
+    return (et + (1 / et)) / 2;
 }
-abs :: proc { abs_i32, abs_i64, abs_f32, abs_f64 }
 
-pow_int :: (base: $T, p: i32) -> T {
-    if base == 0 do return 0;
-    if p == 0    do return 1;
+tanh :: (t: $T) -> T {
+    et := exp(t);
+    one_over_et := 1 / et;
+    return (et - one_over_et) / (et + one_over_et);
+}
 
-    a: T = 1;
-    while p > 0 {
-        if p % 2 == 1 do a *= base;
-        p = p >> 1;
-        base *= base;
-    }
+asinh :: (t: $T) -> T {
+    return ~~ ln(cast(f32) (t + sqrt(t * t + 1)));
+}
 
-    return a;
+acosh :: (t: $T) -> T {
+    return ~~ ln(cast(f32) (t + sqrt(t * t - 1)));
 }
 
-pow_float :: (base: $T, p: T) -> T {
-    if p == 0 do return 1;
-    if p < 0  do return 1 / pow_float(base, -p);
+atanh :: (t: $T) -> T {
+    return ~~ ln(cast(f32) ((1 + t) / (1 - t))) / 2;
+}
 
-    if p >= 1 {
-        tmp := pow_float(p = p / 2, base = base);
-        return tmp * tmp;
-    }
 
-    low  : T = 0;
-    high : T = 1;
 
-    sqr := sqrt(base);
-    acc := sqr;
-    mid := high / 2;
+//
+// Exponentials and logarithms.
+// Exponentials with floats are implemented using a binary search using square roots, since
+// square roots are intrinsic to WASM and therefore "fast". Expoentials with integers are
+// implemented using a fast algorithm that minimizes the number of the mulitplications that
+// are needed. Logarithms are implemented using a polynomial that is accurate in the range of
+// [1, 2], and then utilizes this identity for values outside of that range,
+//
+//      ln(x) = ln(2^n * v) = n * ln(2) + v,   v is in [1, 2]
+//
 
-    while abs(mid - p) > 0.00001 {
-        sqr = sqrt(sqr);
+exp :: (p: $T) -> T do return pow(base = cast(T) E, p = p);
 
-        if mid <= p {
-            low = mid;
-            acc *= sqr;
-        } else {
-            high = mid;
-            acc /= sqr;
+pow :: proc {
+    // Fast implementation of power when raising to an integer power.
+    (base: $T, p: i32) -> T {
+        if base == 0 do return 0;
+        if p == 0    do return 1;
+
+        a: T = 1;
+        while p > 0 {
+            if p % 2 == 1 do a *= base;
+            p = p >> 1;
+            base *= base;
         }
 
-        mid = (low + high) / 2;
+        return a;
+    },
+
+    // Also make the implementation work for 64-bit integers.
+    (base: $T, p: i64) -> T do return pow(base, cast(i32) p);,
+
+    // Generic power implementation for integers using square roots.
+    (base: $T, p: T) -> T {
+        if p == 0 do return 1;
+        if p < 0  do return 1 / pow(base, -p);
+
+        if p >= 1 {
+            tmp := pow(p = p / 2, base = base);
+            return tmp * tmp;
+        }
+
+        low  : T = 0;
+        high : T = 1;
+
+        sqr := sqrt(base);
+        acc := sqr;
+        mid := high / 2;
+
+        while abs(mid - p) > 0.00001 {
+            sqr = sqrt(sqr);
+
+            if mid <= p {
+                low = mid;
+                acc *= sqr;
+            } else {
+                high = mid;
+                acc /= sqr;
+            }
+
+            mid = (low + high) / 2;
+        }
+
+        return acc;
+    }
+}
+
+ln :: (a: f32) -> f32 {
+    // FIX: This is probably not the most numerically stable solution.
+    if a < 1 {
+        return -ln(1 / a);
+        // log2 := 63 - cast(i32) clz_i64(cast(i64) (1 / a));
+        // x    := a / cast(f32) (1 << log2);
+        // res  := -8.6731532f + (129.946172f + (-558.971892f + (843.967330f - 409.109529f * x) * x) * x) * x;
+        // return res + cast(f32) log2 * 0.69314718f; // ln(2) = 0.69314718
     }
 
-    return acc;
+    log2 := 63 - cast(i32) wasm.clz_i64(cast(i64) a);
+    x    := a / cast(f32) (1 << log2);
+    res  := -1.7417939f + (2.8212026f + (-1.4699568f + (0.44717955f - 0.056570851f * x) * x) * x) * x;
+    res  += cast(f32) log2 * 0.69314718; // ln(2) = 0.69314718
+    return res;
 }
 
-pow :: proc { pow_int, pow_float }
-exp :: (p: $T) -> T do return pow(base = cast(T) E, p = p);
+log :: (a: $T, base: $R) -> T {
+    if a <= 0 || base <= 0 do return 0;
+    return ~~(ln(cast(f32) a) / ln(cast(f32) base));
+}
 
-ln :: (a: f32) -> f32 {
 
+
+
+// These function are overloaded in order to use the builtin WASM intrinsics for the
+// operation first, and then default to a polymoprhic function that works on any type.
+// The clunky part about these at the moment is that, if you wanted to pass 'max' to
+// a procedure, you would have to pass 'max_poly' instead, because overloaded functions
+// are not resolved when used by value, i.e. foo : (f32) -> f32 = math.max; Even if they
+// would be however, the fact that these overloads are intrinsic means they cannot be
+// reference from the element section and therefore cannot be passed around or used as
+// values.
+max :: proc { wasm.max_f32, wasm.max_f64, max_poly }
+max_poly :: (a: $T, b: T) -> T {
+    if a >= b do return a;
+    return b;
 }
 
-log :: (a: $T, b: $R) ->T {
+min :: proc { wasm.min_f32, wasm.min_f64, min_poly }
+min_poly :: (a: $T, b: T) -> T {
+    if a <= b do return a;
+    return b;
+}
 
+sqrt :: proc { wasm.sqrt_f32, wasm.sqrt_f64, sqrt_poly }
+sqrt_poly :: proc (x: $T) -> T {
+    return ~~ sqrt_f64(~~ x);
 }
+
+abs :: proc { wasm.abs_f32, wasm.abs_f64, abs_poly }
+abs_poly :: (x: $T) -> T {
+    if x >= 0 do return x;
+    return -x;
+}
+
+sign :: (x: $T) -> T {
+    if x > 0 do return 1;
+    if x < 0 do return cast(T) -1;
+    return 0;
+}
+
+copysign :: proc { wasm.copysign_f32, wasm.copysign_f64, copysign_poly }
+copysign_poly :: (x: $T, y: T) -> T {
+    return abs(x) * sign(y);
+}
+
+
+
+
+//
+// Floating point rounding
+//
+
+ceil    :: proc { wasm.ceil_f32,    wasm.ceil_f64    }
+floor   :: proc { wasm.floor_f32,   wasm.floor_f64   }
+trunc   :: proc { wasm.trunc_f32,   wasm.trunc_f64   }
+nearest :: proc { wasm.nearest_f32, wasm.nearest_f64 }
+
+
+
+//
+// Integer operations
+//
+
+clz          :: proc { wasm.clz_i32,    wasm.clz_i64    }
+ctz          :: proc { wasm.ctz_i32,    wasm.ctz_i64    }
+popcnt       :: proc { wasm.popcnt_i32, wasm.popcnt_i64 }
+rotate_left  :: proc { wasm.rotl_i32,   wasm.rotl_i64   }
+rotate_right :: proc { wasm.rotr_i32,   wasm.rotr_i64   }
index 15d6b2921feb99ceb29f1feb0e4b59221ef6d50b..92f387a9ceabeb07d067cabb0ab5a7fba62fe723 100644 (file)
@@ -67,8 +67,8 @@ main :: proc (args: [] cstr) {
 
     se := find_contiguous_subarray_with_sum(nums, invalid);
 
-    max := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) 0, math.max);
-    min := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) max, math.min);
+    max := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) 0, math.max_poly);
+    min := array.fold_slice(nums.data[se.start .. se.end + 1], cast(u64) max, math.min_poly);
 
     printf("Extrema sum: %l\n", min + max);
 }
index 518fe9caa617d182cc06b2cd008490b222c0299a..a850326129c9b94363d8c9f2624ac7210a78119e 100644 (file)
@@ -39,7 +39,7 @@ main :: proc (args: [] cstr) {
     t.vt = ^test_vtable2;
     println(do_a_thing(^t));
 
-    tmp_vt := Test_VTable.{ first = test_vtable1.second, second = test_vtable2.first, third = math.max };
+    tmp_vt := Test_VTable.{ first = test_vtable1.second, second = test_vtable2.first, third = math.max_poly };
     t.vt = ^tmp_vt;
     println(do_a_thing(^t));