aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include/neuralnet/matrix.h
blob: 4cb0d2506e91e44d7e6e0f21e45fdd569cb69d5e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#pragma once

#include <neuralnet/types.h>

#include <assert.h>

/// NxM matrix.
typedef struct nnMatrix {
  int rows;
  int cols;
  R*  values;
} nnMatrix;

/// Construct a matrix.
nnMatrix nnMatrixMake(int rows, int cols);

/// Delete a matrix and free its internal memory.
void nnMatrixDel(nnMatrix*);

/// Construct a matrix from an array of values.
nnMatrix nnMatrixFromArray(int rows, int cols, const R values[]);

/// Move a matrix.
///
/// |in| is an empty matrix after the move.
/// |out| is a matrix like |in| before the move.
void nnMatrixMove(nnMatrix* in, nnMatrix* out);

/// Deep-copy a matrix.
void nnMatrixCopy(const nnMatrix* in, nnMatrix* out);

/// Write the matrix values into an array in a row-major fashion.
void nnMatrixToArray(const nnMatrix* in, R* out);

/// Write the given row of a matrix into an array.
void nnMatrixRowToArray(const nnMatrix* in, int row, R* out);

/// Copy a column from a source to a target matrix.
void nnMatrixCopyCol(
    const nnMatrix* in, nnMatrix* out, int col_in, int col_out);

/// Mutable borrow of a matrix.
nnMatrix nnMatrixBorrow(nnMatrix* in);

/// Mutable borrow of a subrange of rows of a matrix.
nnMatrix nnMatrixBorrowRows(nnMatrix* in, int row_start, int num_rows);

/// Initialize the matrix from an array of values.
///
/// The array must hold values in a row-major fashion.
void nnMatrixInit(nnMatrix*, const R* values);

/// Initialize all matrix values to a given constant.
void nnMatrixInitConstant(nnMatrix*, R value);

/// Multiply two matrices.
void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);

/// Multiply two matrices, row-by-row variant.
///
/// This function multiples two matrices row-by-row instead of row-by-column,
/// which is equivalent to regular multiplication after transposing the right
/// hand matrix.
///
///     nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O).
void nnMatrixMulRows(
    const nnMatrix* left, const nnMatrix* right, nnMatrix* out);

/// Compute the outer product of two vectors.
void nnMatrixMulOuter(
    const nnMatrix* left, const nnMatrix* right, nnMatrix* out);

/// Matrix multiply-add.
///
/// out = left + (right * scale)
void nnMatrixMulAdd(
    const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out);

/// Matrix multiply-subtract.
///
/// out = left - (right * scale)
void nnMatrixMulSub(
    const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out);

/// Hadamard product of two matrices.
void nnMatrixMulPairs(
    const nnMatrix* left, const nnMatrix* right, nnMatrix* out);

/// Add two matrices.
void nnMatrixAdd(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);

/// Subtract two matrices.
void nnMatrixSub(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);

/// Adds a row vector to all rows of the matrix.
void nnMatrixAddRow(const nnMatrix* matrix, const nnMatrix* row, nnMatrix* out);

/// Scale a matrix.
void nnMatrixScale(nnMatrix*, R scale);

/// Transpose a matrix.
/// |in| must be different than |out|.
void nnMatrixTranspose(const nnMatrix* in, nnMatrix* out);

/// Threshold the values of a matrix using a greater-than operator.
///
/// out[x,y] = 1 if in[x,y] > threshold else 0
void nnMatrixGt(const nnMatrix* in, R threshold, nnMatrix* out);

/// Return the matrix value at the given row and column.
static inline R nnMatrixAt(const nnMatrix* matrix, int row, int col) {
  assert(matrix);
  return matrix->values[row * matrix->cols + col];
}

/// Set the matrix value at the given row and column.
static inline void nnMatrixSet(nnMatrix* matrix, int row, int col, R value) {
  assert(matrix);
  matrix->values[row * matrix->cols + col] = value;
}

/// Return a pointer to the given row in the matrix.
static inline const R* nnMatrixRow(const nnMatrix* matrix, int row) {
  assert(matrix);
  return &matrix->values[row * matrix->cols];
}

/// Return a mutable pointer to the given row in the matrix.
static inline R* nnMatrixRow_mut(nnMatrix* matrix, int row) {
  assert(matrix);
  return &matrix->values[row * matrix->cols];
}