aboutsummaryrefslogtreecommitdiff
path: root/src/lib/test/train_xor_test.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/test/train_xor_test.c')
-rw-r--r--src/lib/test/train_xor_test.c55
1 files changed, 35 insertions, 20 deletions
diff --git a/src/lib/test/train_xor_test.c b/src/lib/test/train_xor_test.c
index 6ddc6e0..78695a3 100644
--- a/src/lib/test/train_xor_test.c
+++ b/src/lib/test/train_xor_test.c
@@ -1,9 +1,9 @@
1#include <neuralnet/train.h> 1#include <neuralnet/train.h>
2 2
3#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h>
5#include "activation.h" 3#include "activation.h"
6#include "neuralnet_impl.h" 4#include "neuralnet_impl.h"
5#include <neuralnet/matrix.h>
6#include <neuralnet/neuralnet.h>
7 7
8#include "test.h" 8#include "test.h"
9#include "test_util.h" 9#include "test_util.h"
@@ -11,18 +11,27 @@
11#include <assert.h> 11#include <assert.h>
12 12
13TEST_CASE(neuralnet_train_xor_test) { 13TEST_CASE(neuralnet_train_xor_test) {
14 const int num_layers = 2; 14 const int num_layers = 3;
15 const int layer_sizes[] = { 2, 2, 1 }; 15 const int input_size = 2;
16 const nnActivation layer_activations[] = { nnRelu, nnIdentity }; 16 const nnLayer layers[] = {
17 {.type = nnLinear, .linear = {.input_size = 2, .output_size = 2}},
18 {.type = nnRelu},
19 {.type = nnLinear, .linear = {.input_size = 2, .output_size = 1}}
20 };
17 21
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 22 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 23 assert(net);
20 24
21 // Train. 25 // Train.
22 26
23 #define N 4 27#define N 4
24 const R inputs[N][2] = { { 0., 0. }, { 0., 1. }, { 1., 0. }, { 1., 1. } }; 28 const R inputs[N][2] = {
25 const R targets[N] = { 0., 1., 1., 0. }; 29 {0., 0.},
30 {0., 1.},
31 {1., 0.},
32 {1., 1.}
33 };
34 const R targets[N] = {0., 1., 1., 0.};
26 35
27 nnMatrix inputs_matrix = nnMatrixMake(N, 2); 36 nnMatrix inputs_matrix = nnMatrixMake(N, 2);
28 nnMatrix targets_matrix = nnMatrixMake(N, 1); 37 nnMatrix targets_matrix = nnMatrixMake(N, 1);
@@ -30,31 +39,37 @@ TEST_CASE(neuralnet_train_xor_test) {
30 nnMatrixInit(&targets_matrix, targets); 39 nnMatrixInit(&targets_matrix, targets);
31 40
32 nnTrainingParams params = { 41 nnTrainingParams params = {
33 .learning_rate = 0.1, 42 .learning_rate = 0.1,
34 .max_iterations = 500, 43 .max_iterations = 500,
35 .seed = 0, 44 .seed = 0,
36 .weight_init = nnWeightInit01, 45 .weight_init = nnWeightInit01,
37 .debug = false, 46 .debug = false,
38 }; 47 };
39 48
40 nnTrain(net, &inputs_matrix, &targets_matrix, &params); 49 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
41 50
42 // Test. 51 // Test.
43 52
44 #define M 4 53#define M 4
45 54
46 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M); 55 nnQueryObject* query = nnMakeQueryObject(net, M);
47 56
48 const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } }; 57 const R test_inputs[M][2] = {
58 {0., 0.},
59 {1., 0.},
60 {0., 1.},
61 {1., 1.}
62 };
49 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2); 63 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2);
50 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs); 64 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs);
51 nnQuery(net, query, &test_inputs_matrix); 65 nnQuery(net, query, &test_inputs_matrix);
52 66
53 const R expected_outputs[M] = { 0., 1., 1., 0. }; 67 const R expected_outputs[M] = {0., 1., 1., 0.};
54 for (int i = 0; i < M; ++i) { 68 for (int i = 0; i < M; ++i) {
55 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); 69 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
56 printf("\nInput: (%f, %f), Output: %f, Expected: %f\n", 70 printf(
57 test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]); 71 "\nInput: (%f, %f), Output: %f, Expected: %f\n", test_inputs[i][0],
72 test_inputs[i][1], test_output, expected_outputs[i]);
58 } 73 }
59 for (int i = 0; i < M; ++i) { 74 for (int i = 0; i < M; ++i) {
60 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); 75 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);