From 6785bd57a1f75a1984b2bac5da69dcbcb05c1bc2 Mon Sep 17 00:00:00 2001 From: 3gg <3gg@shellblade.net> Date: Sun, 15 May 2022 18:22:41 -0700 Subject: Optimize away weights_T during training. Mnist-1000: 18s -> 15s. --- src/lib/include/neuralnet/matrix.h | 6 ++++++ src/lib/src/matrix.c | 29 +++++++++++++++++++++++++++++ src/lib/src/train.c | 18 ++++++++++-------- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h index 9816b81..0cb40cf 100644 --- a/src/lib/include/neuralnet/matrix.h +++ b/src/lib/include/neuralnet/matrix.h @@ -52,6 +52,12 @@ void nnMatrixInitConstant(nnMatrix*, R value); /// Multiply two matrices. void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); +/// Multiply two matrices, row variant. +/// +/// This function multiples two matrices row-by-row instead of row-by-column. +/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). +void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); + /// Matrix multiply-add. /// /// out = left + (right * scale) diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index a7a4ce6..29cdec5 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c @@ -150,6 +150,35 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { } } +void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { + assert(left != 0); + assert(right != 0); + assert(out != 0); + assert(out != left); + assert(out != right); + assert(left->cols == right->cols); + assert(out->rows == left->rows); + assert(out->cols == right->rows); + + R* out_value = out->values; + + for (int i = 0; i < left->rows; ++i) { + const R* left_row = &left->values[i * left->cols]; + const R* right_value = right->values; + + for (int j = 0; j < right->rows; ++j) { + *out_value = 0; + + // Vector dot product. + for (int k = 0; k < left->cols; ++k) { + *out_value += left_row[k] * *right_value++; + } + + out_value++; + } + } +} + void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { assert(left); assert(right); diff --git a/src/lib/src/train.c b/src/lib/src/train.c index 027de66..3061a99 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c @@ -129,7 +129,7 @@ void nnTrain( nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); // Allocate the weight transpose matrices up front for backpropagation. - nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); + //nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); // Allocate the weight delta matrices. nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); @@ -143,7 +143,7 @@ void nnTrain( nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); assert(errors != 0); - assert(weights_T != 0); + //assert(weights_T != 0); assert(weight_deltas != 0); assert(gradient_elems); assert(outputs_T); @@ -155,8 +155,8 @@ void nnTrain( errors[l] = nnMatrixMake(1, layer_weights->cols); - weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); - nnMatrixTranspose(layer_weights, &weights_T[l]); + //weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); + //nnMatrixTranspose(layer_weights, &weights_T[l]); weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); @@ -267,7 +267,9 @@ void nnTrain( // Backpropagate the error before updating weights. if (l > 0) { - nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); + // G * W^T == G *^T W. + //nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); + nnMatrixMulRows(gradient, layer_weights, &errors[l-1]); } // Update weights. @@ -278,7 +280,7 @@ void nnTrain( nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); // Update weight transpose matrix for the next training iteration. - nnMatrixTranspose(layer_weights, &weights_T[l]); + //nnMatrixTranspose(layer_weights, &weights_T[l]); // Update biases. // This is the same formula as for weights, except that the o_j term is @@ -319,7 +321,7 @@ void nnTrain( for (int l = 0; l < net->num_layers; ++l) { nnMatrixDel(&errors[l]); nnMatrixDel(&outputs_T[l]); - nnMatrixDel(&weights_T[l]); + //nnMatrixDel(&weights_T[l]); nnMatrixDel(&weight_deltas[l]); nnGradientElements* elems = &gradient_elems[l]; @@ -340,7 +342,7 @@ void nnTrain( nnMatrixDel(&training_inputs_T); free(errors); free(outputs_T); - free(weights_T); + //free(weights_T); free(weight_deltas); free(gradient_elems); } -- cgit v1.2.3