diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lib/src/train.c | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index fe9f598..7559ece 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
@@ -239,17 +239,16 @@ void nnTrain( | |||
239 | 239 | ||
240 | // Compute this layer's gradient. | 240 | // Compute this layer's gradient. |
241 | // | 241 | // |
242 | // By "gradient" we mean the expression common to the weights and bias | 242 | // By 'gradient' we mean the subexpression common to all the gradients |
243 | // gradients. This is the part of the expression that does not contain | 243 | // for this layer. |
244 | // this layer's input. | 244 | // For linear layers, this is the subexpression common to both the |
245 | // weights and bias gradients. | ||
245 | // | 246 | // |
246 | // Linear: G = id | 247 | // Linear: G = id |
247 | // Relu: G = (output_k > 0 ? 1 : 0) | 248 | // Relu: G = (output_k > 0 ? 1 : 0) |
248 | // Sigmoid: G = output_k * (1 - output_k) | 249 | // Sigmoid: G = output_k * (1 - output_k) |
249 | switch (layer->type) { | 250 | switch (layer->type) { |
250 | case nnLinear: { | 251 | case nnLinear: { |
251 | // TODO: Just copy the pointer? | ||
252 | *gradient = nnMatrixBorrow(&errors[l]); | ||
253 | break; | 252 | break; |
254 | } | 253 | } |
255 | case nnRelu: | 254 | case nnRelu: |
@@ -294,7 +293,7 @@ void nnTrain( | |||
294 | nnMatrix* layer_biases = &linear->biases; | 293 | nnMatrix* layer_biases = &linear->biases; |
295 | 294 | ||
296 | // Outer product to compute the weight deltas. | 295 | // Outer product to compute the weight deltas. |
297 | nnMatrixMulOuter(layer_input, gradient, &weight_deltas[l]); | 296 | nnMatrixMulOuter(layer_input, &errors[l], &weight_deltas[l]); |
298 | 297 | ||
299 | // Update weights. | 298 | // Update weights. |
300 | nnMatrixScale(&weight_deltas[l], params->learning_rate); | 299 | nnMatrixScale(&weight_deltas[l], params->learning_rate); |
@@ -304,7 +303,7 @@ void nnTrain( | |||
304 | // This is the same formula as for weights, except that the o_j term | 303 | // This is the same formula as for weights, except that the o_j term |
305 | // is just 1. | 304 | // is just 1. |
306 | nnMatrixMulSub( | 305 | nnMatrixMulSub( |
307 | layer_biases, gradient, params->learning_rate, layer_biases); | 306 | layer_biases, &errors[l], params->learning_rate, layer_biases); |
308 | } | 307 | } |
309 | } | 308 | } |
310 | 309 | ||