From 3eadd421e508e2955c8de00b139082fea8defa84 Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Tue, 26 Jan 2021 09:40:19 -0600 Subject: [PATCH] generalized data loading --- docs/abstractions | 6 +++- src/mnist.onyx | 78 ++++++++++++++++++++++++++++------------------ src/neuralnet.onyx | 37 ++++++++++++++++++++-- 3 files changed, 88 insertions(+), 33 deletions(-) diff --git a/docs/abstractions b/docs/abstractions index 626fd14..bfe5ab8 100644 --- a/docs/abstractions +++ b/docs/abstractions @@ -12,7 +12,7 @@ Abstractions still needed: * Criteria - MSE (implemented) - - MAE + - MAE (implemented) - BCE * Data Loader @@ -21,3 +21,7 @@ Abstractions still needed: the dataloader as it is needed. The dataloader then has the freedom to cache or preload the data. + This is currently implemented in a very basic way with something that + behaves like an abstract class, but I think it is general enough to + work for the moment. + diff --git a/src/mnist.onyx b/src/mnist.onyx index 897347a..b391419 100644 --- a/src/mnist.onyx +++ b/src/mnist.onyx @@ -5,13 +5,17 @@ use package core -MNIST_Data :: struct { + +MNIST_DataLoader :: struct { + use base : DataLoader; + images : io.FileStream; labels : io.FileStream; } -mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_Data { - mnist_data: MNIST_Data; +mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader { + mnist_data: MNIST_DataLoader; + mnist_data.vtable = ^mnist_dataloader_functions; err : io.Error; err, mnist_data.images = io.open(image_path); @@ -23,43 +27,57 @@ mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := return mnist_data; } -mnist_data_close :: (use mnist_data: ^MNIST_Data) { +mnist_data_close :: (use mnist_data: ^MNIST_DataLoader) { io.stream_close(^images); io.stream_close(^labels); } -load_example :: (use mnist_data: ^MNIST_Data, example: u32, out: [784] u8) -> u32 { - location := 16 + example * 784; - _, bytes_read := io.stream_read_at(^images, location, ~~ out); - - assert(bytes_read == 784, "Incorrect number of bytes read."); - - location = 8 + example; - label_buf : [1] u8; - _, bytes_read = io.stream_read_at(^labels, location, ~~ label_buf); - return ~~ label_buf[0]; +mnist_dataloader_functions := DataLoader_Functions.{ + get_count = (use data: ^MNIST_DataLoader) -> u32 { + return 50000; + }, + + get_item = (use data: ^MNIST_DataLoader, index: u32, input: [] f32, output: [] f32) -> bool { + assert(input.count == 28 * 28, "Input slice was of wrong size. Expected 784."); + assert(output.count == 10, "Output slice was of wrong size. Expected 10."); + + if index > 50000 do return false; + + location := 16 + index * 784; + input_tmp : [784] u8; + _, bytes_read := io.stream_read_at(^images, location, ~~ input_tmp); + + location = 8 + index; + label_buf : [1] u8; + _, bytes_read = io.stream_read_at(^labels, location, ~~ label_buf); + + // CLEANUP: The double cast that is necessary here is gross. + for i: input.count do input[i] = (cast(f32) cast(u32) input_tmp[i]) / 255; + + for i: output.count do output[i] = 0.0f; + output[cast(u32) label_buf[0]] = 1.0f; + + return true; + } } -stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training_examples := 50000) { - example : [784] u8; - expected := f32.[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]; + +stocastic_gradient_descent :: (nn: ^NeuralNet, dataloader: ^DataLoader, criterion: Criterion = mean_squared_error) { input := memory.make_slice(f32, 784); defer cfree(input.data); + expected : [10] f32; + training_example_count := dataloader_get_count(dataloader); + past_100_correct := 0; - for i: 10 { - for ex: training_examples { - label := load_example(mnist_data, ex, example); - expected[label] = 1.0f; - defer expected[label] = 0.0f; - - // CLEANUP: The double cast that is necessary here is gross. - for i: input.count do input[i] = (cast(f32) cast(u32) example[i]) / 255; - + for ex: training_example_count { + dataloader_get_item(dataloader, ex, input, ~~ expected); + neural_net_forward(nn, ~~ input); - neural_net_backward(nn, ~~ expected, mean_squared_error); - + neural_net_backward(nn, ~~ expected, criterion); + + label, _ := array.greatest(expected); prediction := neural_net_get_prediction(nn); if prediction == label do past_100_correct += 1; @@ -86,7 +104,7 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training print_colored_array(cast([] f32) expected, label, color); print_colored_array(output, prediction, color); - loss := neural_net_loss(nn, ~~ expected, mean_squared_error); + loss := neural_net_loss(nn, ~~ expected, criterion); printf("Loss: %f Correct: %i / 100\n", cast(f32) loss, past_100_correct); past_100_correct = 0; @@ -106,7 +124,7 @@ main :: (args: [] cstr) { // context.allocator = alloc.log.logging_allocator(^main_allocator); // nn := neural_net_load("data/test_2.nn"); - nn := make_neural_net(28 * 28, 512, 256, 100, 10); + nn := make_neural_net(28 * 28, 1024, 256, 100, 10); defer neural_net_free(^nn); random.set_seed(5234); diff --git a/src/neuralnet.onyx b/src/neuralnet.onyx index 3c5ea46..c50f4e5 100644 --- a/src/neuralnet.onyx +++ b/src/neuralnet.onyx @@ -1,7 +1,7 @@ use package core NeuralNet :: struct { - layers : [] Layer; + layers : [] Layer; // CLEANUP: Move these to core.alloc, so the nesting isn't nearly as terrible. layer_arena : alloc.arena.ArenaState; @@ -426,4 +426,37 @@ mean_absolute_error := Criterion.{ deltas[j] /= cast(f32) expected.count; } }, -} \ No newline at end of file +} + + + +// +// DataLoader (move this to somewhere else) +// +// Very basic datastructure that represents something you can loader data out of. +// Specifically, an input and output at a particular index. + +DataLoader :: struct { + vtable : ^DataLoader_Functions; +} + +DataLoader_Functions :: struct { + get_count : (^DataLoader) -> u32; + + // I don't like how these have to be floats, but they seem reasonable for now. + get_item : (^DataLoader, index: u32, input: [] f32, output: [] f32) -> bool; +} + +dataloader_get_count :: (use data: ^DataLoader) -> u32 { + if vtable == null do return 0; + if vtable.get_count == null_proc do return 0; + + return vtable.get_count(data); +} + +dataloader_get_item :: (use data: ^DataLoader, index: u32, input: [] f32, output: [] f32) -> bool { + if vtable == null do return false; + if vtable.get_item == null_proc do return false; + + return vtable.get_item(data, index, input, output); +} -- 2.25.1