diff options
author | 3gg <3gg@shellblade.net> | 2023-12-16 11:06:03 -0800 |
---|---|---|
committer | 3gg <3gg@shellblade.net> | 2023-12-16 11:06:03 -0800 |
commit | dc538733da8d49e7240d00fb05517053076fe261 (patch) | |
tree | 752cefb55f072bebbe716b8fa4e4df2baadc3138 /src/lib/include | |
parent | 2067bd53b182429d059a61b0e060f92b4f317ed1 (diff) |
Define vector outer product (nnMatrixMulOuter), which removes the need to transpose layer inputs during training.
Diffstat (limited to 'src/lib/include')
-rw-r--r-- | src/lib/include/neuralnet/matrix.h | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h index f80b985..4cb0d25 100644 --- a/src/lib/include/neuralnet/matrix.h +++ b/src/lib/include/neuralnet/matrix.h | |||
@@ -56,13 +56,20 @@ void nnMatrixInitConstant(nnMatrix*, R value); | |||
56 | /// Multiply two matrices. | 56 | /// Multiply two matrices. |
57 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 57 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); |
58 | 58 | ||
59 | /// Multiply two matrices, row variant. | 59 | /// Multiply two matrices, row-by-row variant. |
60 | /// | 60 | /// |
61 | /// This function multiples two matrices row-by-row instead of row-by-column. | 61 | /// This function multiples two matrices row-by-row instead of row-by-column, |
62 | /// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). | 62 | /// which is equivalent to regular multiplication after transposing the right |
63 | /// hand matrix. | ||
64 | /// | ||
65 | /// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). | ||
63 | void nnMatrixMulRows( | 66 | void nnMatrixMulRows( |
64 | const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 67 | const nnMatrix* left, const nnMatrix* right, nnMatrix* out); |
65 | 68 | ||
69 | /// Compute the outer product of two vectors. | ||
70 | void nnMatrixMulOuter( | ||
71 | const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
72 | |||
66 | /// Matrix multiply-add. | 73 | /// Matrix multiply-add. |
67 | /// | 74 | /// |
68 | /// out = left + (right * scale) | 75 | /// out = left + (right * scale) |