diff options
Diffstat (limited to 'src/lib/src/train.c')
-rw-r--r-- | src/lib/src/train.c | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index 027de66..3061a99 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
@@ -129,7 +129,7 @@ void nnTrain( | |||
129 | nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); | 129 | nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); |
130 | 130 | ||
131 | // Allocate the weight transpose matrices up front for backpropagation. | 131 | // Allocate the weight transpose matrices up front for backpropagation. |
132 | nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); | 132 | //nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); |
133 | 133 | ||
134 | // Allocate the weight delta matrices. | 134 | // Allocate the weight delta matrices. |
135 | nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); | 135 | nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); |
@@ -143,7 +143,7 @@ void nnTrain( | |||
143 | nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); | 143 | nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); |
144 | 144 | ||
145 | assert(errors != 0); | 145 | assert(errors != 0); |
146 | assert(weights_T != 0); | 146 | //assert(weights_T != 0); |
147 | assert(weight_deltas != 0); | 147 | assert(weight_deltas != 0); |
148 | assert(gradient_elems); | 148 | assert(gradient_elems); |
149 | assert(outputs_T); | 149 | assert(outputs_T); |
@@ -155,8 +155,8 @@ void nnTrain( | |||
155 | 155 | ||
156 | errors[l] = nnMatrixMake(1, layer_weights->cols); | 156 | errors[l] = nnMatrixMake(1, layer_weights->cols); |
157 | 157 | ||
158 | weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); | 158 | //weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); |
159 | nnMatrixTranspose(layer_weights, &weights_T[l]); | 159 | //nnMatrixTranspose(layer_weights, &weights_T[l]); |
160 | 160 | ||
161 | weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); | 161 | weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); |
162 | 162 | ||
@@ -267,7 +267,9 @@ void nnTrain( | |||
267 | 267 | ||
268 | // Backpropagate the error before updating weights. | 268 | // Backpropagate the error before updating weights. |
269 | if (l > 0) { | 269 | if (l > 0) { |
270 | nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); | 270 | // G * W^T == G *^T W. |
271 | //nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); | ||
272 | nnMatrixMulRows(gradient, layer_weights, &errors[l-1]); | ||
271 | } | 273 | } |
272 | 274 | ||
273 | // Update weights. | 275 | // Update weights. |
@@ -278,7 +280,7 @@ void nnTrain( | |||
278 | nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); | 280 | nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); |
279 | 281 | ||
280 | // Update weight transpose matrix for the next training iteration. | 282 | // Update weight transpose matrix for the next training iteration. |
281 | nnMatrixTranspose(layer_weights, &weights_T[l]); | 283 | //nnMatrixTranspose(layer_weights, &weights_T[l]); |
282 | 284 | ||
283 | // Update biases. | 285 | // Update biases. |
284 | // This is the same formula as for weights, except that the o_j term is | 286 | // This is the same formula as for weights, except that the o_j term is |
@@ -319,7 +321,7 @@ void nnTrain( | |||
319 | for (int l = 0; l < net->num_layers; ++l) { | 321 | for (int l = 0; l < net->num_layers; ++l) { |
320 | nnMatrixDel(&errors[l]); | 322 | nnMatrixDel(&errors[l]); |
321 | nnMatrixDel(&outputs_T[l]); | 323 | nnMatrixDel(&outputs_T[l]); |
322 | nnMatrixDel(&weights_T[l]); | 324 | //nnMatrixDel(&weights_T[l]); |
323 | nnMatrixDel(&weight_deltas[l]); | 325 | nnMatrixDel(&weight_deltas[l]); |
324 | 326 | ||
325 | nnGradientElements* elems = &gradient_elems[l]; | 327 | nnGradientElements* elems = &gradient_elems[l]; |
@@ -340,7 +342,7 @@ void nnTrain( | |||
340 | nnMatrixDel(&training_inputs_T); | 342 | nnMatrixDel(&training_inputs_T); |
341 | free(errors); | 343 | free(errors); |
342 | free(outputs_T); | 344 | free(outputs_T); |
343 | free(weights_T); | 345 | //free(weights_T); |
344 | free(weight_deltas); | 346 | free(weight_deltas); |
345 | free(gradient_elems); | 347 | free(gradient_elems); |
346 | } | 348 | } |