aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/train.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/src/train.c')
-rw-r--r--src/lib/src/train.c346
1 files changed, 346 insertions, 0 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
new file mode 100644
index 0000000..027de66
--- /dev/null
+++ b/src/lib/src/train.c
@@ -0,0 +1,346 @@
1#include <neuralnet/train.h>
2
3#include <neuralnet/matrix.h>
4#include "neuralnet_impl.h"
5
6#include <random/mt19937-64.h>
7#include <random/normal.h>
8
9#include <assert.h>
10#include <math.h>
11#include <stdlib.h>
12
13#include <stdio.h>
14#define LOGD printf
15
16// If debug mode is requested, we will show progress every this many iterations.
17static const int PROGRESS_THRESHOLD = 5; // %
18
19/// Computes the total MSE from the output error matrix.
20R ComputeMSE(const nnMatrix* errors) {
21 R sum_sq = 0;
22 const int N = errors->rows * errors->cols;
23 const R* value = errors->values;
24 for (int i = 0; i < N; ++i) {
25 sum_sq += *value * *value;
26 value++;
27 }
28 return sum_sq / (R)N;
29}
30
31/// Holds the bits required to compute a sigmoid gradient.
32typedef struct nnSigmoidGradientElements {
33 nnMatrix ones; // A vector of just ones, same size as the layer.
34} nnSigmoidGradientElements;
35
36/// Holds the various elements required to compute gradients. These depend on
37/// what activation function are used, so they'll potentially be different for
38/// each layer. A data type is defined for these because we allocate all the
39/// required memory up front before entering the training loop.
40typedef struct nnGradientElements {
41 nnActivation type;
42 // Gradient vector, same size as the layer.
43 // This will contain the gradient expression except for the output value of
44 // the previous layer.
45 nnMatrix gradient;
46 union {
47 nnSigmoidGradientElements sigmoid;
48 };
49} nnGradientElements;
50
51// Initialize the network's weights randomly and set their biases to 0.
52void nnInitNet(nnNeuralNetwork* net, uint64_t seed, const nnWeightInitStrategy strategy) {
53 assert(net);
54
55 mt19937_64 rng = mt19937_64_make();
56 mt19937_64_init(&rng, seed);
57
58 for (int l = 0; l < net->num_layers; ++l) {
59 nnMatrix* weights = &net->weights[l];
60 nnMatrix* biases = &net->biases[l];
61
62 const R layer_size = (R)nnLayerInputSize(weights);
63 const R scale = 1. / layer_size;
64 const R stdev = 1. / sqrt((R)layer_size);
65 const R sigma = stdev * stdev;
66
67 R* value = weights->values;
68 for (int k = 0; k < weights->rows * weights->cols; ++k) {
69 switch (strategy) {
70 case nnWeightInit01: {
71 const R x01 = mt19937_64_gen_real3(&rng); // (0, +1) interval.
72 *value++ = scale * x01;
73 break;
74 }
75 case nnWeightInit11: {
76 const R x11 = mt19937_64_gen_real4(&rng); // (-1, +1) interval.
77 *value++ = scale * x11;
78 break;
79 }
80 case nnWeightInitNormal:
81 // Using initialization with a normal distribution of standard
82 // deviation 1 / sqrt(num_layer_weights) to prevent saturation when
83 // multiplying inputs.
84 const R u01 = mt19937_64_gen_real3(&rng); // (0, +1) interval.
85 const R v01 = mt19937_64_gen_real3(&rng); // (0, +1) interval.
86 R z0, z1;
87 normal2(u01, v01, &z0, &z1);
88 z0 = normal_transform(z0, /*mu=*/0, sigma);
89 z1 = normal_transform(z1, /*mu=*/0, sigma);
90 *value++ = z0;
91 if (k < weights->rows * weights->cols - 1) {
92 *value++ = z1;
93 ++k;
94 }
95 break;
96 default:
97 assert(false);
98 }
99 }
100
101 // Initialize biases.
102 // 0 is used so that functions originally go through the origin.
103 value = biases->values;
104 for (int k = 0; k < biases->rows * biases->cols; ++k, ++value) {
105 *value = 0;
106 }
107 }
108}
109
110// |inputs| has one row vector per sample.
111// |targets| has one row vector per sample.
112//
113// For now, each iteration trains with one sample (row) at a time.
114void nnTrain(
115 nnNeuralNetwork* net,
116 const nnMatrix* inputs,
117 const nnMatrix* targets,
118 const nnTrainingParams* params) {
119 assert(net);
120 assert(inputs);
121 assert(targets);
122 assert(params);
123 assert(nnNetOutputSize(net) == targets->cols);
124 assert(net->num_layers > 0);
125
126 // Allocate error vectors to hold the backpropagated error values.
127 // For now, these are one row vector per layer, meaning that we will train
128 // with one sample at a time.
129 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix));
130
131 // Allocate the weight transpose matrices up front for backpropagation.
132 nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix));
133
134 // Allocate the weight delta matrices.
135 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix));
136
137 // Allocate the data structures required to compute gradients.
138 // This depends on each layer's activation type.
139 nnGradientElements* gradient_elems = calloc(net->num_layers, sizeof(nnGradientElements));
140
141 // Allocate the output transpose vectors for weight delta calculation.
142 // This is one column vector per layer.
143 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix));
144
145 assert(errors != 0);
146 assert(weights_T != 0);
147 assert(weight_deltas != 0);
148 assert(gradient_elems);
149 assert(outputs_T);
150
151 for (int l = 0; l < net->num_layers; ++l) {
152 const nnMatrix* layer_weights = &net->weights[l];
153 const int layer_output_size = net->weights[l].cols;
154 const nnActivation activation = net->activations[l];
155
156 errors[l] = nnMatrixMake(1, layer_weights->cols);
157
158 weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows);
159 nnMatrixTranspose(layer_weights, &weights_T[l]);
160
161 weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols);
162
163 outputs_T[l] = nnMatrixMake(layer_output_size, 1);
164
165 // Allocate the gradient elements and vectors for weight delta calculation.
166 nnGradientElements* elems = &gradient_elems[l];
167 elems->type = activation;
168 switch (activation) {
169 case nnIdentity:
170 break; // Gradient vector will be borrowed, no need to allocate.
171
172 case nnSigmoid:
173 elems->gradient = nnMatrixMake(1, layer_output_size);
174 // Allocate the 1s vectors.
175 elems->sigmoid.ones = nnMatrixMake(1, layer_output_size);
176 nnMatrixInitConstant(&elems->sigmoid.ones, 1);
177 break;
178
179 case nnRelu:
180 elems->gradient = nnMatrixMake(1, layer_output_size);
181 break;
182 }
183 }
184
185 // Construct the query object with a size of 1 since we are training with one
186 // sample at a time.
187 nnQueryObject* query = nnMakeQueryObject(net, 1);
188
189 // Network outputs are given by the query object. Every network query updates
190 // the outputs.
191 const nnMatrix* const training_outputs = query->network_outputs;
192
193 // A vector to store the training input transposed.
194 nnMatrix training_inputs_T = nnMatrixMake(inputs->cols, 1);
195
196 // If debug mode is requested, we will show progress every Nth iteration.
197 const int progress_frame =
198 (params->max_iterations < PROGRESS_THRESHOLD)
199 ? 1
200 : (params->max_iterations * PROGRESS_THRESHOLD / 100);
201
202 // --- TRAIN
203
204 nnInitNet(net, params->seed, params->weight_init);
205
206 for (int iter = 0; iter < params->max_iterations; ++iter) {
207
208 // For now, we train with one sample at a time.
209 for (int sample = 0; sample < inputs->rows; ++sample) {
210 // Slice the input and target matrices with the batch size.
211 // We are not mutating the inputs, but we need the cast to borrow.
212 nnMatrix training_inputs = nnMatrixBorrowRows((nnMatrix*)inputs, sample, 1);
213 nnMatrix training_targets = nnMatrixBorrowRows((nnMatrix*)targets, sample, 1);
214
215 // Will need the input transposed for backpropagation.
216 // Assuming one training input per iteration for now.
217 nnMatrixTranspose(&training_inputs, &training_inputs_T);
218
219 // Run a forward pass and compute the output layer error.
220 // We don't square the error here; instead, we just compute t-o, which is
221 // part of the derivative, -2(t-o). Also, we compute o-t instead to
222 // remove that outer negative sign.
223 nnQuery(net, query, &training_inputs);
224 //nnMatrixSub(&training_targets, training_outputs, &errors[net->num_layers - 1]);
225 nnMatrixSub(training_outputs, &training_targets, &errors[net->num_layers - 1]);
226
227 // Update outputs_T, which we need during weight updates.
228 for (int l = 0; l < net->num_layers; ++l) {
229 nnMatrixTranspose(&query->layer_outputs[l], &outputs_T[l]);
230 }
231
232 // Update weights and biases for each internal layer, backpropagating
233 // errors along the way.
234 for (int l = net->num_layers - 1; l >= 0; --l) {
235 const nnMatrix* layer_output = &query->layer_outputs[l];
236 nnMatrix* layer_weights = &net->weights[l];
237 nnMatrix* layer_biases = &net->biases[l];
238 nnGradientElements* elems = &gradient_elems[l];
239 nnMatrix* gradient = &elems->gradient;
240 const nnActivation activation = net->activations[l];
241
242 // Compute the gradient (the part of the expression that does not
243 // contain the output of the previous layer).
244 //
245 // Identity: G = error_k
246 // Sigmoid: G = error_k * output_k * (1 - output_k).
247 // Relu: G = error_k * (output_k > 0 ? 1 : 0)
248 switch (activation) {
249 case nnIdentity:
250 // TODO: Just copy the pointer?
251 *gradient = nnMatrixBorrow(&errors[l]);
252 break;
253 case nnSigmoid:
254 nnMatrixSub(&elems->sigmoid.ones, layer_output, gradient);
255 nnMatrixMulPairs(layer_output, gradient, gradient);
256 nnMatrixMulPairs(&errors[l], gradient, gradient);
257 break;
258 case nnRelu:
259 nnMatrixGt(layer_output, 0, gradient);
260 nnMatrixMulPairs(&errors[l], gradient, gradient);
261 break;
262 }
263
264 // Outer product to compute the weight deltas.
265 const nnMatrix* output_T = (l == 0) ? &training_inputs_T : &outputs_T[l-1];
266 nnMatrixMul(output_T, gradient, &weight_deltas[l]);
267
268 // Backpropagate the error before updating weights.
269 if (l > 0) {
270 nnMatrixMul(gradient, &weights_T[l], &errors[l-1]);
271 }
272
273 // Update weights.
274 nnMatrixScale(&weight_deltas[l], params->learning_rate);
275 // The gradient has a negative sign from -(t - o), but we have computed
276 // e = o - t instead, so we can subtract directly.
277 //nnMatrixAdd(layer_weights, &weight_deltas[l], layer_weights);
278 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights);
279
280 // Update weight transpose matrix for the next training iteration.
281 nnMatrixTranspose(layer_weights, &weights_T[l]);
282
283 // Update biases.
284 // This is the same formula as for weights, except that the o_j term is
285 // just 1. We can simply re-use the gradient that we have already
286 // computed for the weight update.
287 //nnMatrixMulAdd(layer_biases, gradient, params->learning_rate, layer_biases);
288 nnMatrixMulSub(layer_biases, gradient, params->learning_rate, layer_biases);
289 }
290
291 // TODO: Add this under a verbose debugging mode.
292 // if (params->debug) {
293 // LOGD("Iter: %d, Sample: %d, Error: %f\n", iter, sample, ComputeMSE(&errors[net->num_layers - 1]));
294 // LOGD("TGT: ");
295 // for (int i = 0; i < training_targets.cols; ++i) {
296 // printf("%.3f ", training_targets.values[i]);
297 // }
298 // printf("\n");
299 // LOGD("OUT: ");
300 // for (int i = 0; i < training_outputs->cols; ++i) {
301 // printf("%.3f ", training_outputs->values[i]);
302 // }
303 // printf("\n");
304 // }
305 }
306
307 if (params->debug && ((iter % progress_frame) == 0)) {
308 LOGD("Iter: %d/%d, Error: %f\n",
309 iter, params->max_iterations, ComputeMSE(&errors[net->num_layers - 1]));
310 }
311 }
312
313 // Print the final error.
314 if (params->debug) {
315 LOGD("Iter: %d/%d, Error: %f\n",
316 params->max_iterations, params->max_iterations, ComputeMSE(&errors[net->num_layers - 1]));
317 }
318
319 for (int l = 0; l < net->num_layers; ++l) {
320 nnMatrixDel(&errors[l]);
321 nnMatrixDel(&outputs_T[l]);
322 nnMatrixDel(&weights_T[l]);
323 nnMatrixDel(&weight_deltas[l]);
324
325 nnGradientElements* elems = &gradient_elems[l];
326 switch (elems->type) {
327 case nnIdentity:
328 break; // Gradient vector is borrowed, no need to deallocate.
329
330 case nnSigmoid:
331 nnMatrixDel(&elems->gradient);
332 nnMatrixDel(&elems->sigmoid.ones);
333 break;
334
335 case nnRelu:
336 nnMatrixDel(&elems->gradient);
337 break;
338 }
339 }
340 nnMatrixDel(&training_inputs_T);
341 free(errors);
342 free(outputs_T);
343 free(weights_T);
344 free(weight_deltas);
345 free(gradient_elems);
346}