diff options
| author | jeanne <jeanne@localhost.localdomain> | 2022-05-11 09:54:38 -0700 |
|---|---|---|
| committer | jeanne <jeanne@localhost.localdomain> | 2022-05-11 09:54:38 -0700 |
| commit | 411f66a2540fa17c736116d865e0ceb0cfe5623b (patch) | |
| tree | fa92c69ec627642c8452f928798ff6eccd24ddd6 /src/lib/test/train_xor_test.c | |
| parent | 7705b07456dfd4b89c272613e98eda36cc787254 (diff) | |
Initial commit.
Diffstat (limited to 'src/lib/test/train_xor_test.c')
| -rw-r--r-- | src/lib/test/train_xor_test.c | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/src/lib/test/train_xor_test.c b/src/lib/test/train_xor_test.c new file mode 100644 index 0000000..6ddc6e0 --- /dev/null +++ b/src/lib/test/train_xor_test.c | |||
| @@ -0,0 +1,66 @@ | |||
| 1 | #include <neuralnet/train.h> | ||
| 2 | |||
| 3 | #include <neuralnet/matrix.h> | ||
| 4 | #include <neuralnet/neuralnet.h> | ||
| 5 | #include "activation.h" | ||
| 6 | #include "neuralnet_impl.h" | ||
| 7 | |||
| 8 | #include "test.h" | ||
| 9 | #include "test_util.h" | ||
| 10 | |||
| 11 | #include <assert.h> | ||
| 12 | |||
| 13 | TEST_CASE(neuralnet_train_xor_test) { | ||
| 14 | const int num_layers = 2; | ||
| 15 | const int layer_sizes[] = { 2, 2, 1 }; | ||
| 16 | const nnActivation layer_activations[] = { nnRelu, nnIdentity }; | ||
| 17 | |||
| 18 | nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); | ||
| 19 | assert(net); | ||
| 20 | |||
| 21 | // Train. | ||
| 22 | |||
| 23 | #define N 4 | ||
| 24 | const R inputs[N][2] = { { 0., 0. }, { 0., 1. }, { 1., 0. }, { 1., 1. } }; | ||
| 25 | const R targets[N] = { 0., 1., 1., 0. }; | ||
| 26 | |||
| 27 | nnMatrix inputs_matrix = nnMatrixMake(N, 2); | ||
| 28 | nnMatrix targets_matrix = nnMatrixMake(N, 1); | ||
| 29 | nnMatrixInit(&inputs_matrix, (const R*)inputs); | ||
| 30 | nnMatrixInit(&targets_matrix, targets); | ||
| 31 | |||
| 32 | nnTrainingParams params = { | ||
| 33 | .learning_rate = 0.1, | ||
| 34 | .max_iterations = 500, | ||
| 35 | .seed = 0, | ||
| 36 | .weight_init = nnWeightInit01, | ||
| 37 | .debug = false, | ||
| 38 | }; | ||
| 39 | |||
| 40 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); | ||
| 41 | |||
| 42 | // Test. | ||
| 43 | |||
| 44 | #define M 4 | ||
| 45 | |||
| 46 | nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M); | ||
| 47 | |||
| 48 | const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } }; | ||
| 49 | nnMatrix test_inputs_matrix = nnMatrixMake(M, 2); | ||
| 50 | nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs); | ||
| 51 | nnQuery(net, query, &test_inputs_matrix); | ||
| 52 | |||
| 53 | const R expected_outputs[M] = { 0., 1., 1., 0. }; | ||
| 54 | for (int i = 0; i < M; ++i) { | ||
| 55 | const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); | ||
| 56 | printf("\nInput: (%f, %f), Output: %f, Expected: %f\n", | ||
| 57 | test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]); | ||
| 58 | } | ||
| 59 | for (int i = 0; i < M; ++i) { | ||
| 60 | const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); | ||
| 61 | TEST_TRUE(double_eq(test_output, expected_outputs[i], OUTPUT_EPS)); | ||
| 62 | } | ||
| 63 | |||
| 64 | nnDeleteQueryObject(&query); | ||
| 65 | nnDeleteNet(&net); | ||
| 66 | } | ||
