diff options
| author | 3gg <3gg@shellblade.net> | 2023-12-16 11:06:03 -0800 |
|---|---|---|
| committer | 3gg <3gg@shellblade.net> | 2023-12-16 11:06:03 -0800 |
| commit | dc538733da8d49e7240d00fb05517053076fe261 (patch) | |
| tree | 752cefb55f072bebbe716b8fa4e4df2baadc3138 /src/lib/include | |
| parent | 2067bd53b182429d059a61b0e060f92b4f317ed1 (diff) | |
Define vector outer product (nnMatrixMulOuter), which removes the need to transpose layer inputs during training.
Diffstat (limited to 'src/lib/include')
| -rw-r--r-- | src/lib/include/neuralnet/matrix.h | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h index f80b985..4cb0d25 100644 --- a/src/lib/include/neuralnet/matrix.h +++ b/src/lib/include/neuralnet/matrix.h | |||
| @@ -56,13 +56,20 @@ void nnMatrixInitConstant(nnMatrix*, R value); | |||
| 56 | /// Multiply two matrices. | 56 | /// Multiply two matrices. |
| 57 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 57 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); |
| 58 | 58 | ||
| 59 | /// Multiply two matrices, row variant. | 59 | /// Multiply two matrices, row-by-row variant. |
| 60 | /// | 60 | /// |
| 61 | /// This function multiples two matrices row-by-row instead of row-by-column. | 61 | /// This function multiples two matrices row-by-row instead of row-by-column, |
| 62 | /// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). | 62 | /// which is equivalent to regular multiplication after transposing the right |
| 63 | /// hand matrix. | ||
| 64 | /// | ||
| 65 | /// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). | ||
| 63 | void nnMatrixMulRows( | 66 | void nnMatrixMulRows( |
| 64 | const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 67 | const nnMatrix* left, const nnMatrix* right, nnMatrix* out); |
| 65 | 68 | ||
| 69 | /// Compute the outer product of two vectors. | ||
| 70 | void nnMatrixMulOuter( | ||
| 71 | const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
| 72 | |||
| 66 | /// Matrix multiply-add. | 73 | /// Matrix multiply-add. |
| 67 | /// | 74 | /// |
| 68 | /// out = left + (right * scale) | 75 | /// out = left + (right * scale) |
