From c2ff5101efb9571f6e56f358cf2049b2e3c67df9 Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Fri, 29 Jan 2021 12:02:07 -0600 Subject: [PATCH] updated to use procedures inside of structures --- src/mnist.onyx | 69 ++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/src/mnist.onyx b/src/mnist.onyx index 87e5aaa..209128b 100644 --- a/src/mnist.onyx +++ b/src/mnist.onyx @@ -10,41 +10,31 @@ MNIST_DataLoader :: struct { images : io.FileStream; labels : io.FileStream; -} - -MNIST_Sample :: struct { - // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size - input : [] f32; - // NOTE(Brendan Hansen): Expected to be 10 elements in size - output : [] f32; -} - -mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader { - mnist_data: MNIST_DataLoader; - mnist_data.vtable = ^mnist_dataloader_functions; - - err : io.Error; - err, mnist_data.images = io.open(image_path); - assert(err == io.Error.None, "There was an error loading the image file"); - - err, mnist_data.labels = io.open(label_path); - assert(err == io.Error.None, "There was an error loading the label file"); - - return mnist_data; -} - -mnist_data_close :: (use mnist_data: ^MNIST_DataLoader) { - io.stream_close(^images); - io.stream_close(^labels); -} + make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader { + mnist_data: MNIST_DataLoader; + mnist_data.vtable = ^mnist_dataloader_functions; + + err : io.Error; + err, mnist_data.images = io.open(image_path); + assert(err == io.Error.None, "There was an error loading the image file"); + + err, mnist_data.labels = io.open(label_path); + assert(err == io.Error.None, "There was an error loading the label file"); -mnist_dataloader_functions := .{ - get_count = (use data: ^MNIST_DataLoader) -> u32 { + return mnist_data; + } + + close :: (use mnist_data: ^MNIST_DataLoader) { + io.stream_close(^images); + io.stream_close(^labels); + } + + get_count :: (use data: ^MNIST_DataLoader) -> u32 { return 50000; - }, + } - get_item = (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool { + get_item :: (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool { assert(input.count == 28 * 28, "Input slice was of wrong size. Expected 784."); assert(output.count == 10, "Output slice was of wrong size. Expected 10."); @@ -68,6 +58,19 @@ mnist_dataloader_functions := .{ } } +MNIST_Sample :: struct { + // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size + input : [] f32; + + // NOTE(Brendan Hansen): Expected to be 10 elements in size + output : [] f32; +} + +mnist_dataloader_functions := .{ + get_count = MNIST_DataLoader.get_count, + get_item = MNIST_DataLoader.get_item, +} + // TODO(Brendan Hansen): Generalize this to all data types train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Optimizer, criterion: Criterion = mean_squared_error) { sample : MNIST_Sample; @@ -142,8 +145,8 @@ main :: (args: [] cstr) { random.set_seed(5234); - mnist_data := mnist_data_make(); - defer mnist_data_close(^mnist_data); + mnist_data := MNIST_DataLoader.make(); + defer mnist_data.close(^mnist_data); optimizer := sgd_optimizer_create(^nn, learning_rate = 0.005f); neural_net_supply_parameters(^nn, ^optimizer); -- 2.25.1