aboutsummaryrefslogtreecommitdiff
path: root/src/lib
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib')
-rw-r--r--src/lib/include/neuralnet/matrix.h6
-rw-r--r--src/lib/src/matrix.c29
-rw-r--r--src/lib/src/train.c18
3 files changed, 45 insertions, 8 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
index 9816b81..0cb40cf 100644
--- a/src/lib/include/neuralnet/matrix.h
+++ b/src/lib/include/neuralnet/matrix.h
@@ -52,6 +52,12 @@ void nnMatrixInitConstant(nnMatrix*, R value);
52/// Multiply two matrices. 52/// Multiply two matrices.
53void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 53void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
54 54
55/// Multiply two matrices, row variant.
56///
57/// This function multiples two matrices row-by-row instead of row-by-column.
58/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O).
59void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
60
55/// Matrix multiply-add. 61/// Matrix multiply-add.
56/// 62///
57/// out = left + (right * scale) 63/// out = left + (right * scale)
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c
index a7a4ce6..29cdec5 100644
--- a/src/lib/src/matrix.c
+++ b/src/lib/src/matrix.c
@@ -150,6 +150,35 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
150 } 150 }
151} 151}
152 152
153void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
154 assert(left != 0);
155 assert(right != 0);
156 assert(out != 0);
157 assert(out != left);
158 assert(out != right);
159 assert(left->cols == right->cols);
160 assert(out->rows == left->rows);
161 assert(out->cols == right->rows);
162
163 R* out_value = out->values;
164
165 for (int i = 0; i < left->rows; ++i) {
166 const R* left_row = &left->values[i * left->cols];
167 const R* right_value = right->values;
168
169 for (int j = 0; j < right->rows; ++j) {
170 *out_value = 0;
171
172 // Vector dot product.
173 for (int k = 0; k < left->cols; ++k) {
174 *out_value += left_row[k] * *right_value++;
175 }
176
177 out_value++;
178 }
179 }
180}
181
153void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { 182void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) {
154 assert(left); 183 assert(left);
155 assert(right); 184 assert(right);
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
index 027de66..3061a99 100644
--- a/src/lib/src/train.c
+++ b/src/lib/src/train.c
@@ -129,7 +129,7 @@ void nnTrain(
129 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); 129 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix));
130 130
131 // Allocate the weight transpose matrices up front for backpropagation. 131 // Allocate the weight transpose matrices up front for backpropagation.
132 nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); 132 //nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix));
133 133
134 // Allocate the weight delta matrices. 134 // Allocate the weight delta matrices.
135 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); 135 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix));
@@ -143,7 +143,7 @@ void nnTrain(
143 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); 143 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix));
144 144
145 assert(errors != 0); 145 assert(errors != 0);
146 assert(weights_T != 0); 146 //assert(weights_T != 0);
147 assert(weight_deltas != 0); 147 assert(weight_deltas != 0);
148 assert(gradient_elems); 148 assert(gradient_elems);
149 assert(outputs_T); 149 assert(outputs_T);
@@ -155,8 +155,8 @@ void nnTrain(
155 155
156 errors[l] = nnMatrixMake(1, layer_weights->cols); 156 errors[l] = nnMatrixMake(1, layer_weights->cols);
157 157
158 weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); 158 //weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows);
159 nnMatrixTranspose(layer_weights, &weights_T[l]); 159 //nnMatrixTranspose(layer_weights, &weights_T[l]);
160 160
161 weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); 161 weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols);
162 162
@@ -267,7 +267,9 @@ void nnTrain(
267 267
268 // Backpropagate the error before updating weights. 268 // Backpropagate the error before updating weights.
269 if (l > 0) { 269 if (l > 0) {
270 nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); 270 // G * W^T == G *^T W.
271 //nnMatrixMul(gradient, &weights_T[l], &errors[l-1]);
272 nnMatrixMulRows(gradient, layer_weights, &errors[l-1]);
271 } 273 }
272 274
273 // Update weights. 275 // Update weights.
@@ -278,7 +280,7 @@ void nnTrain(
278 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); 280 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights);
279 281
280 // Update weight transpose matrix for the next training iteration. 282 // Update weight transpose matrix for the next training iteration.
281 nnMatrixTranspose(layer_weights, &weights_T[l]); 283 //nnMatrixTranspose(layer_weights, &weights_T[l]);
282 284
283 // Update biases. 285 // Update biases.
284 // This is the same formula as for weights, except that the o_j term is 286 // This is the same formula as for weights, except that the o_j term is
@@ -319,7 +321,7 @@ void nnTrain(
319 for (int l = 0; l < net->num_layers; ++l) { 321 for (int l = 0; l < net->num_layers; ++l) {
320 nnMatrixDel(&errors[l]); 322 nnMatrixDel(&errors[l]);
321 nnMatrixDel(&outputs_T[l]); 323 nnMatrixDel(&outputs_T[l]);
322 nnMatrixDel(&weights_T[l]); 324 //nnMatrixDel(&weights_T[l]);
323 nnMatrixDel(&weight_deltas[l]); 325 nnMatrixDel(&weight_deltas[l]);
324 326
325 nnGradientElements* elems = &gradient_elems[l]; 327 nnGradientElements* elems = &gradient_elems[l];
@@ -340,7 +342,7 @@ void nnTrain(
340 nnMatrixDel(&training_inputs_T); 342 nnMatrixDel(&training_inputs_T);
341 free(errors); 343 free(errors);
342 free(outputs_T); 344 free(outputs_T);
343 free(weights_T); 345 //free(weights_T);
344 free(weight_deltas); 346 free(weight_deltas);
345 free(gradient_elems); 347 free(gradient_elems);
346} 348}