bugfixes with parallel_for and distributor
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 3 Jan 2022 21:26:43 +0000 (15:26 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 3 Jan 2022 21:26:43 +0000 (15:26 -0600)
core/container/iter.onyx
core/intrinsics/onyx.onyx
scripts/run_tests.onyx

index f75ce56f71b28d9707b540fe67a7a96d2db35f29..11242eafdba88ce649b561645c4d9fa6b4c584b8 100644 (file)
@@ -1,6 +1,7 @@
 package core.iter
 
 use package core.intrinsics.onyx { __zero_value }
+#local sync   :: package core.sync
 #local memory :: package core.memory
 
 as_iterator :: #match {}
@@ -549,11 +550,19 @@ distributor :: #match {}
     Context :: struct (T: type_expr) {
         mutex: sync.Mutex;
         iterator: Iterator(T);
+        ended := false;
     }
 
     next :: (use c: ^Context($T)) -> (T, bool) {
+        if ended do return __zero_value(T), false;
         sync.scoped_mutex(^mutex);
-        return take_one(iterator);
+
+        if v, success := take_one(iterator); !success {
+            ended = true;
+            return v, false;
+        } else {
+            return v, true;
+        }
     }
 
     close :: (use c: ^Context($T)) {
@@ -568,27 +577,39 @@ distributor :: #match {}
     return .{c, #solidify next {T=T}, #solidify close {T=T}};
 }
 
-parallel_for :: macro (iterable: $I, thread_count: u32, body: Code) where Iterable(I) {
+parallel_for :: macro (iterable: $I, thread_count: u32, thread_data: ^$Ctx, body: Code) where Iterable(I) {
     thread :: package core.thread;
     alloc  :: package core.alloc;
+    distributor :: distributor;
+    as_iterator :: as_iterator;
 
     if thread_count != 0 {
-        dist := distributor(iterable);
+        dist := distributor(as_iterator(iterable));
         hacky_crap_to_get_the_type_of_T(dist);
     }
 
     hacky_crap_to_get_the_type_of_T :: macro (dist: Iterator($T)) {
         threads := (cast(^thread.Thread) alloc.from_stack(thread_count * sizeof thread.Thread))[0 .. (thread_count - 1)];
-        for^ threads do thread.spawn(it, ^dist, #solidify thread_function {body=body, T=T});
+        t_data := Thread_Data(T, Ctx).{
+            iter = ^dist,
+            data = thread_data,
+        };
+        for^ threads do thread.spawn(it, ^t_data, #solidify thread_function {body=body, T=T, Ctx=Ctx});
 
-        thread_function(body, ^dist);
+        thread_function(body, ^t_data);
 
         for^ threads do thread.join(it);
         dist.close(dist.data);
     }
 
-    thread_function :: ($body: Code, iter: ^Iterator($T)) {
-        for #no_close *iter {
+    Thread_Data :: struct (T: type_expr, Ctx: type_expr) {
+        iter: ^Iterator(T);
+        data: ^Ctx;
+    }
+
+    thread_function :: ($body: Code, __data: ^Thread_Data($T, $Ctx)) {
+        thread_data := __data.data;
+        for #no_close *__data.iter {
             #insert body;
         }
     }
index cf752a906914565b55fbce0b513bcd90d6e6964e..ad7064681ee814ac3de758bc8a429ee47dea9de7 100644 (file)
@@ -4,6 +4,8 @@ __initialize :: (val: ^$T)      -> void #intrinsic ---
 __zero_value :: ($T: type_expr) -> T    #intrinsic ---
 
 init :: macro ($T: type_expr) -> T {
+    __initialize :: __initialize
+
     val: T;
     __initialize(^val);
     return val;
index 87b2da06b9f809b8345e1944357ee4eaba9d3a40..41e36f621d5307bdde0c1e337c218fb806f2525d 100644 (file)
@@ -6,6 +6,7 @@
 #load "core/std"
 
 use package core
+use package core.intrinsics.onyx { init }
 #local runtime :: package runtime
 
 #if false {
@@ -26,37 +27,6 @@ use package core
     }
 }
 
-@Relocate
-distributor :: (arr: [] $T) -> Iterator(T) {
-    Context :: struct (T: type_expr) {
-        mutex: sync.Mutex;
-        arr: [] T;
-        curr_pos: i32;
-    }
-
-    next :: (use c: ^Context($T)) -> (T, bool) {
-        sync.scoped_mutex(^mutex);
-
-        use package core.intrinsics.onyx {__zero_value}
-        if curr_pos >= arr.count do return __zero_value(T), false;
-
-        defer curr_pos += 1;
-        return arr[curr_pos], true;
-    }
-
-    close :: (use c: ^Context($T)) {
-        sync.mutex_destroy(^c.mutex);
-        cfree(c);
-    }
-
-    c := new(Context(T));
-    sync.mutex_init(^c.mutex);
-    c.arr = arr;
-    c.curr_pos = 0;
-
-    return .{ c, #solidify next {T=T}, #solidify close {T=T}};
-}
-
 Color :: enum {
     White;
     Red;
@@ -117,9 +87,6 @@ find_onyx_files :: (root: str, cases: ^[..] Test_Case) {
     return;
 }
 
-// The executable to use when compiling
-onyx_cmd: str;
-at_least_one_test_failed := false;
 settings := Settings.{};
 
 Settings :: struct {
@@ -193,12 +160,19 @@ main :: (args) => {
     args_parse(args, ^settings);
     printf("Using {p*}\n", ^settings);
 
+    Execution_Context :: struct {
+        // The executable to use when compiling
+        onyx_cmd: str;
+        at_least_one_test_failed := false;
+    }
+    exec_context := init(Execution_Context);
+
     switch runtime.compiler_os {
         case .Linux {
-            onyx_cmd = "./bin/onyx";
-            if settings.debug do onyx_cmd = "./bin/onyx-debug";
+            exec_context.onyx_cmd = "./bin/onyx";
+            if settings.debug do exec_context.onyx_cmd = "./bin/onyx-debug";
         }
-        case .Windows do onyx_cmd = "onyx.exe";
+        case .Windows do exec_context.onyx_cmd = "onyx.exe";
     }
 
     cases := array.make(Test_Case, capacity=256);
@@ -206,16 +180,14 @@ main :: (args) => {
 
     thread_count := settings.threads;
 
-    iter.parallel_for(cases, settings.threads) {
+    iter.parallel_for(cases, settings.threads, ^exec_context) {
         // Weird macros mean I have to forward external names
         use package core
-        onyx_cmd :: onyx_cmd
         print_color :: print_color
-        at_least_one_test_failed :: at_least_one_test_failed
 
         printf("[{}]  Running test {}...\n", context.thread_id, it.source_file);
 
-        proc := os.process_spawn(onyx_cmd, .["run", it.source_file]);
+        proc := os.process_spawn(thread_data.onyx_cmd, .["run", it.source_file]);
         defer os.process_destroy(^proc);
 
         proc_reader := io.reader_make(^proc);
@@ -225,7 +197,7 @@ main :: (args) => {
         if exit := os.process_wait(^proc); exit != .Success {
             // Error running the test case
             print_color(.Red, "[{}]  Error '{}' in test case {}.\n{}", context.thread_id, exit, it.source_file, output);
-            at_least_one_test_failed = true;
+            thread_data.at_least_one_test_failed = true;
             continue;
         }
 
@@ -237,12 +209,12 @@ main :: (args) => {
                 print_color(.Red, "[{}]  Output did not match for {}.\n", context.thread_id, it.source_file);
                 printf("Expected:\n{}\n", expected_output);
                 printf("Got:\n{}\n", output);
-                at_least_one_test_failed = true;
+                thread_data.at_least_one_test_failed = true;
             }
         }
     }
 
-    if at_least_one_test_failed {
+    if exec_context.at_least_one_test_failed {
         print_color(.Red, "FAILED\n");
         os.exit(-1);