diff options
| author | 3gg <3gg@shellblade.net> | 2023-12-16 10:21:16 -0800 |
|---|---|---|
| committer | 3gg <3gg@shellblade.net> | 2023-12-16 10:21:16 -0800 |
| commit | 653e98e029a0d0f110b0ac599e50406060bb0f87 (patch) | |
| tree | 6f909215218f6720266bde1b3f49aeddad8b1da3 /src/lib/test/train_xor_test.c | |
| parent | 3df7b6fb0c65295eed4590e6f166d60e89b3c68e (diff) | |
Decouple activations from linear layer.
Diffstat (limited to 'src/lib/test/train_xor_test.c')
| -rw-r--r-- | src/lib/test/train_xor_test.c | 55 |
1 files changed, 35 insertions, 20 deletions
diff --git a/src/lib/test/train_xor_test.c b/src/lib/test/train_xor_test.c index 6ddc6e0..78695a3 100644 --- a/src/lib/test/train_xor_test.c +++ b/src/lib/test/train_xor_test.c | |||
| @@ -1,9 +1,9 @@ | |||
| 1 | #include <neuralnet/train.h> | 1 | #include <neuralnet/train.h> |
| 2 | 2 | ||
| 3 | #include <neuralnet/matrix.h> | ||
| 4 | #include <neuralnet/neuralnet.h> | ||
| 5 | #include "activation.h" | 3 | #include "activation.h" |
| 6 | #include "neuralnet_impl.h" | 4 | #include "neuralnet_impl.h" |
| 5 | #include <neuralnet/matrix.h> | ||
| 6 | #include <neuralnet/neuralnet.h> | ||
| 7 | 7 | ||
| 8 | #include "test.h" | 8 | #include "test.h" |
| 9 | #include "test_util.h" | 9 | #include "test_util.h" |
| @@ -11,18 +11,27 @@ | |||
| 11 | #include <assert.h> | 11 | #include <assert.h> |
| 12 | 12 | ||
| 13 | TEST_CASE(neuralnet_train_xor_test) { | 13 | TEST_CASE(neuralnet_train_xor_test) { |
| 14 | const int num_layers = 2; | 14 | const int num_layers = 3; |
| 15 | const int layer_sizes[] = { 2, 2, 1 }; | 15 | const int input_size = 2; |
| 16 | const nnActivation layer_activations[] = { nnRelu, nnIdentity }; | 16 | const nnLayer layers[] = { |
| 17 | {.type = nnLinear, .linear = {.input_size = 2, .output_size = 2}}, | ||
| 18 | {.type = nnRelu}, | ||
| 19 | {.type = nnLinear, .linear = {.input_size = 2, .output_size = 1}} | ||
| 20 | }; | ||
| 17 | 21 | ||
| 18 | nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); | 22 | nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size); |
| 19 | assert(net); | 23 | assert(net); |
| 20 | 24 | ||
| 21 | // Train. | 25 | // Train. |
| 22 | 26 | ||
| 23 | #define N 4 | 27 | #define N 4 |
| 24 | const R inputs[N][2] = { { 0., 0. }, { 0., 1. }, { 1., 0. }, { 1., 1. } }; | 28 | const R inputs[N][2] = { |
| 25 | const R targets[N] = { 0., 1., 1., 0. }; | 29 | {0., 0.}, |
| 30 | {0., 1.}, | ||
| 31 | {1., 0.}, | ||
| 32 | {1., 1.} | ||
| 33 | }; | ||
| 34 | const R targets[N] = {0., 1., 1., 0.}; | ||
| 26 | 35 | ||
| 27 | nnMatrix inputs_matrix = nnMatrixMake(N, 2); | 36 | nnMatrix inputs_matrix = nnMatrixMake(N, 2); |
| 28 | nnMatrix targets_matrix = nnMatrixMake(N, 1); | 37 | nnMatrix targets_matrix = nnMatrixMake(N, 1); |
| @@ -30,31 +39,37 @@ TEST_CASE(neuralnet_train_xor_test) { | |||
| 30 | nnMatrixInit(&targets_matrix, targets); | 39 | nnMatrixInit(&targets_matrix, targets); |
| 31 | 40 | ||
| 32 | nnTrainingParams params = { | 41 | nnTrainingParams params = { |
| 33 | .learning_rate = 0.1, | 42 | .learning_rate = 0.1, |
| 34 | .max_iterations = 500, | 43 | .max_iterations = 500, |
| 35 | .seed = 0, | 44 | .seed = 0, |
| 36 | .weight_init = nnWeightInit01, | 45 | .weight_init = nnWeightInit01, |
| 37 | .debug = false, | 46 | .debug = false, |
| 38 | }; | 47 | }; |
| 39 | 48 | ||
| 40 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); | 49 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); |
| 41 | 50 | ||
| 42 | // Test. | 51 | // Test. |
| 43 | 52 | ||
| 44 | #define M 4 | 53 | #define M 4 |
| 45 | 54 | ||
| 46 | nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M); | 55 | nnQueryObject* query = nnMakeQueryObject(net, M); |
| 47 | 56 | ||
| 48 | const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } }; | 57 | const R test_inputs[M][2] = { |
| 58 | {0., 0.}, | ||
| 59 | {1., 0.}, | ||
| 60 | {0., 1.}, | ||
| 61 | {1., 1.} | ||
| 62 | }; | ||
| 49 | nnMatrix test_inputs_matrix = nnMatrixMake(M, 2); | 63 | nnMatrix test_inputs_matrix = nnMatrixMake(M, 2); |
| 50 | nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs); | 64 | nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs); |
| 51 | nnQuery(net, query, &test_inputs_matrix); | 65 | nnQuery(net, query, &test_inputs_matrix); |
| 52 | 66 | ||
| 53 | const R expected_outputs[M] = { 0., 1., 1., 0. }; | 67 | const R expected_outputs[M] = {0., 1., 1., 0.}; |
| 54 | for (int i = 0; i < M; ++i) { | 68 | for (int i = 0; i < M; ++i) { |
| 55 | const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); | 69 | const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); |
| 56 | printf("\nInput: (%f, %f), Output: %f, Expected: %f\n", | 70 | printf( |
| 57 | test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]); | 71 | "\nInput: (%f, %f), Output: %f, Expected: %f\n", test_inputs[i][0], |
| 72 | test_inputs[i][1], test_output, expected_outputs[i]); | ||
| 58 | } | 73 | } |
| 59 | for (int i = 0; i < M; ++i) { | 74 | for (int i = 0; i < M; ++i) { |
| 60 | const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); | 75 | const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); |
