aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/matrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/src/matrix.c')
-rw-r--r--src/lib/src/matrix.c29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c
index a7a4ce6..29cdec5 100644
--- a/src/lib/src/matrix.c
+++ b/src/lib/src/matrix.c
@@ -150,6 +150,35 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
150 } 150 }
151} 151}
152 152
153void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
154 assert(left != 0);
155 assert(right != 0);
156 assert(out != 0);
157 assert(out != left);
158 assert(out != right);
159 assert(left->cols == right->cols);
160 assert(out->rows == left->rows);
161 assert(out->cols == right->rows);
162
163 R* out_value = out->values;
164
165 for (int i = 0; i < left->rows; ++i) {
166 const R* left_row = &left->values[i * left->cols];
167 const R* right_value = right->values;
168
169 for (int j = 0; j < right->rows; ++j) {
170 *out_value = 0;
171
172 // Vector dot product.
173 for (int k = 0; k < left->cols; ++k) {
174 *out_value += left_row[k] * *right_value++;
175 }
176
177 out_value++;
178 }
179 }
180}
181
153void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { 182void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) {
154 assert(left); 183 assert(left);
155 assert(right); 184 assert(right);