diff options
| author | 3gg <3gg@shellblade.net> | 2023-12-16 10:21:16 -0800 |
|---|---|---|
| committer | 3gg <3gg@shellblade.net> | 2023-12-16 10:21:16 -0800 |
| commit | 653e98e029a0d0f110b0ac599e50406060bb0f87 (patch) | |
| tree | 6f909215218f6720266bde1b3f49aeddad8b1da3 /src/lib/test/train_linear_perceptron_test.c | |
| parent | 3df7b6fb0c65295eed4590e6f166d60e89b3c68e (diff) | |
Decouple activations from linear layer.
Diffstat (limited to 'src/lib/test/train_linear_perceptron_test.c')
| -rw-r--r-- | src/lib/test/train_linear_perceptron_test.c | 44 |
1 files changed, 23 insertions, 21 deletions
diff --git a/src/lib/test/train_linear_perceptron_test.c b/src/lib/test/train_linear_perceptron_test.c index 2b1336d..667643b 100644 --- a/src/lib/test/train_linear_perceptron_test.c +++ b/src/lib/test/train_linear_perceptron_test.c | |||
| @@ -1,9 +1,8 @@ | |||
| 1 | #include <neuralnet/train.h> | 1 | #include <neuralnet/train.h> |
| 2 | 2 | ||
| 3 | #include "neuralnet_impl.h" | ||
| 3 | #include <neuralnet/matrix.h> | 4 | #include <neuralnet/matrix.h> |
| 4 | #include <neuralnet/neuralnet.h> | 5 | #include <neuralnet/neuralnet.h> |
| 5 | #include "activation.h" | ||
| 6 | #include "neuralnet_impl.h" | ||
| 7 | 6 | ||
| 8 | #include "test.h" | 7 | #include "test.h" |
| 9 | #include "test_util.h" | 8 | #include "test_util.h" |
| @@ -11,19 +10,21 @@ | |||
| 11 | #include <assert.h> | 10 | #include <assert.h> |
| 12 | 11 | ||
| 13 | TEST_CASE(neuralnet_train_linear_perceptron_test) { | 12 | TEST_CASE(neuralnet_train_linear_perceptron_test) { |
| 14 | const int num_layers = 1; | 13 | const int num_layers = 1; |
| 15 | const int layer_sizes[] = { 1, 1 }; | 14 | const int input_size = 1; |
| 16 | const nnActivation layer_activations[] = { nnIdentity }; | 15 | const nnLayer layers[] = { |
| 16 | {.type = nnLinear, .linear = {.input_size = 1, .output_size = 1}} | ||
| 17 | }; | ||
| 17 | 18 | ||
| 18 | nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); | 19 | nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size); |
| 19 | assert(net); | 20 | assert(net); |
| 20 | 21 | ||
| 21 | // Train. | 22 | // Train. |
| 22 | 23 | ||
| 23 | // Try to learn the Y=X line. | 24 | // Try to learn the Y=X line. |
| 24 | #define N 2 | 25 | #define N 2 |
| 25 | const R inputs[N] = { 0., 1. }; | 26 | const R inputs[N] = {0., 1.}; |
| 26 | const R targets[N] = { 0., 1. }; | 27 | const R targets[N] = {0., 1.}; |
| 27 | 28 | ||
| 28 | nnMatrix inputs_matrix = nnMatrixMake(N, 1); | 29 | nnMatrix inputs_matrix = nnMatrixMake(N, 1); |
| 29 | nnMatrix targets_matrix = nnMatrixMake(N, 1); | 30 | nnMatrix targets_matrix = nnMatrixMake(N, 1); |
| @@ -31,26 +32,27 @@ TEST_CASE(neuralnet_train_linear_perceptron_test) { | |||
| 31 | nnMatrixInit(&targets_matrix, targets); | 32 | nnMatrixInit(&targets_matrix, targets); |
| 32 | 33 | ||
| 33 | nnTrainingParams params = { | 34 | nnTrainingParams params = { |
| 34 | .learning_rate = 0.7, | 35 | .learning_rate = 0.7, |
| 35 | .max_iterations = 10, | 36 | .max_iterations = 10, |
| 36 | .seed = 0, | 37 | .seed = 0, |
| 37 | .weight_init = nnWeightInit01, | 38 | .weight_init = nnWeightInit01, |
| 38 | .debug = false, | 39 | .debug = false, |
| 39 | }; | 40 | }; |
| 40 | 41 | ||
| 41 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); | 42 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); |
| 42 | 43 | ||
| 43 | const R weight = nnMatrixAt(&net->weights[0], 0, 0); | 44 | const R weight = nnMatrixAt(&net->layers[0].linear.weights, 0, 0); |
| 44 | const R expected_weight = 1.0; | 45 | const R expected_weight = 1.0; |
| 45 | printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); | 46 | printf( |
| 47 | "\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); | ||
| 46 | TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); | 48 | TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); |
| 47 | 49 | ||
| 48 | // Test. | 50 | // Test. |
| 49 | 51 | ||
| 50 | nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); | 52 | nnQueryObject* query = nnMakeQueryObject(net, 1); |
| 51 | 53 | ||
| 52 | const R test_input[] = { 2.3 }; | 54 | const R test_input[] = {2.3}; |
| 53 | R test_output[1]; | 55 | R test_output[1]; |
| 54 | nnQueryArray(net, query, test_input, test_output); | 56 | nnQueryArray(net, query, test_input, test_output); |
| 55 | 57 | ||
| 56 | const R expected_output = test_input[0]; | 58 | const R expected_output = test_input[0]; |
