aboutsummaryrefslogtreecommitdiff
path: root/src/lib/test/train_sigmoid_test.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/test/train_sigmoid_test.c')
-rw-r--r--src/lib/test/train_sigmoid_test.c66
1 files changed, 66 insertions, 0 deletions
diff --git a/src/lib/test/train_sigmoid_test.c b/src/lib/test/train_sigmoid_test.c
new file mode 100644
index 0000000..588e7ca
--- /dev/null
+++ b/src/lib/test/train_sigmoid_test.c
@@ -0,0 +1,66 @@
1#include <neuralnet/train.h>
2
3#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h>
5#include "activation.h"
6#include "neuralnet_impl.h"
7
8#include "test.h"
9#include "test_util.h"
10
11#include <assert.h>
12
13TEST_CASE(neuralnet_train_sigmoid_test) {
14 const int num_layers = 1;
15 const int layer_sizes[] = { 1, 1 };
16 const nnActivation layer_activations[] = { nnSigmoid };
17
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
19 assert(net);
20
21 // Train.
22
23 // Try to learn the sigmoid function.
24 #define N 3
25 R inputs[N];
26 R targets[N];
27 for (int i = 0; i < N; ++i) {
28 inputs[i] = lerp(-1, +1, (R)i / (R)(N-1));
29 targets[i] = sigmoid(inputs[i]);
30 }
31
32 nnMatrix inputs_matrix = nnMatrixMake(N, 1);
33 nnMatrix targets_matrix = nnMatrixMake(N, 1);
34 nnMatrixInit(&inputs_matrix, inputs);
35 nnMatrixInit(&targets_matrix, targets);
36
37 nnTrainingParams params = {
38 .learning_rate = 0.9,
39 .max_iterations = 100,
40 .seed = 0,
41 .weight_init = nnWeightInit01,
42 .debug = false,
43 };
44
45 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
46
47 const R weight = nnMatrixAt(&net->weights[0], 0, 0);
48 const R expected_weight = 1.0;
49 printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
50 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
51
52 // Test.
53
54 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
55
56 const R test_input[] = { 0.3 };
57 R test_output[1];
58 nnQueryArray(net, query, test_input, test_output);
59
60 const R expected_output = 0.574442516811659; // sigmoid(0.3)
61 printf("Output: %f, Expected: %f\n", test_output[0], expected_output);
62 TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS));
63
64 nnDeleteQueryObject(&query);
65 nnDeleteNet(&net);
66}