aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2023-12-16 11:06:03 -0800
committer3gg <3gg@shellblade.net>2023-12-16 11:06:03 -0800
commitdc538733da8d49e7240d00fb05517053076fe261 (patch)
tree752cefb55f072bebbe716b8fa4e4df2baadc3138 /src/lib/include
parent2067bd53b182429d059a61b0e060f92b4f317ed1 (diff)
Define vector outer product (nnMatrixMulOuter), which removes the need to transpose layer inputs during training.
Diffstat (limited to 'src/lib/include')
-rw-r--r--src/lib/include/neuralnet/matrix.h13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
index f80b985..4cb0d25 100644
--- a/src/lib/include/neuralnet/matrix.h
+++ b/src/lib/include/neuralnet/matrix.h
@@ -56,13 +56,20 @@ void nnMatrixInitConstant(nnMatrix*, R value);
56/// Multiply two matrices. 56/// Multiply two matrices.
57void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 57void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
58 58
59/// Multiply two matrices, row variant. 59/// Multiply two matrices, row-by-row variant.
60/// 60///
61/// This function multiples two matrices row-by-row instead of row-by-column. 61/// This function multiples two matrices row-by-row instead of row-by-column,
62/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). 62/// which is equivalent to regular multiplication after transposing the right
63/// hand matrix.
64///
65/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O).
63void nnMatrixMulRows( 66void nnMatrixMulRows(
64 const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 67 const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
65 68
69/// Compute the outer product of two vectors.
70void nnMatrixMulOuter(
71 const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
72
66/// Matrix multiply-add. 73/// Matrix multiply-add.
67/// 74///
68/// out = left + (right * scale) 75/// out = left + (right * scale)