diff options
-rw-r--r-- | src/lib/src/matrix.c | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index 29cdec5..f937c01 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c | |||
@@ -131,21 +131,23 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | |||
131 | assert(out->cols == right->cols); | 131 | assert(out->cols == right->cols); |
132 | 132 | ||
133 | R* out_value = out->values; | 133 | R* out_value = out->values; |
134 | for (int i = 0; i < out->rows * out->cols; ++i) { | ||
135 | *out_value++ = 0; | ||
136 | } | ||
134 | 137 | ||
135 | for (int i = 0; i < left->rows; ++i) { | 138 | for (int i = 0; i < left->rows; ++i) { |
136 | const R* left_row = &left->values[i * left->cols]; | 139 | const R* p_left_value = &left->values[i * left->cols]; |
137 | 140 | ||
138 | for (int j = 0; j < right->cols; ++j) { | 141 | for (int j = 0; j < left->cols; ++j) { |
139 | const R* right_col = &right->values[j]; | 142 | const R left_value = *p_left_value; |
140 | *out_value = 0; | 143 | const R* right_value = &right->values[j * right->cols]; |
144 | R* out_value = &out->values[i * out->cols]; | ||
141 | 145 | ||
142 | // Vector dot product. | 146 | for (int k = 0; k < right->cols; ++k) { |
143 | for (int k = 0; k < left->cols; ++k) { | 147 | *out_value++ += left_value * *right_value++; |
144 | *out_value += left_row[k] * right_col[0]; | ||
145 | right_col += right->cols; // Next row in the column. | ||
146 | } | 148 | } |
147 | 149 | ||
148 | out_value++; | 150 | p_left_value++; |
149 | } | 151 | } |
150 | } | 152 | } |
151 | } | 153 | } |