aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/train.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/src/train.c')
-rw-r--r--src/lib/src/train.c18
1 files changed, 10 insertions, 8 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
index 027de66..3061a99 100644
--- a/src/lib/src/train.c
+++ b/src/lib/src/train.c
@@ -129,7 +129,7 @@ void nnTrain(
129 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); 129 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix));
130 130
131 // Allocate the weight transpose matrices up front for backpropagation. 131 // Allocate the weight transpose matrices up front for backpropagation.
132 nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); 132 //nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix));
133 133
134 // Allocate the weight delta matrices. 134 // Allocate the weight delta matrices.
135 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); 135 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix));
@@ -143,7 +143,7 @@ void nnTrain(
143 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); 143 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix));
144 144
145 assert(errors != 0); 145 assert(errors != 0);
146 assert(weights_T != 0); 146 //assert(weights_T != 0);
147 assert(weight_deltas != 0); 147 assert(weight_deltas != 0);
148 assert(gradient_elems); 148 assert(gradient_elems);
149 assert(outputs_T); 149 assert(outputs_T);
@@ -155,8 +155,8 @@ void nnTrain(
155 155
156 errors[l] = nnMatrixMake(1, layer_weights->cols); 156 errors[l] = nnMatrixMake(1, layer_weights->cols);
157 157
158 weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); 158 //weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows);
159 nnMatrixTranspose(layer_weights, &weights_T[l]); 159 //nnMatrixTranspose(layer_weights, &weights_T[l]);
160 160
161 weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); 161 weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols);
162 162
@@ -267,7 +267,9 @@ void nnTrain(
267 267
268 // Backpropagate the error before updating weights. 268 // Backpropagate the error before updating weights.
269 if (l > 0) { 269 if (l > 0) {
270 nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); 270 // G * W^T == G *^T W.
271 //nnMatrixMul(gradient, &weights_T[l], &errors[l-1]);
272 nnMatrixMulRows(gradient, layer_weights, &errors[l-1]);
271 } 273 }
272 274
273 // Update weights. 275 // Update weights.
@@ -278,7 +280,7 @@ void nnTrain(
278 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); 280 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights);
279 281
280 // Update weight transpose matrix for the next training iteration. 282 // Update weight transpose matrix for the next training iteration.
281 nnMatrixTranspose(layer_weights, &weights_T[l]); 283 //nnMatrixTranspose(layer_weights, &weights_T[l]);
282 284
283 // Update biases. 285 // Update biases.
284 // This is the same formula as for weights, except that the o_j term is 286 // This is the same formula as for weights, except that the o_j term is
@@ -319,7 +321,7 @@ void nnTrain(
319 for (int l = 0; l < net->num_layers; ++l) { 321 for (int l = 0; l < net->num_layers; ++l) {
320 nnMatrixDel(&errors[l]); 322 nnMatrixDel(&errors[l]);
321 nnMatrixDel(&outputs_T[l]); 323 nnMatrixDel(&outputs_T[l]);
322 nnMatrixDel(&weights_T[l]); 324 //nnMatrixDel(&weights_T[l]);
323 nnMatrixDel(&weight_deltas[l]); 325 nnMatrixDel(&weight_deltas[l]);
324 326
325 nnGradientElements* elems = &gradient_elems[l]; 327 nnGradientElements* elems = &gradient_elems[l];
@@ -340,7 +342,7 @@ void nnTrain(
340 nnMatrixDel(&training_inputs_T); 342 nnMatrixDel(&training_inputs_T);
341 free(errors); 343 free(errors);
342 free(outputs_T); 344 free(outputs_T);
343 free(weights_T); 345 //free(weights_T);
344 free(weight_deltas); 346 free(weight_deltas);
345 free(gradient_elems); 347 free(gradient_elems);
346} 348}