From: Brendan Hansen Date: Fri, 5 Feb 2021 04:18:47 +0000 (-0600) Subject: generalize train function X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=0990da841697cad2581c766427bcc558540e27af;p=onyx-mnist.git generalize train function --- diff --git a/project.4coder b/project.4coder index f8f7ca0..9306168 100644 --- a/project.4coder +++ b/project.4coder @@ -19,7 +19,7 @@ load_paths = { { load_paths_custom, .os = "mac" }, }; -build_win32 = "\\dev\\onyx\\onyx.exe -V src\\mnist.onyx -o network.wasm"; +build_win32 = "\\dev\\onyx\\onyx.exe -V src\\mnist.onyx -o mnist.wasm"; build_linux = "/usr/bin/onyx -V src/mnist.onyx -o mnist.wasm"; command_list = { diff --git a/src/mnist.onyx b/src/mnist.onyx index 51835aa..c2d5e82 100644 --- a/src/mnist.onyx +++ b/src/mnist.onyx @@ -30,14 +30,14 @@ MNIST_DataLoader :: struct { } get_count :: (use data: ^MNIST_DataLoader) -> u32 { - return 50000; + return 60000; } get_item :: (use data: ^MNIST_DataLoader, index: u32, use sample: ^MNIST_Sample) -> bool { assert(input.count == 28 * 28, "Input slice was of wrong size. Expected 784."); assert(output.count == 10, "Output slice was of wrong size. Expected 10."); - if index > 50000 do return false; + if index > 60000 do return false; location := 16 + index * 784; input_tmp : [784] u8; @@ -57,47 +57,100 @@ MNIST_DataLoader :: struct { } } -MNIST_Sample :: struct { - // NOTE(Brendan Hansen): Expected to be 28 * 28 elements in size - input : [] f32; - - // NOTE(Brendan Hansen): Expected to be 10 elements in size - output : [] f32; -} - mnist_dataloader_functions := .{ get_count = MNIST_DataLoader.get_count, get_item = MNIST_DataLoader.get_item, } -// TODO(Brendan Hansen): Generalize this to all data types -train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Optimizer, criterion: Criterion = mean_squared_error) { - sample : MNIST_Sample; - sample.input = memory.make_slice(f32, 784); - sample.output = memory.make_slice(f32, 10); - defer cfree(sample.input.data); - defer cfree(sample.output.data); +MNIST_Sample :: struct { + input, output : [] f32; + + init :: (use s: ^MNIST_Sample, allocator := context.allocator) { + input = memory.make_slice(f32, 784, allocator); + output = memory.make_slice(f32, 10, allocator); + } + + deinit :: (use s: ^MNIST_Sample, allocator := context.allocator) { + raw_free(allocator, input.data); + raw_free(allocator, output.data); + } +} + +train :: ( + nn: ^NeuralNet, // The neural network. + dataloader: ^DataLoader($Sample_Type), // Data loader that provides samples of type Sample_Type. + optimizer: ^Optimizer, // The optimizer of choice that is expected to have neural net parameters initialized. + criterion: Criterion = mean_squared_error, // The criterion of choice. + batch_size := 10, // How many samples per batch. + batches_per_epoch := -1, // -1 means dataset size divided by batch size + epochs := 5, // The number of epochs + ) { + + sample : Sample_Type; + sample->init(); + defer sample->deinit(); training_example_count := dataloader_get_count(dataloader); + printf("Training sample count: %i\n", training_example_count); + + if batches_per_epoch == -1 { + batches_per_epoch = training_example_count / batch_size; + } + // Tracking how many of the past 100 samples were correct. past_100_correct := 0; - for i: 10 { - printf("Staring epoch %i ===================================\n", i); - for ex: training_example_count { - dataloader_get_item(dataloader, ex, ^sample); - + + for epoch: epochs { + printf("Staring epoch %i ===================================\n", epoch + 1); + + for batch_num: batches_per_epoch { optimizer_zero_gradient(optimizer); - (*nn)->forward(sample.input); - (*nn)->backward(sample.output, criterion); + + for batch: batch_size { + sample_num := random.between(0, training_example_count); + dataloader_get_item(dataloader, sample_num, ^sample); + + (*nn)->forward(sample.input); + (*nn)->backward(sample.output, criterion); + + label, _ := array.greatest(sample.output); + prediction := (*nn)->get_prediction(); + if prediction == label do past_100_correct += 1; + } + optimizer_step(optimizer); + if batch_num % (100 / batch_size) == 0 { + loss := (*nn)->get_loss(sample.output, criterion); + printf("Loss: %f Correct: %i / 100\n", cast(f32) loss, past_100_correct); + + past_100_correct = 0; + } + } + } +} + +main :: (args: [] cstr) { + nn := NeuralNet.make(28 * 28, 512, 256, 100, 10); + defer nn->free(); + + random.set_seed(5234); - // NOTE(Brendan Hansen): Prediction printing and tracking. - label, _ := array.greatest(sample.output); - prediction := (*nn)->get_prediction(); - if prediction == label do past_100_correct += 1; + mnist_data := MNIST_DataLoader.make(); + defer mnist_data->close(); - if ex % 100 == 0 { + optimizer := sgd_optimizer_create(^nn, learning_rate = 0.01f); + nn->supply_parameters(^optimizer); + + train(^nn, ^mnist_data.base, ^optimizer); +} + + + + +// Old code for printing the outputs fancily: +/* + { print_colored_array :: (arr: [] $T, color_idx: i32, color_code := 94) { for i: arr.count { if i == color_idx { @@ -120,7 +173,7 @@ train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Opt print_colored_array(sample.output, label, color); print_colored_array(output, prediction, color); - loss := NeuralNet.get_loss(nn, sample.output, criterion); + loss := (*nn)->get_loss(sample.output, criterion); printf("Loss: %f Correct: %i / 100\n", cast(f32) loss, past_100_correct); past_100_correct = 0; @@ -130,22 +183,4 @@ train :: (nn: ^NeuralNet, dataloader: ^DataLoader(MNIST_Sample), optimizer: ^Opt // neural_net_save(nn, "data/still_working.nn"); // } } - } - } -} - -main :: (args: [] cstr) { - nn := NeuralNet.make(28 * 28, 512, 256, 100, 10); - defer nn->free(); - - random.set_seed(5234); - - mnist_data := MNIST_DataLoader.make(); - defer mnist_data->close(); - - optimizer := sgd_optimizer_create(^nn, learning_rate = 0.005f); - nn->supply_parameters(^optimizer); - - println("Starting training"); - train(^nn, ^mnist_data, ^optimizer); -} \ No newline at end of file + */ diff --git a/src/neuralnet.onyx b/src/neuralnet.onyx index 1627557..92c1424 100644 --- a/src/neuralnet.onyx +++ b/src/neuralnet.onyx @@ -308,7 +308,6 @@ neural_net_load :: (filename: str) -> NeuralNet { - // // Activation functions // The activation functions that are currently implemented are: @@ -468,7 +467,7 @@ mean_absolute_error := Criterion.{ // // DataLoader (move this to somewhere else) // -// Very basic datastructure that represents something you can loader data out of. +// Very basic datastructure that represents something you can load data out of. // Specifically, an input and output at a particular index. // @@ -478,8 +477,6 @@ DataLoader :: struct (Sample_Type: type_expr) { DataLoader_Functions :: struct (Sample_Type: type_expr) { get_count : (^DataLoader(Sample_Type)) -> u32; - - // I don't like how these have to be floats, but they seem reasonable for now. get_item : (^DataLoader(Sample_Type), index: u32, sample: ^Sample_Type) -> bool; } @@ -518,7 +515,7 @@ Optimizer :: struct { } Optimizer_Functions :: struct { - step : (optimizer: ^Optimizer) -> void; + step : (optimizer: ^Optimizer, scale: f32) -> void; } optimizer_init :: (use optim: ^Optimizer, nn: ^NeuralNet, allocator := context.allocator) { @@ -532,11 +529,11 @@ optimizer_init :: (use optim: ^Optimizer, nn: ^NeuralNet, allocator := context.a } } -optimizer_step :: (use optim: ^Optimizer) { +optimizer_step :: (use optim: ^Optimizer, scale: f32 = 1) { if vtable == null do return; if vtable.step == null_proc do return; - vtable.step(optim); + vtable.step(optim, scale); } optimizer_zero_gradient :: (use optim: ^Optimizer) { @@ -573,14 +570,16 @@ sgd_optimizer_create :: (nn: ^NeuralNet, learning_rate := 0.01f, allocator := co return sgd; } -sgd_optimizer_step :: (use optimizer: ^SGD_Optimizer) { +sgd_optimizer_step :: (use optimizer: ^SGD_Optimizer, scale: f32) { + alpha := scale * learning_rate; + for variable: variables { - variable.value += variable.delta * learning_rate; + variable.value += variable.delta * alpha; } for variable_array: variable_arrays { for ^variable: *variable_array { - variable.value += variable.delta * learning_rate; + variable.value += variable.delta * alpha; } } } \ No newline at end of file