From dd4fb28634e580d98c11f299768bebc615e03bc7 Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Fri, 22 Jan 2021 17:11:03 -0600 Subject: [PATCH] loading neural networks from disk works --- src/mnist.onyx | 31 +++++++++++++++++-------------- src/neuralnet.onyx | 40 +++++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/mnist.onyx b/src/mnist.onyx index efad3aa..8e7d6a1 100644 --- a/src/mnist.onyx +++ b/src/mnist.onyx @@ -59,24 +59,28 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training neural_net_backward(nn, ~~ expected); if ex % 100 == 0 { - print_colored_array :: (arr: [] $T) { - greatest_idx := 0; - for i: arr.count do if arr[i] > arr[greatest_idx] do greatest_idx = i; - + print_colored_array :: (arr: [] $T, color_idx: i32, color_code := 94) { for i: arr.count { - if i == greatest_idx { - printf("\x1b[94m%f\x1b[0m ", cast(f32) arr[i]); + if i == color_idx { + printf("\x1b[%im", color_code); + print(arr[i]); + print("\x1b[0m "); } else { - printf("%f ", cast(f32) arr[i]); + print(arr[i]); + print(" "); } } print("\n"); } - print_colored_array(cast([] f32) expected); - output := neural_net_get_output(nn); - print_colored_array(output); + prediction := neural_net_get_prediction(nn); + + color := 94; + if prediction != label do color = 91; + + print_colored_array(cast([] f32) expected, label, color); + print_colored_array(output, prediction, color); loss := neural_net_loss(nn, ~~ expected); printf("MSE loss: %f\n", cast(f32) loss); @@ -96,15 +100,14 @@ main :: (args: [] cstr) { // main_allocator := context.allocator; // context.allocator = alloc.log.logging_allocator(^main_allocator); - _ := neural_net_load("data/dummy.nn"); + nn := neural_net_load("data/dummy.nn"); + // nn := make_neural_net(28 * 28, 512, 256, 100, 10); + defer neural_net_free(^nn); random.set_seed(5234); mnist_data := mnist_data_make(); defer mnist_data_close(^mnist_data); - nn := make_neural_net(28 * 28, 512, 256, 100, 10); - defer neural_net_free(^nn); - stocastic_gradient_descent(^nn, ^mnist_data); } \ No newline at end of file diff --git a/src/neuralnet.onyx b/src/neuralnet.onyx index 884d6a7..89dd30b 100644 --- a/src/neuralnet.onyx +++ b/src/neuralnet.onyx @@ -97,6 +97,16 @@ neural_net_get_output :: (use nn: ^NeuralNet) -> [] f32 { return layers[layers.count - 1].neurons; } +// :MNISTSpecific +neural_net_get_prediction :: (use nn: ^NeuralNet) -> i32 { + output := neural_net_get_output(nn); + + greatest_idx := 0; + for i: output.count do if output[i] > output[greatest_idx] do greatest_idx = i; + + return greatest_idx; +} + neural_net_loss :: (use nn: ^NeuralNet, expected_output: [] f32) -> f32 { // MSE loss assert(layers[layers.count - 1].neurons.count == expected_output.count, @@ -129,23 +139,23 @@ Layer :: struct { deltas : [] f32; } -layer_init :: (use layer: ^Layer, layer_size: u32, prev_layer_size: u32, allocator := context.allocator) { +layer_init :: (use layer: ^Layer, layer_size: u32, prev_layer_size: u32, allocator := context.allocator, allocate_weights_and_biases := true) { neurons = memory.make_slice(f32, layer_size, allocator); pre_activation_neurons = memory.make_slice(f32, layer_size, allocator); + weights = memory.make_slice(#type [] f32, layer_size, allocator); + use_bias = true; deltas = memory.make_slice(f32, layer_size, allocator); activation = sigmoid_activation; is_input = (prev_layer_size == 0); - if !is_input { + if !is_input && allocate_weights_and_biases { if use_bias { biases = memory.make_slice(f32, layer_size, allocator); } - weights = memory.make_slice(#type [] f32, layer_size, allocator); - for ^weight: weights { *weight = memory.make_slice(f32, prev_layer_size, allocator); } @@ -237,14 +247,26 @@ neural_net_load :: (filename: str) -> NeuralNet { layer_size := io.binary_read_i32(^reader); is_input := cast(bool) io.binary_read_byte(^reader); - layer_init(^nn.layers[l], layer_size, prev_layer_size, layer_allocator); - if is_input do continue; + layer_init(^nn.layers[l], layer_size, prev_layer_size, allocator = layer_allocator, allocate_weights_and_biases = false); + if !is_input { + nn.layers[l].use_bias = cast(bool) io.binary_read_byte(^reader); - nn.layers[l].use_bias = cast(bool) io.binary_read_byte(^reader); + activation_id := cast(ActivationFunctionID) io.binary_read_byte(^reader); + nn.layers[l].activation = activation_function_from_id(activation_id); - activation_id := cast(ActivationFunctionID) io.binary_read_byte(^reader); - nn.layers[l].activation = activation_function_from_id(activation_id); + if nn.layers[l].use_bias { + nn.layers[l].biases = io.binary_read_slice(^reader, f32, layer_size, allocator = layer_allocator); + } + + for w: layer_size { + nn.layers[l].weights[w] = io.binary_read_slice(^reader, f32, prev_layer_size, allocator = layer_allocator); + } + } + + prev_layer_size = layer_size; } + + return nn; } -- 2.25.1