use package core
-MNIST_Data :: struct {
+
+MNIST_DataLoader :: struct {
+ use base : DataLoader;
+
images : io.FileStream;
labels : io.FileStream;
}
-mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_Data {
- mnist_data: MNIST_Data;
+mnist_data_make :: (image_path := "data/train-images-idx3-ubyte", label_path := "data/train-labels-idx1-ubyte") -> MNIST_DataLoader {
+ mnist_data: MNIST_DataLoader;
+ mnist_data.vtable = ^mnist_dataloader_functions;
err : io.Error;
err, mnist_data.images = io.open(image_path);
return mnist_data;
}
-mnist_data_close :: (use mnist_data: ^MNIST_Data) {
+mnist_data_close :: (use mnist_data: ^MNIST_DataLoader) {
io.stream_close(^images);
io.stream_close(^labels);
}
-load_example :: (use mnist_data: ^MNIST_Data, example: u32, out: [784] u8) -> u32 {
- location := 16 + example * 784;
- _, bytes_read := io.stream_read_at(^images, location, ~~ out);
-
- assert(bytes_read == 784, "Incorrect number of bytes read.");
-
- location = 8 + example;
- label_buf : [1] u8;
- _, bytes_read = io.stream_read_at(^labels, location, ~~ label_buf);
- return ~~ label_buf[0];
+mnist_dataloader_functions := DataLoader_Functions.{
+ get_count = (use data: ^MNIST_DataLoader) -> u32 {
+ return 50000;
+ },
+
+ get_item = (use data: ^MNIST_DataLoader, index: u32, input: [] f32, output: [] f32) -> bool {
+ assert(input.count == 28 * 28, "Input slice was of wrong size. Expected 784.");
+ assert(output.count == 10, "Output slice was of wrong size. Expected 10.");
+
+ if index > 50000 do return false;
+
+ location := 16 + index * 784;
+ input_tmp : [784] u8;
+ _, bytes_read := io.stream_read_at(^images, location, ~~ input_tmp);
+
+ location = 8 + index;
+ label_buf : [1] u8;
+ _, bytes_read = io.stream_read_at(^labels, location, ~~ label_buf);
+
+ // CLEANUP: The double cast that is necessary here is gross.
+ for i: input.count do input[i] = (cast(f32) cast(u32) input_tmp[i]) / 255;
+
+ for i: output.count do output[i] = 0.0f;
+ output[cast(u32) label_buf[0]] = 1.0f;
+
+ return true;
+ }
}
-stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training_examples := 50000) {
- example : [784] u8;
- expected := f32.[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ];
+
+stocastic_gradient_descent :: (nn: ^NeuralNet, dataloader: ^DataLoader, criterion: Criterion = mean_squared_error) {
input := memory.make_slice(f32, 784);
defer cfree(input.data);
+ expected : [10] f32;
+ training_example_count := dataloader_get_count(dataloader);
+
past_100_correct := 0;
-
for i: 10 {
- for ex: training_examples {
- label := load_example(mnist_data, ex, example);
- expected[label] = 1.0f;
- defer expected[label] = 0.0f;
-
- // CLEANUP: The double cast that is necessary here is gross.
- for i: input.count do input[i] = (cast(f32) cast(u32) example[i]) / 255;
-
+ for ex: training_example_count {
+ dataloader_get_item(dataloader, ex, input, ~~ expected);
+
neural_net_forward(nn, ~~ input);
- neural_net_backward(nn, ~~ expected, mean_squared_error);
-
+ neural_net_backward(nn, ~~ expected, criterion);
+
+ label, _ := array.greatest(expected);
prediction := neural_net_get_prediction(nn);
if prediction == label do past_100_correct += 1;
print_colored_array(cast([] f32) expected, label, color);
print_colored_array(output, prediction, color);
- loss := neural_net_loss(nn, ~~ expected, mean_squared_error);
+ loss := neural_net_loss(nn, ~~ expected, criterion);
printf("Loss: %f Correct: %i / 100\n", cast(f32) loss, past_100_correct);
past_100_correct = 0;
// context.allocator = alloc.log.logging_allocator(^main_allocator);
// nn := neural_net_load("data/test_2.nn");
- nn := make_neural_net(28 * 28, 512, 256, 100, 10);
+ nn := make_neural_net(28 * 28, 1024, 256, 100, 10);
defer neural_net_free(^nn);
random.set_seed(5234);
use package core
NeuralNet :: struct {
- layers : [] Layer;
+ layers : [] Layer;
// CLEANUP: Move these to core.alloc, so the nesting isn't nearly as terrible.
layer_arena : alloc.arena.ArenaState;
deltas[j] /= cast(f32) expected.count;
}
},
-}
\ No newline at end of file
+}
+
+
+
+//
+// DataLoader (move this to somewhere else)
+//
+// Very basic datastructure that represents something you can loader data out of.
+// Specifically, an input and output at a particular index.
+
+DataLoader :: struct {
+ vtable : ^DataLoader_Functions;
+}
+
+DataLoader_Functions :: struct {
+ get_count : (^DataLoader) -> u32;
+
+ // I don't like how these have to be floats, but they seem reasonable for now.
+ get_item : (^DataLoader, index: u32, input: [] f32, output: [] f32) -> bool;
+}
+
+dataloader_get_count :: (use data: ^DataLoader) -> u32 {
+ if vtable == null do return 0;
+ if vtable.get_count == null_proc do return 0;
+
+ return vtable.get_count(data);
+}
+
+dataloader_get_item :: (use data: ^DataLoader, index: u32, input: [] f32, output: [] f32) -> bool {
+ if vtable == null do return false;
+ if vtable.get_item == null_proc do return false;
+
+ return vtable.get_item(data, index, input, output);
+}