diff options
Diffstat (limited to 'src/lib/src/train.c')
-rw-r--r-- | src/lib/src/train.c | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index 9244907..dc93f0f 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
@@ -219,13 +219,15 @@ void nnTrain( | |||
219 | // Assuming one training input per iteration for now. | 219 | // Assuming one training input per iteration for now. |
220 | nnMatrixTranspose(&training_inputs, &training_inputs_T); | 220 | nnMatrixTranspose(&training_inputs, &training_inputs_T); |
221 | 221 | ||
222 | // Run a forward pass and compute the output layer error. | 222 | // Run a forward pass and compute the output layer error relevant to the |
223 | // We don't square the error here; instead, we just compute t-o, which is | 223 | // derivative: o-t. |
224 | // part of the derivative, -2(t-o). Also, we compute o-t instead to | 224 | // Error: (t-o)^2 |
225 | // remove that outer negative sign. | 225 | // dE/do = -2(t-o) |
226 | // = +2(o-t) | ||
227 | // Note that we compute o-t instead to remove that outer negative sign. | ||
228 | // The 2 is dropped because we are only interested in the direction of the | ||
229 | // gradient. The learning rate controls the magnitude. | ||
226 | nnQuery(net, query, &training_inputs); | 230 | nnQuery(net, query, &training_inputs); |
227 | // nnMatrixSub(&training_targets, training_outputs, | ||
228 | // &errors[net->num_layers - 1]); | ||
229 | nnMatrixSub( | 231 | nnMatrixSub( |
230 | training_outputs, &training_targets, &errors[net->num_layers - 1]); | 232 | training_outputs, &training_targets, &errors[net->num_layers - 1]); |
231 | 233 | ||
@@ -328,6 +330,7 @@ void nnTrain( | |||
328 | params->max_iterations, ComputeMSE(&errors[net->num_layers - 1])); | 330 | params->max_iterations, ComputeMSE(&errors[net->num_layers - 1])); |
329 | } | 331 | } |
330 | 332 | ||
333 | // Clean up. | ||
331 | for (int l = 0; l < net->num_layers; ++l) { | 334 | for (int l = 0; l < net->num_layers; ++l) { |
332 | nnMatrixDel(&errors[l]); | 335 | nnMatrixDel(&errors[l]); |
333 | nnMatrixDel(&outputs_T[l]); | 336 | nnMatrixDel(&outputs_T[l]); |