diff options
author | 3gg <3gg@shellblade.net> | 2022-05-15 18:22:41 -0700 |
---|---|---|
committer | 3gg <3gg@shellblade.net> | 2022-05-15 18:48:32 -0700 |
commit | 6785bd57a1f75a1984b2bac5da69dcbcb05c1bc2 (patch) | |
tree | 79393ee1069c415e195ba94c42d7bc43cf1b49f3 | |
parent | e858458e934a9e2fca953be43120497771292304 (diff) |
Optimize away weights_T during training.
Mnist-1000: 18s -> 15s.
-rw-r--r-- | src/lib/include/neuralnet/matrix.h | 6 | ||||
-rw-r--r-- | src/lib/src/matrix.c | 29 | ||||
-rw-r--r-- | src/lib/src/train.c | 18 |
3 files changed, 45 insertions, 8 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h index 9816b81..0cb40cf 100644 --- a/src/lib/include/neuralnet/matrix.h +++ b/src/lib/include/neuralnet/matrix.h | |||
@@ -52,6 +52,12 @@ void nnMatrixInitConstant(nnMatrix*, R value); | |||
52 | /// Multiply two matrices. | 52 | /// Multiply two matrices. |
53 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 53 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); |
54 | 54 | ||
55 | /// Multiply two matrices, row variant. | ||
56 | /// | ||
57 | /// This function multiples two matrices row-by-row instead of row-by-column. | ||
58 | /// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). | ||
59 | void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
60 | |||
55 | /// Matrix multiply-add. | 61 | /// Matrix multiply-add. |
56 | /// | 62 | /// |
57 | /// out = left + (right * scale) | 63 | /// out = left + (right * scale) |
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index a7a4ce6..29cdec5 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c | |||
@@ -150,6 +150,35 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | |||
150 | } | 150 | } |
151 | } | 151 | } |
152 | 152 | ||
153 | void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | ||
154 | assert(left != 0); | ||
155 | assert(right != 0); | ||
156 | assert(out != 0); | ||
157 | assert(out != left); | ||
158 | assert(out != right); | ||
159 | assert(left->cols == right->cols); | ||
160 | assert(out->rows == left->rows); | ||
161 | assert(out->cols == right->rows); | ||
162 | |||
163 | R* out_value = out->values; | ||
164 | |||
165 | for (int i = 0; i < left->rows; ++i) { | ||
166 | const R* left_row = &left->values[i * left->cols]; | ||
167 | const R* right_value = right->values; | ||
168 | |||
169 | for (int j = 0; j < right->rows; ++j) { | ||
170 | *out_value = 0; | ||
171 | |||
172 | // Vector dot product. | ||
173 | for (int k = 0; k < left->cols; ++k) { | ||
174 | *out_value += left_row[k] * *right_value++; | ||
175 | } | ||
176 | |||
177 | out_value++; | ||
178 | } | ||
179 | } | ||
180 | } | ||
181 | |||
153 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { | 182 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { |
154 | assert(left); | 183 | assert(left); |
155 | assert(right); | 184 | assert(right); |
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index 027de66..3061a99 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
@@ -129,7 +129,7 @@ void nnTrain( | |||
129 | nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); | 129 | nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); |
130 | 130 | ||
131 | // Allocate the weight transpose matrices up front for backpropagation. | 131 | // Allocate the weight transpose matrices up front for backpropagation. |
132 | nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); | 132 | //nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); |
133 | 133 | ||
134 | // Allocate the weight delta matrices. | 134 | // Allocate the weight delta matrices. |
135 | nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); | 135 | nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); |
@@ -143,7 +143,7 @@ void nnTrain( | |||
143 | nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); | 143 | nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); |
144 | 144 | ||
145 | assert(errors != 0); | 145 | assert(errors != 0); |
146 | assert(weights_T != 0); | 146 | //assert(weights_T != 0); |
147 | assert(weight_deltas != 0); | 147 | assert(weight_deltas != 0); |
148 | assert(gradient_elems); | 148 | assert(gradient_elems); |
149 | assert(outputs_T); | 149 | assert(outputs_T); |
@@ -155,8 +155,8 @@ void nnTrain( | |||
155 | 155 | ||
156 | errors[l] = nnMatrixMake(1, layer_weights->cols); | 156 | errors[l] = nnMatrixMake(1, layer_weights->cols); |
157 | 157 | ||
158 | weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); | 158 | //weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); |
159 | nnMatrixTranspose(layer_weights, &weights_T[l]); | 159 | //nnMatrixTranspose(layer_weights, &weights_T[l]); |
160 | 160 | ||
161 | weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); | 161 | weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); |
162 | 162 | ||
@@ -267,7 +267,9 @@ void nnTrain( | |||
267 | 267 | ||
268 | // Backpropagate the error before updating weights. | 268 | // Backpropagate the error before updating weights. |
269 | if (l > 0) { | 269 | if (l > 0) { |
270 | nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); | 270 | // G * W^T == G *^T W. |
271 | //nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); | ||
272 | nnMatrixMulRows(gradient, layer_weights, &errors[l-1]); | ||
271 | } | 273 | } |
272 | 274 | ||
273 | // Update weights. | 275 | // Update weights. |
@@ -278,7 +280,7 @@ void nnTrain( | |||
278 | nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); | 280 | nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); |
279 | 281 | ||
280 | // Update weight transpose matrix for the next training iteration. | 282 | // Update weight transpose matrix for the next training iteration. |
281 | nnMatrixTranspose(layer_weights, &weights_T[l]); | 283 | //nnMatrixTranspose(layer_weights, &weights_T[l]); |
282 | 284 | ||
283 | // Update biases. | 285 | // Update biases. |
284 | // This is the same formula as for weights, except that the o_j term is | 286 | // This is the same formula as for weights, except that the o_j term is |
@@ -319,7 +321,7 @@ void nnTrain( | |||
319 | for (int l = 0; l < net->num_layers; ++l) { | 321 | for (int l = 0; l < net->num_layers; ++l) { |
320 | nnMatrixDel(&errors[l]); | 322 | nnMatrixDel(&errors[l]); |
321 | nnMatrixDel(&outputs_T[l]); | 323 | nnMatrixDel(&outputs_T[l]); |
322 | nnMatrixDel(&weights_T[l]); | 324 | //nnMatrixDel(&weights_T[l]); |
323 | nnMatrixDel(&weight_deltas[l]); | 325 | nnMatrixDel(&weight_deltas[l]); |
324 | 326 | ||
325 | nnGradientElements* elems = &gradient_elems[l]; | 327 | nnGradientElements* elems = &gradient_elems[l]; |
@@ -340,7 +342,7 @@ void nnTrain( | |||
340 | nnMatrixDel(&training_inputs_T); | 342 | nnMatrixDel(&training_inputs_T); |
341 | free(errors); | 343 | free(errors); |
342 | free(outputs_T); | 344 | free(outputs_T); |
343 | free(weights_T); | 345 | //free(weights_T); |
344 | free(weight_deltas); | 346 | free(weight_deltas); |
345 | free(gradient_elems); | 347 | free(gradient_elems); |
346 | } | 348 | } |