diff options
author | 3gg <3gg@shellblade.net> | 2023-12-16 10:21:16 -0800 |
---|---|---|
committer | 3gg <3gg@shellblade.net> | 2023-12-16 10:21:16 -0800 |
commit | 653e98e029a0d0f110b0ac599e50406060bb0f87 (patch) | |
tree | 6f909215218f6720266bde1b3f49aeddad8b1da3 /src/lib/test/train_sigmoid_test.c | |
parent | 3df7b6fb0c65295eed4590e6f166d60e89b3c68e (diff) |
Decouple activations from linear layer.
Diffstat (limited to 'src/lib/test/train_sigmoid_test.c')
-rw-r--r-- | src/lib/test/train_sigmoid_test.c | 46 |
1 files changed, 25 insertions, 21 deletions
diff --git a/src/lib/test/train_sigmoid_test.c b/src/lib/test/train_sigmoid_test.c index 588e7ca..39a84b0 100644 --- a/src/lib/test/train_sigmoid_test.c +++ b/src/lib/test/train_sigmoid_test.c | |||
@@ -1,9 +1,9 @@ | |||
1 | #include <neuralnet/train.h> | 1 | #include <neuralnet/train.h> |
2 | 2 | ||
3 | #include <neuralnet/matrix.h> | ||
4 | #include <neuralnet/neuralnet.h> | ||
5 | #include "activation.h" | 3 | #include "activation.h" |
6 | #include "neuralnet_impl.h" | 4 | #include "neuralnet_impl.h" |
5 | #include <neuralnet/matrix.h> | ||
6 | #include <neuralnet/neuralnet.h> | ||
7 | 7 | ||
8 | #include "test.h" | 8 | #include "test.h" |
9 | #include "test_util.h" | 9 | #include "test_util.h" |
@@ -11,21 +11,24 @@ | |||
11 | #include <assert.h> | 11 | #include <assert.h> |
12 | 12 | ||
13 | TEST_CASE(neuralnet_train_sigmoid_test) { | 13 | TEST_CASE(neuralnet_train_sigmoid_test) { |
14 | const int num_layers = 1; | 14 | const int num_layers = 2; |
15 | const int layer_sizes[] = { 1, 1 }; | 15 | const int input_size = 1; |
16 | const nnActivation layer_activations[] = { nnSigmoid }; | 16 | const nnLayer layers[] = { |
17 | {.type = nnLinear, .linear = {.input_size = 1, .output_size = 1}}, | ||
18 | {.type = nnSigmoid}, | ||
19 | }; | ||
17 | 20 | ||
18 | nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); | 21 | nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size); |
19 | assert(net); | 22 | assert(net); |
20 | 23 | ||
21 | // Train. | 24 | // Train. |
22 | 25 | ||
23 | // Try to learn the sigmoid function. | 26 | // Try to learn the sigmoid function. |
24 | #define N 3 | 27 | #define N 3 |
25 | R inputs[N]; | 28 | R inputs[N]; |
26 | R targets[N]; | 29 | R targets[N]; |
27 | for (int i = 0; i < N; ++i) { | 30 | for (int i = 0; i < N; ++i) { |
28 | inputs[i] = lerp(-1, +1, (R)i / (R)(N-1)); | 31 | inputs[i] = lerp(-1, +1, (R)i / (R)(N - 1)); |
29 | targets[i] = sigmoid(inputs[i]); | 32 | targets[i] = sigmoid(inputs[i]); |
30 | } | 33 | } |
31 | 34 | ||
@@ -35,29 +38,30 @@ TEST_CASE(neuralnet_train_sigmoid_test) { | |||
35 | nnMatrixInit(&targets_matrix, targets); | 38 | nnMatrixInit(&targets_matrix, targets); |
36 | 39 | ||
37 | nnTrainingParams params = { | 40 | nnTrainingParams params = { |
38 | .learning_rate = 0.9, | 41 | .learning_rate = 0.9, |
39 | .max_iterations = 100, | 42 | .max_iterations = 100, |
40 | .seed = 0, | 43 | .seed = 0, |
41 | .weight_init = nnWeightInit01, | 44 | .weight_init = nnWeightInit01, |
42 | .debug = false, | 45 | .debug = false, |
43 | }; | 46 | }; |
44 | 47 | ||
45 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); | 48 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); |
46 | 49 | ||
47 | const R weight = nnMatrixAt(&net->weights[0], 0, 0); | 50 | const R weight = nnMatrixAt(&net->layers[0].linear.weights, 0, 0); |
48 | const R expected_weight = 1.0; | 51 | const R expected_weight = 1.0; |
49 | printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); | 52 | printf( |
53 | "\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); | ||
50 | TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); | 54 | TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); |
51 | 55 | ||
52 | // Test. | 56 | // Test. |
53 | 57 | ||
54 | nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); | 58 | nnQueryObject* query = nnMakeQueryObject(net, 1); |
55 | 59 | ||
56 | const R test_input[] = { 0.3 }; | 60 | const R test_input[] = {0.3}; |
57 | R test_output[1]; | 61 | R test_output[1]; |
58 | nnQueryArray(net, query, test_input, test_output); | 62 | nnQueryArray(net, query, test_input, test_output); |
59 | 63 | ||
60 | const R expected_output = 0.574442516811659; // sigmoid(0.3) | 64 | const R expected_output = 0.574442516811659; // sigmoid(0.3) |
61 | printf("Output: %f, Expected: %f\n", test_output[0], expected_output); | 65 | printf("Output: %f, Expected: %f\n", test_output[0], expected_output); |
62 | TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS)); | 66 | TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS)); |
63 | 67 | ||