small updates with new features in Onyx
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Wed, 3 Feb 2021 04:41:49 +0000 (22:41 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Wed, 3 Feb 2021 04:41:49 +0000 (22:41 -0600)
src/mnist.onyx
src/neuralnet.onyx

index 3be07bb7d2c716acbb9a3ea0b96e81d5d777654d..51835aac01b84ca3baf9a258f340a63d7b7740ce 100644 (file)
@@ -8,8 +8,7 @@ use package core
 MNIST_DataLoader :: struct {
     use base : DataLoader(MNIST_Sample);
     
-    images : io.FileStream;
-    labels : io.FileStream;
+    images, labels : io.FileStream;
     
     make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader {
         mnist_data: MNIST_DataLoader;
@@ -88,14 +87,14 @@ train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Opt
             dataloader_get_item(dataloader, ex, ^sample);
             
             optimizer_zero_gradient(optimizer);
-            NeuralNet.forward(nn, ~~ sample.input);
-            NeuralNet.backward(nn, ~~ sample.output, criterion);
+            (*nn)->forward(sample.input);
+            (*nn)->backward(sample.output, criterion);
             optimizer_step(optimizer);
 
 
             // NOTE(Brendan Hansen): Prediction printing and tracking.
             label, _   := array.greatest(sample.output);
-            prediction := NeuralNet.get_prediction(nn);
+            prediction := (*nn)->get_prediction();
             if prediction == label do past_100_correct += 1;
 
             if ex % 100 == 0 {
@@ -116,7 +115,7 @@ train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Opt
                 color := 94;
                 if prediction != label do color = 91;
 
-                output := NeuralNet.get_output(nn);
+                output := (*nn)->get_output();
 
                 print_colored_array(sample.output, label, color);
                 print_colored_array(output, prediction, color);
@@ -126,30 +125,26 @@ train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Opt
 
                 past_100_correct = 0;
                 
-                if ex % 10000 == 0 {
-                    println("Saving neural network...");
-                    neural_net_save(nn, "data/still_working.nn");
-                }
+                // if ex % 10000 == 0 {
+                //     println("Saving neural network...");
+                //     neural_net_save(nn, "data/still_working.nn");
+                // }
             }
         }
     }
 }
 
 main :: (args: [] cstr) {
-    // Enables a logging allocator to print every allocation
-    // main_allocator := context.allocator;
-    // context.allocator = alloc.log.logging_allocator(^main_allocator);
-
     nn := NeuralNet.make(28 * 28, 512, 256, 100, 10);
-    defer nn.free(^nn);
+    defer nn->free();
 
     random.set_seed(5234);
 
     mnist_data := MNIST_DataLoader.make();
-    defer mnist_data.close(^mnist_data);
+    defer mnist_data->close();
 
     optimizer := sgd_optimizer_create(^nn, learning_rate = 0.005f);
-    nn.supply_parameters(^nn, ^optimizer);
+    nn->supply_parameters(^optimizer);
 
     println("Starting training");
     train(^nn, ^mnist_data, ^optimizer);
index 4cf6b0e2e0b8c38bfd5282d8909f0378671f12e2..162755797a56ca9235adfd65a6fcf27f24febeee 100644 (file)
@@ -5,10 +5,7 @@ use package core
 // Variable
 //
 // TODO(Brendan Hansen): Document this better
-Variable :: struct {
-    value : f32;
-    delta : f32;
-}
+Variable :: struct { value, delta: f32; }
 
 //
 // General purpose Multi-Layer Perceptron (MLP)
@@ -122,7 +119,7 @@ NeuralNet :: struct {
 
     // :MNISTSpecific
     get_prediction :: (use nn: ^NeuralNet) -> i32 {
-        output := NeuralNet.get_output(nn);
+        output := get_output(nn);
 
         greatest_idx := 0;
         for i: output.count do if output[i] > output[greatest_idx] do greatest_idx = i;
@@ -173,7 +170,7 @@ Layer :: struct {
             
             weights = memory.make_slice(Variable, layer_size * prev_layer_size, allocator);
 
-            Layer.randomize_weights_and_biases(layer);
+            randomize_weights_and_biases(layer);
         }
     }