aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/lib/src/train.c13
1 files changed, 6 insertions, 7 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
index fe9f598..7559ece 100644
--- a/src/lib/src/train.c
+++ b/src/lib/src/train.c
@@ -239,17 +239,16 @@ void nnTrain(
239 239
240 // Compute this layer's gradient. 240 // Compute this layer's gradient.
241 // 241 //
242 // By "gradient" we mean the expression common to the weights and bias 242 // By 'gradient' we mean the subexpression common to all the gradients
243 // gradients. This is the part of the expression that does not contain 243 // for this layer.
244 // this layer's input. 244 // For linear layers, this is the subexpression common to both the
245 // weights and bias gradients.
245 // 246 //
246 // Linear: G = id 247 // Linear: G = id
247 // Relu: G = (output_k > 0 ? 1 : 0) 248 // Relu: G = (output_k > 0 ? 1 : 0)
248 // Sigmoid: G = output_k * (1 - output_k) 249 // Sigmoid: G = output_k * (1 - output_k)
249 switch (layer->type) { 250 switch (layer->type) {
250 case nnLinear: { 251 case nnLinear: {
251 // TODO: Just copy the pointer?
252 *gradient = nnMatrixBorrow(&errors[l]);
253 break; 252 break;
254 } 253 }
255 case nnRelu: 254 case nnRelu:
@@ -294,7 +293,7 @@ void nnTrain(
294 nnMatrix* layer_biases = &linear->biases; 293 nnMatrix* layer_biases = &linear->biases;
295 294
296 // Outer product to compute the weight deltas. 295 // Outer product to compute the weight deltas.
297 nnMatrixMulOuter(layer_input, gradient, &weight_deltas[l]); 296 nnMatrixMulOuter(layer_input, &errors[l], &weight_deltas[l]);
298 297
299 // Update weights. 298 // Update weights.
300 nnMatrixScale(&weight_deltas[l], params->learning_rate); 299 nnMatrixScale(&weight_deltas[l], params->learning_rate);
@@ -304,7 +303,7 @@ void nnTrain(
304 // This is the same formula as for weights, except that the o_j term 303 // This is the same formula as for weights, except that the o_j term
305 // is just 1. 304 // is just 1.
306 nnMatrixMulSub( 305 nnMatrixMulSub(
307 layer_biases, gradient, params->learning_rate, layer_biases); 306 layer_biases, &errors[l], params->learning_rate, layer_biases);
308 } 307 }
309 } 308 }
310 309