aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include/neuralnet/neuralnet.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/include/neuralnet/neuralnet.h')
-rw-r--r--src/lib/include/neuralnet/neuralnet.h51
1 files changed, 32 insertions, 19 deletions
diff --git a/src/lib/include/neuralnet/neuralnet.h b/src/lib/include/neuralnet/neuralnet.h
index 05c9406..f122c2a 100644
--- a/src/lib/include/neuralnet/neuralnet.h
+++ b/src/lib/include/neuralnet/neuralnet.h
@@ -1,32 +1,45 @@
1#pragma once 1#pragma once
2 2
3#include <neuralnet/matrix.h>
3#include <neuralnet/types.h> 4#include <neuralnet/types.h>
4 5
5typedef struct nnMatrix nnMatrix;
6
7typedef struct nnNeuralNetwork nnNeuralNetwork; 6typedef struct nnNeuralNetwork nnNeuralNetwork;
8typedef struct nnQueryObject nnQueryObject; 7typedef struct nnQueryObject nnQueryObject;
9 8
10/// Neuron activation. 9/// Linear layer parameters.
11typedef enum nnActivation { 10///
12 nnIdentity, 11/// Either one of the following must be set:
12/// a) Training: input and output sizes.
13/// b) Inference: weights + biases.
14typedef struct nnLinearParams {
15 int input_size;
16 int output_size;
17 nnMatrix weights;
18 nnMatrix biases;
19} nnLinearParams;
20
21/// Layer type.
22typedef enum nnLayerType {
23 nnLinear,
13 nnSigmoid, 24 nnSigmoid,
14 nnRelu, 25 nnRelu,
15} nnActivation; 26} nnLayerType;
27
28/// Neural network layer.
29typedef struct nnLayer {
30 nnLayerType type;
31 union {
32 nnLinearParams linear;
33 };
34} nnLayer;
16 35
17/// Create a network. 36/// Create a network.
18nnNeuralNetwork* nnMakeNet( 37nnNeuralNetwork* nnMakeNet(
19 int num_layers, const int* layer_sizes, const nnActivation* activations); 38 const nnLayer* layers, int num_layers, int input_size);
20 39
21/// Delete the network and free its internal memory. 40/// Delete the network and free its internal memory.
22void nnDeleteNet(nnNeuralNetwork**); 41void nnDeleteNet(nnNeuralNetwork**);
23 42
24/// Set the network's weights.
25void nnSetWeights(nnNeuralNetwork*, const R* weights);
26
27/// Set the network's biases.
28void nnSetBiases(nnNeuralNetwork*, const R* biases);
29
30/// Query the network. 43/// Query the network.
31/// 44///
32/// |input| is a matrix of inputs, one row per input and as many columns as the 45/// |input| is a matrix of inputs, one row per input and as many columns as the
@@ -42,10 +55,10 @@ void nnQueryArray(
42 55
43/// Create a query object. 56/// Create a query object.
44/// 57///
45/// The query object holds all the internal memory required to query a network. 58/// The query object holds all the internal memory required to query a network
46/// Query objects allocate all memory up front so that network queries can run 59/// with batches of the given size. Memory is allocated up front so that network
47/// without additional memory allocation. 60/// queries can run without additional memory allocation.
48nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int num_inputs); 61nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int batch_size);
49 62
50/// Delete the query object and free its internal memory. 63/// Delete the query object and free its internal memory.
51void nnDeleteQueryObject(nnQueryObject**); 64void nnDeleteQueryObject(nnQueryObject**);
@@ -60,7 +73,7 @@ int nnNetInputSize(const nnNeuralNetwork*);
60int nnNetOutputSize(const nnNeuralNetwork*); 73int nnNetOutputSize(const nnNeuralNetwork*);
61 74
62/// Return the layer's input size. 75/// Return the layer's input size.
63int nnLayerInputSize(const nnMatrix* weights); 76int nnLayerInputSize(const nnNeuralNetwork*, int layer);
64 77
65/// Return the layer's output size. 78/// Return the layer's output size.
66int nnLayerOutputSize(const nnMatrix* weights); 79int nnLayerOutputSize(const nnNeuralNetwork*, int layer);