aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include/neuralnet/matrix.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/include/neuralnet/matrix.h')
-rw-r--r--src/lib/include/neuralnet/matrix.h111
1 files changed, 111 insertions, 0 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
new file mode 100644
index 0000000..9816b81
--- /dev/null
+++ b/src/lib/include/neuralnet/matrix.h
@@ -0,0 +1,111 @@
1#pragma once
2
3#include <neuralnet/types.h>
4
5#include <assert.h>
6
7/// NxM matrix.
8typedef struct nnMatrix {
9 int rows;
10 int cols;
11 R* values;
12} nnMatrix;
13
14/// Construct a matrix.
15nnMatrix nnMatrixMake(int rows, int cols);
16
17/// Delete a matrix and free its internal memory.
18void nnMatrixDel(nnMatrix*);
19
20/// Move a matrix.
21///
22/// |in| is an empty matrix after the move.
23/// |out| is a matrix like |in| before the move.
24void nnMatrixMove(nnMatrix* in, nnMatrix* out);
25
26/// Deep-copy a matrix.
27void nnMatrixCopy(const nnMatrix* in, nnMatrix* out);
28
29/// Write the matrix values into an array in a row-major fashion.
30void nnMatrixToArray(const nnMatrix* in, R* out);
31
32/// Write the given row of a matrix into an array.
33void nnMatrixRowToArray(const nnMatrix* in, int row, R* out);
34
35/// Copy a column from a source to a target matrix.
36void nnMatrixCopyCol(const nnMatrix* in, nnMatrix* out, int col_in, int col_out);
37
38/// Mutable borrow of a matrix.
39nnMatrix nnMatrixBorrow(nnMatrix* in);
40
41/// Mutable borrow of a subrange of rows of a matrix.
42nnMatrix nnMatrixBorrowRows(nnMatrix* in, int row_start, int num_rows);
43
44/// Initialize the matrix from an array of values.
45///
46/// The array must hold values in a row-major fashion.
47void nnMatrixInit(nnMatrix*, const R* values);
48
49/// Initialize all matrix values to a given constant.
50void nnMatrixInitConstant(nnMatrix*, R value);
51
52/// Multiply two matrices.
53void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
54
55/// Matrix multiply-add.
56///
57/// out = left + (right * scale)
58void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out);
59
60/// Matrix multiply-subtract.
61///
62/// out = left - (right * scale)
63void nnMatrixMulSub(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out);
64
65/// Hadamard product of two matrices.
66void nnMatrixMulPairs(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
67
68/// Add two matrices.
69void nnMatrixAdd(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
70
71/// Subtract two matrices.
72void nnMatrixSub(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
73
74/// Adds a row vector to all rows of the matrix.
75void nnMatrixAddRow(const nnMatrix* matrix, const nnMatrix* row, nnMatrix* out);
76
77/// Scale a matrix.
78void nnMatrixScale(nnMatrix*, R scale);
79
80/// Transpose a matrix.
81/// |in| must be different than |out|.
82void nnMatrixTranspose(const nnMatrix* in, nnMatrix* out);
83
84/// Threshold the values of a matrix using a greater-than operator.
85///
86/// out[x,y] = 1 if in[x,y] > threshold else 0
87void nnMatrixGt(const nnMatrix* in, R threshold, nnMatrix* out);
88
89/// Return the matrix value at the given row and column.
90static inline R nnMatrixAt(const nnMatrix* matrix, int row, int col) {
91 assert(matrix);
92 return matrix->values[row * matrix->cols + col];
93}
94
95/// Set the matrix value at the given row and column.
96static inline void nnMatrixSet(nnMatrix* matrix, int row, int col, R value) {
97 assert(matrix);
98 matrix->values[row * matrix->cols + col] = value;
99}
100
101/// Return a pointer to the given row in the matrix.
102static inline const R* nnMatrixRow(const nnMatrix* matrix, int row) {
103 assert(matrix);
104 return &matrix->values[row * matrix->cols];
105}
106
107/// Return a mutable pointer to the given row in the matrix.
108static inline R* nnMatrixRow_mut(nnMatrix* matrix, int row) {
109 assert(matrix);
110 return &matrix->values[row * matrix->cols];
111}