images : io.FileStream;
labels : io.FileStream;
-}
-
-MNIST_Sample :: struct {
- // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size
- input : [] f32;
- // NOTE(Brendan Hansen): Expected to be 10 elements in size
- output : [] f32;
-}
-
-mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader {
- mnist_data: MNIST_DataLoader;
- mnist_data.vtable = ^mnist_dataloader_functions;
-
- err : io.Error;
- err, mnist_data.images = io.open(image_path);
- assert(err == io.Error.None, "There was an error loading the image file");
-
- err, mnist_data.labels = io.open(label_path);
- assert(err == io.Error.None, "There was an error loading the label file");
-
- return mnist_data;
-}
-
-mnist_data_close :: (use mnist_data: ^MNIST_DataLoader) {
- io.stream_close(^images);
- io.stream_close(^labels);
-}
+ make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader {
+ mnist_data: MNIST_DataLoader;
+ mnist_data.vtable = ^mnist_dataloader_functions;
+
+ err : io.Error;
+ err, mnist_data.images = io.open(image_path);
+ assert(err == io.Error.None, "There was an error loading the image file");
+
+ err, mnist_data.labels = io.open(label_path);
+ assert(err == io.Error.None, "There was an error loading the label file");
-mnist_dataloader_functions := <DataLoader_Functions(MNIST_Sample)>.{
- get_count = (use data: ^MNIST_DataLoader) -> u32 {
+ return mnist_data;
+ }
+
+ close :: (use mnist_data: ^MNIST_DataLoader) {
+ io.stream_close(^images);
+ io.stream_close(^labels);
+ }
+
+ get_count :: (use data: ^MNIST_DataLoader) -> u32 {
return 50000;
- },
+ }
- get_item = (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool {
+ get_item :: (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool {
assert(input.count == 28 * 28, "Input slice was of wrong size. Expected 784.");
assert(output.count == 10, "Output slice was of wrong size. Expected 10.");
}
}
+MNIST_Sample :: struct {
+ // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size
+ input : [] f32;
+
+ // NOTE(Brendan Hansen): Expected to be 10 elements in size
+ output : [] f32;
+}
+
+mnist_dataloader_functions := <DataLoader_Functions(MNIST_Sample)>.{
+ get_count = MNIST_DataLoader.get_count,
+ get_item = MNIST_DataLoader.get_item,
+}
+
// TODO(Brendan Hansen): Generalize this to all data types
train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Optimizer, criterion: Criterion = mean_squared_error) {
sample : MNIST_Sample;
random.set_seed(5234);
- mnist_data := mnist_data_make();
- defer mnist_data_close(^mnist_data);
+ mnist_data := MNIST_DataLoader.make();
+ defer mnist_data.close(^mnist_data);
optimizer := sgd_optimizer_create(^nn, learning_rate = 0.005f);
neural_net_supply_parameters(^nn, ^optimizer);