aboutsummaryrefslogtreecommitdiff
path: root/src/lib/test/train_linear_perceptron_test.c
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2023-12-16 10:21:16 -0800
committer3gg <3gg@shellblade.net>2023-12-16 10:21:16 -0800
commit653e98e029a0d0f110b0ac599e50406060bb0f87 (patch)
tree6f909215218f6720266bde1b3f49aeddad8b1da3 /src/lib/test/train_linear_perceptron_test.c
parent3df7b6fb0c65295eed4590e6f166d60e89b3c68e (diff)
Decouple activations from linear layer.
Diffstat (limited to 'src/lib/test/train_linear_perceptron_test.c')
-rw-r--r--src/lib/test/train_linear_perceptron_test.c44
1 files changed, 23 insertions, 21 deletions
diff --git a/src/lib/test/train_linear_perceptron_test.c b/src/lib/test/train_linear_perceptron_test.c
index 2b1336d..667643b 100644
--- a/src/lib/test/train_linear_perceptron_test.c
+++ b/src/lib/test/train_linear_perceptron_test.c
@@ -1,9 +1,8 @@
1#include <neuralnet/train.h> 1#include <neuralnet/train.h>
2 2
3#include "neuralnet_impl.h"
3#include <neuralnet/matrix.h> 4#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h> 5#include <neuralnet/neuralnet.h>
5#include "activation.h"
6#include "neuralnet_impl.h"
7 6
8#include "test.h" 7#include "test.h"
9#include "test_util.h" 8#include "test_util.h"
@@ -11,19 +10,21 @@
11#include <assert.h> 10#include <assert.h>
12 11
13TEST_CASE(neuralnet_train_linear_perceptron_test) { 12TEST_CASE(neuralnet_train_linear_perceptron_test) {
14 const int num_layers = 1; 13 const int num_layers = 1;
15 const int layer_sizes[] = { 1, 1 }; 14 const int input_size = 1;
16 const nnActivation layer_activations[] = { nnIdentity }; 15 const nnLayer layers[] = {
16 {.type = nnLinear, .linear = {.input_size = 1, .output_size = 1}}
17 };
17 18
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 19 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 20 assert(net);
20 21
21 // Train. 22// Train.
22 23
23 // Try to learn the Y=X line. 24// Try to learn the Y=X line.
24 #define N 2 25#define N 2
25 const R inputs[N] = { 0., 1. }; 26 const R inputs[N] = {0., 1.};
26 const R targets[N] = { 0., 1. }; 27 const R targets[N] = {0., 1.};
27 28
28 nnMatrix inputs_matrix = nnMatrixMake(N, 1); 29 nnMatrix inputs_matrix = nnMatrixMake(N, 1);
29 nnMatrix targets_matrix = nnMatrixMake(N, 1); 30 nnMatrix targets_matrix = nnMatrixMake(N, 1);
@@ -31,26 +32,27 @@ TEST_CASE(neuralnet_train_linear_perceptron_test) {
31 nnMatrixInit(&targets_matrix, targets); 32 nnMatrixInit(&targets_matrix, targets);
32 33
33 nnTrainingParams params = { 34 nnTrainingParams params = {
34 .learning_rate = 0.7, 35 .learning_rate = 0.7,
35 .max_iterations = 10, 36 .max_iterations = 10,
36 .seed = 0, 37 .seed = 0,
37 .weight_init = nnWeightInit01, 38 .weight_init = nnWeightInit01,
38 .debug = false, 39 .debug = false,
39 }; 40 };
40 41
41 nnTrain(net, &inputs_matrix, &targets_matrix, &params); 42 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
42 43
43 const R weight = nnMatrixAt(&net->weights[0], 0, 0); 44 const R weight = nnMatrixAt(&net->layers[0].linear.weights, 0, 0);
44 const R expected_weight = 1.0; 45 const R expected_weight = 1.0;
45 printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); 46 printf(
47 "\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
46 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); 48 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
47 49
48 // Test. 50 // Test.
49 51
50 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); 52 nnQueryObject* query = nnMakeQueryObject(net, 1);
51 53
52 const R test_input[] = { 2.3 }; 54 const R test_input[] = {2.3};
53 R test_output[1]; 55 R test_output[1];
54 nnQueryArray(net, query, test_input, test_output); 56 nnQueryArray(net, query, test_input, test_output);
55 57
56 const R expected_output = test_input[0]; 58 const R expected_output = test_input[0];