diff options
author | jeanne <jeanne@localhost.localdomain> | 2022-05-11 09:54:38 -0700 |
---|---|---|
committer | jeanne <jeanne@localhost.localdomain> | 2022-05-11 09:54:38 -0700 |
commit | 411f66a2540fa17c736116d865e0ceb0cfe5623b (patch) | |
tree | fa92c69ec627642c8452f928798ff6eccd24ddd6 /src/lib/test/train_sigmoid_test.c | |
parent | 7705b07456dfd4b89c272613e98eda36cc787254 (diff) |
Initial commit.
Diffstat (limited to 'src/lib/test/train_sigmoid_test.c')
-rw-r--r-- | src/lib/test/train_sigmoid_test.c | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/src/lib/test/train_sigmoid_test.c b/src/lib/test/train_sigmoid_test.c new file mode 100644 index 0000000..588e7ca --- /dev/null +++ b/src/lib/test/train_sigmoid_test.c | |||
@@ -0,0 +1,66 @@ | |||
1 | #include <neuralnet/train.h> | ||
2 | |||
3 | #include <neuralnet/matrix.h> | ||
4 | #include <neuralnet/neuralnet.h> | ||
5 | #include "activation.h" | ||
6 | #include "neuralnet_impl.h" | ||
7 | |||
8 | #include "test.h" | ||
9 | #include "test_util.h" | ||
10 | |||
11 | #include <assert.h> | ||
12 | |||
13 | TEST_CASE(neuralnet_train_sigmoid_test) { | ||
14 | const int num_layers = 1; | ||
15 | const int layer_sizes[] = { 1, 1 }; | ||
16 | const nnActivation layer_activations[] = { nnSigmoid }; | ||
17 | |||
18 | nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); | ||
19 | assert(net); | ||
20 | |||
21 | // Train. | ||
22 | |||
23 | // Try to learn the sigmoid function. | ||
24 | #define N 3 | ||
25 | R inputs[N]; | ||
26 | R targets[N]; | ||
27 | for (int i = 0; i < N; ++i) { | ||
28 | inputs[i] = lerp(-1, +1, (R)i / (R)(N-1)); | ||
29 | targets[i] = sigmoid(inputs[i]); | ||
30 | } | ||
31 | |||
32 | nnMatrix inputs_matrix = nnMatrixMake(N, 1); | ||
33 | nnMatrix targets_matrix = nnMatrixMake(N, 1); | ||
34 | nnMatrixInit(&inputs_matrix, inputs); | ||
35 | nnMatrixInit(&targets_matrix, targets); | ||
36 | |||
37 | nnTrainingParams params = { | ||
38 | .learning_rate = 0.9, | ||
39 | .max_iterations = 100, | ||
40 | .seed = 0, | ||
41 | .weight_init = nnWeightInit01, | ||
42 | .debug = false, | ||
43 | }; | ||
44 | |||
45 | nnTrain(net, &inputs_matrix, &targets_matrix, ¶ms); | ||
46 | |||
47 | const R weight = nnMatrixAt(&net->weights[0], 0, 0); | ||
48 | const R expected_weight = 1.0; | ||
49 | printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); | ||
50 | TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); | ||
51 | |||
52 | // Test. | ||
53 | |||
54 | nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); | ||
55 | |||
56 | const R test_input[] = { 0.3 }; | ||
57 | R test_output[1]; | ||
58 | nnQueryArray(net, query, test_input, test_output); | ||
59 | |||
60 | const R expected_output = 0.574442516811659; // sigmoid(0.3) | ||
61 | printf("Output: %f, Expected: %f\n", test_output[0], expected_output); | ||
62 | TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS)); | ||
63 | |||
64 | nnDeleteQueryObject(&query); | ||
65 | nnDeleteNet(&net); | ||
66 | } | ||