From 041613467a0915e6ec07cdab0ca3e7b8d757fe5f Mon Sep 17 00:00:00 2001 From: 3gg <3gg@shellblade.net> Date: Sun, 15 May 2022 18:49:10 -0700 Subject: Optimize matrix multiplication to be more cache-friendly. Mnist-1000: 15s -> 7.74s. --- src/lib/src/matrix.c | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index 29cdec5..f937c01 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c @@ -131,21 +131,23 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { assert(out->cols == right->cols); R* out_value = out->values; + for (int i = 0; i < out->rows * out->cols; ++i) { + *out_value++ = 0; + } for (int i = 0; i < left->rows; ++i) { - const R* left_row = &left->values[i * left->cols]; + const R* p_left_value = &left->values[i * left->cols]; - for (int j = 0; j < right->cols; ++j) { - const R* right_col = &right->values[j]; - *out_value = 0; + for (int j = 0; j < left->cols; ++j) { + const R left_value = *p_left_value; + const R* right_value = &right->values[j * right->cols]; + R* out_value = &out->values[i * out->cols]; - // Vector dot product. - for (int k = 0; k < left->cols; ++k) { - *out_value += left_row[k] * right_col[0]; - right_col += right->cols; // Next row in the column. + for (int k = 0; k < right->cols; ++k) { + *out_value++ += left_value * *right_value++; } - out_value++; + p_left_value++; } } } -- cgit v1.2.3