aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/lib/src/matrix.c20
1 files changed, 11 insertions, 9 deletions
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c
index 29cdec5..f937c01 100644
--- a/src/lib/src/matrix.c
+++ b/src/lib/src/matrix.c
@@ -131,21 +131,23 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
131 assert(out->cols == right->cols); 131 assert(out->cols == right->cols);
132 132
133 R* out_value = out->values; 133 R* out_value = out->values;
134 for (int i = 0; i < out->rows * out->cols; ++i) {
135 *out_value++ = 0;
136 }
134 137
135 for (int i = 0; i < left->rows; ++i) { 138 for (int i = 0; i < left->rows; ++i) {
136 const R* left_row = &left->values[i * left->cols]; 139 const R* p_left_value = &left->values[i * left->cols];
137 140
138 for (int j = 0; j < right->cols; ++j) { 141 for (int j = 0; j < left->cols; ++j) {
139 const R* right_col = &right->values[j]; 142 const R left_value = *p_left_value;
140 *out_value = 0; 143 const R* right_value = &right->values[j * right->cols];
144 R* out_value = &out->values[i * out->cols];
141 145
142 // Vector dot product. 146 for (int k = 0; k < right->cols; ++k) {
143 for (int k = 0; k < left->cols; ++k) { 147 *out_value++ += left_value * *right_value++;
144 *out_value += left_row[k] * right_col[0];
145 right_col += right->cols; // Next row in the column.
146 } 148 }
147 149
148 out_value++; 150 p_left_value++;
149 } 151 }
150 } 152 }
151} 153}