aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/train.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/src/train.c')
-rw-r--r--src/lib/src/train.c28
1 files changed, 3 insertions, 25 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
index ccff553..fe9f598 100644
--- a/src/lib/src/train.c
+++ b/src/lib/src/train.c
@@ -153,14 +153,9 @@ void nnTrain(
153 nnGradientElements* gradient_elems = 153 nnGradientElements* gradient_elems =
154 calloc(net->num_layers, sizeof(nnGradientElements)); 154 calloc(net->num_layers, sizeof(nnGradientElements));
155 155
156 // Allocate the output transpose vectors for weight delta calculation.
157 // This is one column vector per layer.
158 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix));
159
160 assert(errors != 0); 156 assert(errors != 0);
161 assert(weight_deltas != 0); 157 assert(weight_deltas != 0);
162 assert(gradient_elems); 158 assert(gradient_elems);
163 assert(outputs_T);
164 159
165 for (int l = 0; l < net->num_layers; ++l) { 160 for (int l = 0; l < net->num_layers; ++l) {
166 const int layer_input_size = nnLayerInputSize(net, l); 161 const int layer_input_size = nnLayerInputSize(net, l);
@@ -169,7 +164,6 @@ void nnTrain(
169 164
170 errors[l] = nnMatrixMake(1, layer_output_size); 165 errors[l] = nnMatrixMake(1, layer_output_size);
171 weight_deltas[l] = nnMatrixMake(layer_input_size, layer_output_size); 166 weight_deltas[l] = nnMatrixMake(layer_input_size, layer_output_size);
172 outputs_T[l] = nnMatrixMake(layer_output_size, 1);
173 167
174 // Allocate the gradient elements and vectors for weight delta calculation. 168 // Allocate the gradient elements and vectors for weight delta calculation.
175 nnGradientElements* elems = &gradient_elems[l]; 169 nnGradientElements* elems = &gradient_elems[l];
@@ -199,9 +193,6 @@ void nnTrain(
199 // the outputs. 193 // the outputs.
200 const nnMatrix* const training_outputs = query->network_outputs; 194 const nnMatrix* const training_outputs = query->network_outputs;
201 195
202 // A vector to store the training input transposed.
203 nnMatrix training_inputs_T = nnMatrixMake(inputs->cols, 1);
204
205 // If debug mode is requested, we will show progress every Nth iteration. 196 // If debug mode is requested, we will show progress every Nth iteration.
206 const int progress_frame = 197 const int progress_frame =
207 (params->max_iterations < PROGRESS_THRESHOLD) 198 (params->max_iterations < PROGRESS_THRESHOLD)
@@ -223,10 +214,6 @@ void nnTrain(
223 const nnMatrix training_targets = 214 const nnMatrix training_targets =
224 nnMatrixBorrowRows((nnMatrix*)targets, sample, 1); 215 nnMatrixBorrowRows((nnMatrix*)targets, sample, 1);
225 216
226 // Will need the input transposed for backpropagation.
227 // Assuming one training input per iteration for now.
228 nnMatrixTranspose(&training_inputs, &training_inputs_T);
229
230 // Forward pass. 217 // Forward pass.
231 nnQuery(net, query, &training_inputs); 218 nnQuery(net, query, &training_inputs);
232 219
@@ -240,14 +227,11 @@ void nnTrain(
240 nnMatrixSub( 227 nnMatrixSub(
241 training_outputs, &training_targets, &errors[net->num_layers - 1]); 228 training_outputs, &training_targets, &errors[net->num_layers - 1]);
242 229
243 // Update outputs_T, which we need during weight updates.
244 for (int l = 0; l < net->num_layers; ++l) {
245 nnMatrixTranspose(&query->layer_outputs[l], &outputs_T[l]);
246 }
247
248 // Update weights and biases for each internal layer, back-propagating 230 // Update weights and biases for each internal layer, back-propagating
249 // errors along the way. 231 // errors along the way.
250 for (int l = net->num_layers - 1; l >= 0; --l) { 232 for (int l = net->num_layers - 1; l >= 0; --l) {
233 const nnMatrix* layer_input =
234 (l == 0) ? &training_inputs : &query->layer_outputs[l - 1];
251 const nnMatrix* layer_output = &query->layer_outputs[l]; 235 const nnMatrix* layer_output = &query->layer_outputs[l];
252 nnGradientElements* elems = &gradient_elems[l]; 236 nnGradientElements* elems = &gradient_elems[l];
253 nnMatrix* gradient = &elems->gradient; 237 nnMatrix* gradient = &elems->gradient;
@@ -310,10 +294,7 @@ void nnTrain(
310 nnMatrix* layer_biases = &linear->biases; 294 nnMatrix* layer_biases = &linear->biases;
311 295
312 // Outer product to compute the weight deltas. 296 // Outer product to compute the weight deltas.
313 // This layer's input is the previous layer's output. 297 nnMatrixMulOuter(layer_input, gradient, &weight_deltas[l]);
314 const nnMatrix* input_T =
315 (l == 0) ? &training_inputs_T : &outputs_T[l - 1];
316 nnMatrixMul(input_T, gradient, &weight_deltas[l]);
317 298
318 // Update weights. 299 // Update weights.
319 nnMatrixScale(&weight_deltas[l], params->learning_rate); 300 nnMatrixScale(&weight_deltas[l], params->learning_rate);
@@ -360,7 +341,6 @@ void nnTrain(
360 // Clean up. 341 // Clean up.
361 for (int l = 0; l < net->num_layers; ++l) { 342 for (int l = 0; l < net->num_layers; ++l) {
362 nnMatrixDel(&errors[l]); 343 nnMatrixDel(&errors[l]);
363 nnMatrixDel(&outputs_T[l]);
364 nnMatrixDel(&weight_deltas[l]); 344 nnMatrixDel(&weight_deltas[l]);
365 345
366 nnGradientElements* elems = &gradient_elems[l]; 346 nnGradientElements* elems = &gradient_elems[l];
@@ -378,9 +358,7 @@ void nnTrain(
378 break; 358 break;
379 } 359 }
380 } 360 }
381 nnMatrixDel(&training_inputs_T);
382 free(errors); 361 free(errors);
383 free(outputs_T);
384 free(weight_deltas); 362 free(weight_deltas);
385 free(gradient_elems); 363 free(gradient_elems);
386} 364}