aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/lib/include/neuralnet/matrix.h13
-rw-r--r--src/lib/src/matrix.c26
-rw-r--r--src/lib/src/train.c28
3 files changed, 39 insertions, 28 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
index f80b985..4cb0d25 100644
--- a/src/lib/include/neuralnet/matrix.h
+++ b/src/lib/include/neuralnet/matrix.h
@@ -56,13 +56,20 @@ void nnMatrixInitConstant(nnMatrix*, R value);
56/// Multiply two matrices. 56/// Multiply two matrices.
57void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 57void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
58 58
59/// Multiply two matrices, row variant. 59/// Multiply two matrices, row-by-row variant.
60/// 60///
61/// This function multiples two matrices row-by-row instead of row-by-column. 61/// This function multiples two matrices row-by-row instead of row-by-column,
62/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). 62/// which is equivalent to regular multiplication after transposing the right
63/// hand matrix.
64///
65/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O).
63void nnMatrixMulRows( 66void nnMatrixMulRows(
64 const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 67 const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
65 68
69/// Compute the outer product of two vectors.
70void nnMatrixMulOuter(
71 const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
72
66/// Matrix multiply-add. 73/// Matrix multiply-add.
67/// 74///
68/// out = left + (right * scale) 75/// out = left + (right * scale)
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c
index d5c3fcc..29511eb 100644
--- a/src/lib/src/matrix.c
+++ b/src/lib/src/matrix.c
@@ -189,6 +189,32 @@ void nnMatrixMulRows(
189 } 189 }
190} 190}
191 191
192void nnMatrixMulOuter(
193 const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
194 assert(left != 0);
195 assert(right != 0);
196 assert(out != 0);
197 assert(out != left);
198 assert(out != right);
199 assert((left->rows == 1) || (left->cols == 1)); // Vector.
200 assert((right->rows == 1) || (right->cols == 1)); // Vector.
201 const int N = left->rows * left->cols;
202 const int M = right->rows * right->cols;
203 assert((out->rows == N) && (out->cols == M));
204
205 const R* left_value = left->values;
206 R* out_value = out->values;
207
208 for (int i = 0; i < N; ++i) {
209 const R* right_value = right->values;
210
211 for (int j = 0; j < M; ++j) {
212 *out_value++ = *left_value * *right_value++;
213 }
214 left_value++;
215 }
216}
217
192void nnMatrixMulAdd( 218void nnMatrixMulAdd(
193 const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { 219 const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) {
194 assert(left); 220 assert(left);
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
index ccff553..fe9f598 100644
--- a/src/lib/src/train.c
+++ b/src/lib/src/train.c
@@ -153,14 +153,9 @@ void nnTrain(
153 nnGradientElements* gradient_elems = 153 nnGradientElements* gradient_elems =
154 calloc(net->num_layers, sizeof(nnGradientElements)); 154 calloc(net->num_layers, sizeof(nnGradientElements));
155 155
156 // Allocate the output transpose vectors for weight delta calculation.
157 // This is one column vector per layer.
158 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix));
159
160 assert(errors != 0); 156 assert(errors != 0);
161 assert(weight_deltas != 0); 157 assert(weight_deltas != 0);
162 assert(gradient_elems); 158 assert(gradient_elems);
163 assert(outputs_T);
164 159
165 for (int l = 0; l < net->num_layers; ++l) { 160 for (int l = 0; l < net->num_layers; ++l) {
166 const int layer_input_size = nnLayerInputSize(net, l); 161 const int layer_input_size = nnLayerInputSize(net, l);
@@ -169,7 +164,6 @@ void nnTrain(
169 164
170 errors[l] = nnMatrixMake(1, layer_output_size); 165 errors[l] = nnMatrixMake(1, layer_output_size);
171 weight_deltas[l] = nnMatrixMake(layer_input_size, layer_output_size); 166 weight_deltas[l] = nnMatrixMake(layer_input_size, layer_output_size);
172 outputs_T[l] = nnMatrixMake(layer_output_size, 1);
173 167
174 // Allocate the gradient elements and vectors for weight delta calculation. 168 // Allocate the gradient elements and vectors for weight delta calculation.
175 nnGradientElements* elems = &gradient_elems[l]; 169 nnGradientElements* elems = &gradient_elems[l];
@@ -199,9 +193,6 @@ void nnTrain(
199 // the outputs. 193 // the outputs.
200 const nnMatrix* const training_outputs = query->network_outputs; 194 const nnMatrix* const training_outputs = query->network_outputs;
201 195
202 // A vector to store the training input transposed.
203 nnMatrix training_inputs_T = nnMatrixMake(inputs->cols, 1);
204
205 // If debug mode is requested, we will show progress every Nth iteration. 196 // If debug mode is requested, we will show progress every Nth iteration.
206 const int progress_frame = 197 const int progress_frame =
207 (params->max_iterations < PROGRESS_THRESHOLD) 198 (params->max_iterations < PROGRESS_THRESHOLD)
@@ -223,10 +214,6 @@ void nnTrain(
223 const nnMatrix training_targets = 214 const nnMatrix training_targets =
224 nnMatrixBorrowRows((nnMatrix*)targets, sample, 1); 215 nnMatrixBorrowRows((nnMatrix*)targets, sample, 1);
225 216
226 // Will need the input transposed for backpropagation.
227 // Assuming one training input per iteration for now.
228 nnMatrixTranspose(&training_inputs, &training_inputs_T);
229
230 // Forward pass. 217 // Forward pass.
231 nnQuery(net, query, &training_inputs); 218 nnQuery(net, query, &training_inputs);
232 219
@@ -240,14 +227,11 @@ void nnTrain(
240 nnMatrixSub( 227 nnMatrixSub(
241 training_outputs, &training_targets, &errors[net->num_layers - 1]); 228 training_outputs, &training_targets, &errors[net->num_layers - 1]);
242 229
243 // Update outputs_T, which we need during weight updates.
244 for (int l = 0; l < net->num_layers; ++l) {
245 nnMatrixTranspose(&query->layer_outputs[l], &outputs_T[l]);
246 }
247
248 // Update weights and biases for each internal layer, back-propagating 230 // Update weights and biases for each internal layer, back-propagating
249 // errors along the way. 231 // errors along the way.
250 for (int l = net->num_layers - 1; l >= 0; --l) { 232 for (int l = net->num_layers - 1; l >= 0; --l) {
233 const nnMatrix* layer_input =
234 (l == 0) ? &training_inputs : &query->layer_outputs[l - 1];
251 const nnMatrix* layer_output = &query->layer_outputs[l]; 235 const nnMatrix* layer_output = &query->layer_outputs[l];
252 nnGradientElements* elems = &gradient_elems[l]; 236 nnGradientElements* elems = &gradient_elems[l];
253 nnMatrix* gradient = &elems->gradient; 237 nnMatrix* gradient = &elems->gradient;
@@ -310,10 +294,7 @@ void nnTrain(
310 nnMatrix* layer_biases = &linear->biases; 294 nnMatrix* layer_biases = &linear->biases;
311 295
312 // Outer product to compute the weight deltas. 296 // Outer product to compute the weight deltas.
313 // This layer's input is the previous layer's output. 297 nnMatrixMulOuter(layer_input, gradient, &weight_deltas[l]);
314 const nnMatrix* input_T =
315 (l == 0) ? &training_inputs_T : &outputs_T[l - 1];
316 nnMatrixMul(input_T, gradient, &weight_deltas[l]);
317 298
318 // Update weights. 299 // Update weights.
319 nnMatrixScale(&weight_deltas[l], params->learning_rate); 300 nnMatrixScale(&weight_deltas[l], params->learning_rate);
@@ -360,7 +341,6 @@ void nnTrain(
360 // Clean up. 341 // Clean up.
361 for (int l = 0; l < net->num_layers; ++l) { 342 for (int l = 0; l < net->num_layers; ++l) {
362 nnMatrixDel(&errors[l]); 343 nnMatrixDel(&errors[l]);
363 nnMatrixDel(&outputs_T[l]);
364 nnMatrixDel(&weight_deltas[l]); 344 nnMatrixDel(&weight_deltas[l]);
365 345
366 nnGradientElements* elems = &gradient_elems[l]; 346 nnGradientElements* elems = &gradient_elems[l];
@@ -378,9 +358,7 @@ void nnTrain(
378 break; 358 break;
379 } 359 }
380 } 360 }
381 nnMatrixDel(&training_inputs_T);
382 free(errors); 361 free(errors);
383 free(outputs_T);
384 free(weight_deltas); 362 free(weight_deltas);
385 free(gradient_elems); 363 free(gradient_elems);
386} 364}