From edd531d95f6b05277c79a4531354c7336efddc91 Mon Sep 17 00:00:00 2001 From: Brendan Hansen Date: Mon, 25 Jan 2021 09:07:04 -0600 Subject: [PATCH] added correct count for past 100 examples --- src/mnist.onyx | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/mnist.onyx b/src/mnist.onyx index be8522d..dbb7fe6 100644 --- a/src/mnist.onyx +++ b/src/mnist.onyx @@ -46,6 +46,8 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training input := memory.make_slice(f32, 784); defer cfree(input.data); + past_100_correct := 0; + for i: 10 { for ex: training_examples { label := load_example(mnist_data, ex, example); @@ -58,6 +60,9 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training neural_net_forward(nn, ~~ input); neural_net_backward(nn, ~~ expected); + prediction := neural_net_get_prediction(nn); + if prediction == label do past_100_correct += 1; + if ex % 100 == 0 { print_colored_array :: (arr: [] $T, color_idx: i32, color_code := 94) { for i: arr.count { @@ -73,21 +78,22 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training print("\n"); } - output := neural_net_get_output(nn); - prediction := neural_net_get_prediction(nn); - color := 94; if prediction != label do color = 91; + output := neural_net_get_output(nn); + print_colored_array(cast([] f32) expected, label, color); print_colored_array(output, prediction, color); loss := neural_net_loss(nn, ~~ expected); - printf("MSE loss: %f\n", cast(f32) loss); + printf("MSE loss: %f Correct: %i / 100\n", cast(f32) loss, past_100_correct); + + past_100_correct = 0; if ex % 10000 == 0 { println("Saving neural network..."); - neural_net_save(nn, "data/test_1.nn"); + neural_net_save(nn, "data/test_2.nn"); } } } @@ -100,7 +106,7 @@ main :: (args: [] cstr) { // main_allocator := context.allocator; // context.allocator = alloc.log.logging_allocator(^main_allocator); - nn := neural_net_load("data/test_1.nn"); + nn := neural_net_load("data/test_2.nn"); // nn := make_neural_net(28 * 28, 512, 256, 100, 10); defer neural_net_free(^nn); -- 2.25.1