added correct count for past 100 examples
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 25 Jan 2021 15:07:04 +0000 (09:07 -0600)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Mon, 25 Jan 2021 15:07:04 +0000 (09:07 -0600)
src/mnist.onyx

index be8522d691034af5de0561458498477d14ab6f1b..dbb7fe6385f545d5d08b54056d2d6f2f914b16a7 100644 (file)
@@ -46,6 +46,8 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training
     input := memory.make_slice(f32, 784);
     defer cfree(input.data);
 
+    past_100_correct := 0;
+
     for i: 10 {
         for ex: training_examples {
             label := load_example(mnist_data, ex, example);
@@ -58,6 +60,9 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training
             neural_net_forward(nn, ~~ input);
             neural_net_backward(nn, ~~ expected);
 
+            prediction := neural_net_get_prediction(nn);
+            if prediction == label do past_100_correct += 1;
+
             if ex % 100 == 0 {
                 print_colored_array :: (arr: [] $T, color_idx: i32, color_code := 94) {
                     for i: arr.count {
@@ -73,21 +78,22 @@ stocastic_gradient_descent :: (nn: ^NeuralNet, mnist_data: ^MNIST_Data, training
                     print("\n");
                 }
 
-                output := neural_net_get_output(nn);
-                prediction := neural_net_get_prediction(nn);
-
                 color := 94;
                 if prediction != label do color = 91;
 
+                output := neural_net_get_output(nn);
+
                 print_colored_array(cast([] f32) expected, label, color);
                 print_colored_array(output, prediction, color);
 
                 loss := neural_net_loss(nn, ~~ expected);
-                printf("MSE loss: %f\n", cast(f32) loss);
+                printf("MSE loss: %f     Correct: %i / 100\n", cast(f32) loss, past_100_correct);
+
+                past_100_correct = 0;
 
                 if ex % 10000 == 0 {
                     println("Saving neural network...");
-                    neural_net_save(nn, "data/test_1.nn");
+                    neural_net_save(nn, "data/test_2.nn");
                 }
             }
         }
@@ -100,7 +106,7 @@ main :: (args: [] cstr) {
     // main_allocator := context.allocator;
     // context.allocator = alloc.log.logging_allocator(^main_allocator);
 
-    nn := neural_net_load("data/test_1.nn");
+    nn := neural_net_load("data/test_2.nn");
     // nn := make_neural_net(28 * 28, 512, 256, 100, 10);
     defer neural_net_free(^nn);