aboutsummaryrefslogtreecommitdiff
path: root/src/lib/test/neuralnet_test.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/test/neuralnet_test.c')
-rw-r--r--src/lib/test/neuralnet_test.c92
1 files changed, 92 insertions, 0 deletions
diff --git a/src/lib/test/neuralnet_test.c b/src/lib/test/neuralnet_test.c
new file mode 100644
index 0000000..14d9438
--- /dev/null
+++ b/src/lib/test/neuralnet_test.c
@@ -0,0 +1,92 @@
1#include <neuralnet/neuralnet.h>
2
3#include <neuralnet/matrix.h>
4#include "activation.h"
5#include "neuralnet_impl.h"
6
7#include "test.h"
8#include "test_util.h"
9
10#include <assert.h>
11
12TEST_CASE(neuralnet_perceptron_test) {
13 const int num_layers = 1;
14 const int layer_sizes[] = { 1, 1 };
15 const nnActivation layer_activations[] = { nnSigmoid };
16 const R weights[] = { 0.3 };
17
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
19 assert(net);
20 nnSetWeights(net, weights);
21
22 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
23
24 const R input[] = { 0.9 };
25 R output[1];
26 nnQueryArray(net, query, input, output);
27
28 const R expected_output = sigmoid(input[0] * weights[0]);
29 printf("\nOutput: %f, Expected: %f\n", output[0], expected_output);
30 TEST_TRUE(double_eq(output[0], expected_output, EPS));
31
32 nnDeleteQueryObject(&query);
33 nnDeleteNet(&net);
34}
35
36TEST_CASE(neuralnet_xor_test) {
37 const int num_layers = 2;
38 const int layer_sizes[] = { 2, 2, 1 };
39 const nnActivation layer_activations[] = { nnRelu, nnIdentity };
40 const R weights[] = {
41 1, 1, 1, 1, // First (hidden) layer.
42 1, -2 // Second (output) layer.
43 };
44 const R biases[] = {
45 0, -1, // First (hidden) layer.
46 0 // Second (output) layer.
47 };
48
49 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
50 assert(net);
51 nnSetWeights(net, weights);
52 nnSetBiases(net, biases);
53
54 // First layer weights.
55 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 0), 1);
56 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 1), 1);
57 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 2), 1);
58 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 3), 1);
59 // Second layer weights.
60 TEST_EQUAL(nnMatrixAt(&net->weights[1], 0, 0), 1);
61 TEST_EQUAL(nnMatrixAt(&net->weights[1], 0, 1), -2);
62 // First layer biases.
63 TEST_EQUAL(nnMatrixAt(&net->biases[0], 0, 0), 0);
64 TEST_EQUAL(nnMatrixAt(&net->biases[0], 0, 1), -1);
65 // Second layer biases.
66 TEST_EQUAL(nnMatrixAt(&net->biases[1], 0, 0), 0);
67
68 // Test.
69
70 #define M 4
71
72 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M);
73
74 const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } };
75 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2);
76 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs);
77 nnQuery(net, query, &test_inputs_matrix);
78
79 const R expected_outputs[M] = { 0., 1., 1., 0. };
80 for (int i = 0; i < M; ++i) {
81 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
82 printf("\nInput: (%f, %f), Output: %f, Expected: %f\n",
83 test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]);
84 }
85 for (int i = 0; i < M; ++i) {
86 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
87 TEST_TRUE(double_eq(test_output, expected_outputs[i], OUTPUT_EPS));
88 }
89
90 nnDeleteQueryObject(&query);
91 nnDeleteNet(&net);
92}