updated to use procedures inside of structures
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Fri, 29 Jan 2021 18:02:07 +0000 (12:02 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Fri, 29 Jan 2021 18:02:07 +0000 (12:02 -0600)
src/mnist.onyx

index 87e5aaa0c9bd1bc4ed053ddefcff78d32081b6af..209128b9480650a00782189de298e275e09aa29e 100644 (file)
@@ -10,41 +10,31 @@ MNIST_DataLoader :: struct {
     
     images : io.FileStream;
     labels : io.FileStream;
-}
-
-MNIST_Sample :: struct {
-    // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size
-    input : [] f32;
     
-    // NOTE(Brendan Hansen): Expected to be 10 elements in size
-    output : [] f32;
-}
-
-mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader {
-    mnist_data: MNIST_DataLoader;
-    mnist_data.vtable = ^mnist_dataloader_functions;
-
-    err : io.Error;
-    err, mnist_data.images = io.open(image_path);
-    assert(err == io.Error.None, "There was an error loading the image file");
-
-    err, mnist_data.labels = io.open(label_path);
-    assert(err == io.Error.None, "There was an error loading the label file");
-
-    return mnist_data;
-}
-
-mnist_data_close :: (use mnist_data: ^MNIST_DataLoader) {
-    io.stream_close(^images);
-    io.stream_close(^labels);
-}
+    make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader {
+        mnist_data: MNIST_DataLoader;
+        mnist_data.vtable = ^mnist_dataloader_functions;
+        
+        err : io.Error;
+        err, mnist_data.images = io.open(image_path);
+        assert(err == io.Error.None, "There was an error loading the image file");
+        
+        err, mnist_data.labels = io.open(label_path);
+        assert(err == io.Error.None, "There was an error loading the label file");
 
-mnist_dataloader_functions := <DataLoader_Functions(MNIST_Sample)>.{
-    get_count = (use data: ^MNIST_DataLoader) -> u32 {
+        return mnist_data;
+    }
+       
+    close :: (use mnist_data: ^MNIST_DataLoader) {
+        io.stream_close(^images);
+        io.stream_close(^labels);
+    }
+    
+    get_count :: (use data: ^MNIST_DataLoader) -> u32 {
         return 50000;
-    },
+    }
     
-    get_item  = (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool {
+    get_item :: (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool {
         assert(input.count  == 28 * 28, "Input slice was of wrong size. Expected 784.");
         assert(output.count == 10,      "Output slice was of wrong size. Expected 10.");
         
@@ -68,6 +58,19 @@ mnist_dataloader_functions := <DataLoader_Functions(MNIST_Sample)>.{
     }
 }
 
+MNIST_Sample :: struct {
+    // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size
+    input : [] f32;
+    
+    // NOTE(Brendan Hansen): Expected to be 10 elements in size
+    output : [] f32;
+}
+
+mnist_dataloader_functions := <DataLoader_Functions(MNIST_Sample)>.{
+    get_count = MNIST_DataLoader.get_count,
+    get_item  = MNIST_DataLoader.get_item,
+}
+
 // TODO(Brendan Hansen): Generalize this to all data types 
 train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Optimizer, criterion: Criterion = mean_squared_error) {
     sample : MNIST_Sample;
@@ -142,8 +145,8 @@ main :: (args: [] cstr) {
 
     random.set_seed(5234);
 
-    mnist_data := mnist_data_make();
-    defer mnist_data_close(^mnist_data);
+    mnist_data := MNIST_DataLoader.make();
+    defer mnist_data.close(^mnist_data);
 
     optimizer := sgd_optimizer_create(^nn, learning_rate = 0.005f);
     neural_net_supply_parameters(^nn, ^optimizer);