diff options
Diffstat (limited to 'src/lib/include/neuralnet/matrix.h')
-rw-r--r-- | src/lib/include/neuralnet/matrix.h | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h new file mode 100644 index 0000000..9816b81 --- /dev/null +++ b/src/lib/include/neuralnet/matrix.h | |||
@@ -0,0 +1,111 @@ | |||
1 | #pragma once | ||
2 | |||
3 | #include <neuralnet/types.h> | ||
4 | |||
5 | #include <assert.h> | ||
6 | |||
7 | /// NxM matrix. | ||
8 | typedef struct nnMatrix { | ||
9 | int rows; | ||
10 | int cols; | ||
11 | R* values; | ||
12 | } nnMatrix; | ||
13 | |||
14 | /// Construct a matrix. | ||
15 | nnMatrix nnMatrixMake(int rows, int cols); | ||
16 | |||
17 | /// Delete a matrix and free its internal memory. | ||
18 | void nnMatrixDel(nnMatrix*); | ||
19 | |||
20 | /// Move a matrix. | ||
21 | /// | ||
22 | /// |in| is an empty matrix after the move. | ||
23 | /// |out| is a matrix like |in| before the move. | ||
24 | void nnMatrixMove(nnMatrix* in, nnMatrix* out); | ||
25 | |||
26 | /// Deep-copy a matrix. | ||
27 | void nnMatrixCopy(const nnMatrix* in, nnMatrix* out); | ||
28 | |||
29 | /// Write the matrix values into an array in a row-major fashion. | ||
30 | void nnMatrixToArray(const nnMatrix* in, R* out); | ||
31 | |||
32 | /// Write the given row of a matrix into an array. | ||
33 | void nnMatrixRowToArray(const nnMatrix* in, int row, R* out); | ||
34 | |||
35 | /// Copy a column from a source to a target matrix. | ||
36 | void nnMatrixCopyCol(const nnMatrix* in, nnMatrix* out, int col_in, int col_out); | ||
37 | |||
38 | /// Mutable borrow of a matrix. | ||
39 | nnMatrix nnMatrixBorrow(nnMatrix* in); | ||
40 | |||
41 | /// Mutable borrow of a subrange of rows of a matrix. | ||
42 | nnMatrix nnMatrixBorrowRows(nnMatrix* in, int row_start, int num_rows); | ||
43 | |||
44 | /// Initialize the matrix from an array of values. | ||
45 | /// | ||
46 | /// The array must hold values in a row-major fashion. | ||
47 | void nnMatrixInit(nnMatrix*, const R* values); | ||
48 | |||
49 | /// Initialize all matrix values to a given constant. | ||
50 | void nnMatrixInitConstant(nnMatrix*, R value); | ||
51 | |||
52 | /// Multiply two matrices. | ||
53 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
54 | |||
55 | /// Matrix multiply-add. | ||
56 | /// | ||
57 | /// out = left + (right * scale) | ||
58 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out); | ||
59 | |||
60 | /// Matrix multiply-subtract. | ||
61 | /// | ||
62 | /// out = left - (right * scale) | ||
63 | void nnMatrixMulSub(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out); | ||
64 | |||
65 | /// Hadamard product of two matrices. | ||
66 | void nnMatrixMulPairs(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
67 | |||
68 | /// Add two matrices. | ||
69 | void nnMatrixAdd(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
70 | |||
71 | /// Subtract two matrices. | ||
72 | void nnMatrixSub(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
73 | |||
74 | /// Adds a row vector to all rows of the matrix. | ||
75 | void nnMatrixAddRow(const nnMatrix* matrix, const nnMatrix* row, nnMatrix* out); | ||
76 | |||
77 | /// Scale a matrix. | ||
78 | void nnMatrixScale(nnMatrix*, R scale); | ||
79 | |||
80 | /// Transpose a matrix. | ||
81 | /// |in| must be different than |out|. | ||
82 | void nnMatrixTranspose(const nnMatrix* in, nnMatrix* out); | ||
83 | |||
84 | /// Threshold the values of a matrix using a greater-than operator. | ||
85 | /// | ||
86 | /// out[x,y] = 1 if in[x,y] > threshold else 0 | ||
87 | void nnMatrixGt(const nnMatrix* in, R threshold, nnMatrix* out); | ||
88 | |||
89 | /// Return the matrix value at the given row and column. | ||
90 | static inline R nnMatrixAt(const nnMatrix* matrix, int row, int col) { | ||
91 | assert(matrix); | ||
92 | return matrix->values[row * matrix->cols + col]; | ||
93 | } | ||
94 | |||
95 | /// Set the matrix value at the given row and column. | ||
96 | static inline void nnMatrixSet(nnMatrix* matrix, int row, int col, R value) { | ||
97 | assert(matrix); | ||
98 | matrix->values[row * matrix->cols + col] = value; | ||
99 | } | ||
100 | |||
101 | /// Return a pointer to the given row in the matrix. | ||
102 | static inline const R* nnMatrixRow(const nnMatrix* matrix, int row) { | ||
103 | assert(matrix); | ||
104 | return &matrix->values[row * matrix->cols]; | ||
105 | } | ||
106 | |||
107 | /// Return a mutable pointer to the given row in the matrix. | ||
108 | static inline R* nnMatrixRow_mut(nnMatrix* matrix, int row) { | ||
109 | assert(matrix); | ||
110 | return &matrix->values[row * matrix->cols]; | ||
111 | } | ||