aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/include')
-rw-r--r--src/lib/include/neuralnet/matrix.h15
-rw-r--r--src/lib/include/neuralnet/neuralnet.h8
-rw-r--r--src/lib/include/neuralnet/train.h20
3 files changed, 24 insertions, 19 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
index 0cb40cf..b7281bf 100644
--- a/src/lib/include/neuralnet/matrix.h
+++ b/src/lib/include/neuralnet/matrix.h
@@ -33,7 +33,8 @@ void nnMatrixToArray(const nnMatrix* in, R* out);
33void nnMatrixRowToArray(const nnMatrix* in, int row, R* out); 33void nnMatrixRowToArray(const nnMatrix* in, int row, R* out);
34 34
35/// Copy a column from a source to a target matrix. 35/// Copy a column from a source to a target matrix.
36void nnMatrixCopyCol(const nnMatrix* in, nnMatrix* out, int col_in, int col_out); 36void nnMatrixCopyCol(
37 const nnMatrix* in, nnMatrix* out, int col_in, int col_out);
37 38
38/// Mutable borrow of a matrix. 39/// Mutable borrow of a matrix.
39nnMatrix nnMatrixBorrow(nnMatrix* in); 40nnMatrix nnMatrixBorrow(nnMatrix* in);
@@ -56,20 +57,24 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
56/// 57///
57/// This function multiples two matrices row-by-row instead of row-by-column. 58/// This function multiples two matrices row-by-row instead of row-by-column.
58/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). 59/// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O).
59void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 60void nnMatrixMulRows(
61 const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
60 62
61/// Matrix multiply-add. 63/// Matrix multiply-add.
62/// 64///
63/// out = left + (right * scale) 65/// out = left + (right * scale)
64void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out); 66void nnMatrixMulAdd(
67 const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out);
65 68
66/// Matrix multiply-subtract. 69/// Matrix multiply-subtract.
67/// 70///
68/// out = left - (right * scale) 71/// out = left - (right * scale)
69void nnMatrixMulSub(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out); 72void nnMatrixMulSub(
73 const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out);
70 74
71/// Hadamard product of two matrices. 75/// Hadamard product of two matrices.
72void nnMatrixMulPairs(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 76void nnMatrixMulPairs(
77 const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
73 78
74/// Add two matrices. 79/// Add two matrices.
75void nnMatrixAdd(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); 80void nnMatrixAdd(const nnMatrix* left, const nnMatrix* right, nnMatrix* out);
diff --git a/src/lib/include/neuralnet/neuralnet.h b/src/lib/include/neuralnet/neuralnet.h
index 1cf1c53..05c9406 100644
--- a/src/lib/include/neuralnet/neuralnet.h
+++ b/src/lib/include/neuralnet/neuralnet.h
@@ -5,7 +5,7 @@
5typedef struct nnMatrix nnMatrix; 5typedef struct nnMatrix nnMatrix;
6 6
7typedef struct nnNeuralNetwork nnNeuralNetwork; 7typedef struct nnNeuralNetwork nnNeuralNetwork;
8typedef struct nnQueryObject nnQueryObject; 8typedef struct nnQueryObject nnQueryObject;
9 9
10/// Neuron activation. 10/// Neuron activation.
11typedef enum nnActivation { 11typedef enum nnActivation {
@@ -15,7 +15,8 @@ typedef enum nnActivation {
15} nnActivation; 15} nnActivation;
16 16
17/// Create a network. 17/// Create a network.
18nnNeuralNetwork* nnMakeNet(int num_layers, const int* layer_sizes, const nnActivation* activations); 18nnNeuralNetwork* nnMakeNet(
19 int num_layers, const int* layer_sizes, const nnActivation* activations);
19 20
20/// Delete the network and free its internal memory. 21/// Delete the network and free its internal memory.
21void nnDeleteNet(nnNeuralNetwork**); 22void nnDeleteNet(nnNeuralNetwork**);
@@ -36,7 +37,8 @@ void nnSetBiases(nnNeuralNetwork*, const R* biases);
36void nnQuery(const nnNeuralNetwork*, nnQueryObject*, const nnMatrix* input); 37void nnQuery(const nnNeuralNetwork*, nnQueryObject*, const nnMatrix* input);
37 38
38/// Query the network, array version. 39/// Query the network, array version.
39void nnQueryArray(const nnNeuralNetwork*, nnQueryObject*, const R* input, R* output); 40void nnQueryArray(
41 const nnNeuralNetwork*, nnQueryObject*, const R* input, R* output);
40 42
41/// Create a query object. 43/// Create a query object.
42/// 44///
diff --git a/src/lib/include/neuralnet/train.h b/src/lib/include/neuralnet/train.h
index 79f8e7b..6d811c2 100644
--- a/src/lib/include/neuralnet/train.h
+++ b/src/lib/include/neuralnet/train.h
@@ -14,18 +14,18 @@ typedef struct nnMatrix nnMatrix;
14/// activation with many inputs. Thus, a (0,1) initialization is really 14/// activation with many inputs. Thus, a (0,1) initialization is really
15/// (0,scale), for example. 15/// (0,scale), for example.
16typedef enum nnWeightInitStrategy { 16typedef enum nnWeightInitStrategy {
17 nnWeightInit01, // (0,1) range. 17 nnWeightInit01, // (0,1) range.
18 nnWeightInit11, // (-1,+1) range. 18 nnWeightInit11, // (-1,+1) range.
19 nnWeightInitNormal, // Normal distribution. 19 nnWeightInitNormal, // Normal distribution.
20} nnWeightInitStrategy; 20} nnWeightInitStrategy;
21 21
22/// Network training parameters. 22/// Network training parameters.
23typedef struct nnTrainingParams { 23typedef struct nnTrainingParams {
24 R learning_rate; 24 R learning_rate;
25 int max_iterations; 25 int max_iterations;
26 uint64_t seed; 26 uint64_t seed;
27 nnWeightInitStrategy weight_init; 27 nnWeightInitStrategy weight_init;
28 bool debug; 28 bool debug;
29} nnTrainingParams; 29} nnTrainingParams;
30 30
31/// Train the network. 31/// Train the network.
@@ -36,7 +36,5 @@ typedef struct nnTrainingParams {
36/// |targets| is a matrix of targets, one row per target and as many columns as 36/// |targets| is a matrix of targets, one row per target and as many columns as
37/// the target's dimension. 37/// the target's dimension.
38void nnTrain( 38void nnTrain(
39 nnNeuralNetwork*, 39 nnNeuralNetwork*, const nnMatrix* inputs, const nnMatrix* targets,
40 const nnMatrix* inputs, 40 const nnTrainingParams*);
41 const nnMatrix* targets,
42 const nnTrainingParams*);