diff options
Diffstat (limited to 'src/lib/src/matrix.c')
-rw-r--r-- | src/lib/src/matrix.c | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index a7a4ce6..29cdec5 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c | |||
@@ -150,6 +150,35 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | |||
150 | } | 150 | } |
151 | } | 151 | } |
152 | 152 | ||
153 | void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | ||
154 | assert(left != 0); | ||
155 | assert(right != 0); | ||
156 | assert(out != 0); | ||
157 | assert(out != left); | ||
158 | assert(out != right); | ||
159 | assert(left->cols == right->cols); | ||
160 | assert(out->rows == left->rows); | ||
161 | assert(out->cols == right->rows); | ||
162 | |||
163 | R* out_value = out->values; | ||
164 | |||
165 | for (int i = 0; i < left->rows; ++i) { | ||
166 | const R* left_row = &left->values[i * left->cols]; | ||
167 | const R* right_value = right->values; | ||
168 | |||
169 | for (int j = 0; j < right->rows; ++j) { | ||
170 | *out_value = 0; | ||
171 | |||
172 | // Vector dot product. | ||
173 | for (int k = 0; k < left->cols; ++k) { | ||
174 | *out_value += left_row[k] * *right_value++; | ||
175 | } | ||
176 | |||
177 | out_value++; | ||
178 | } | ||
179 | } | ||
180 | } | ||
181 | |||
153 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { | 182 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { |
154 | assert(left); | 183 | assert(left); |
155 | assert(right); | 184 | assert(right); |