aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include/neuralnet/neuralnet.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/include/neuralnet/neuralnet.h')
-rw-r--r--src/lib/include/neuralnet/neuralnet.h64
1 files changed, 64 insertions, 0 deletions
diff --git a/src/lib/include/neuralnet/neuralnet.h b/src/lib/include/neuralnet/neuralnet.h
new file mode 100644
index 0000000..1cf1c53
--- /dev/null
+++ b/src/lib/include/neuralnet/neuralnet.h
@@ -0,0 +1,64 @@
1#pragma once
2
3#include <neuralnet/types.h>
4
5typedef struct nnMatrix nnMatrix;
6
7typedef struct nnNeuralNetwork nnNeuralNetwork;
8typedef struct nnQueryObject nnQueryObject;
9
10/// Neuron activation.
11typedef enum nnActivation {
12 nnIdentity,
13 nnSigmoid,
14 nnRelu,
15} nnActivation;
16
17/// Create a network.
18nnNeuralNetwork* nnMakeNet(int num_layers, const int* layer_sizes, const nnActivation* activations);
19
20/// Delete the network and free its internal memory.
21void nnDeleteNet(nnNeuralNetwork**);
22
23/// Set the network's weights.
24void nnSetWeights(nnNeuralNetwork*, const R* weights);
25
26/// Set the network's biases.
27void nnSetBiases(nnNeuralNetwork*, const R* biases);
28
29/// Query the network.
30///
31/// |input| is a matrix of inputs, one row per input and as many columns as the
32/// input's dimension.
33///
34/// The query object's output matrix (see nnQueryOutputs()) is a matrix of
35/// outputs, one row per output and as many columns as the output's dimension.
36void nnQuery(const nnNeuralNetwork*, nnQueryObject*, const nnMatrix* input);
37
38/// Query the network, array version.
39void nnQueryArray(const nnNeuralNetwork*, nnQueryObject*, const R* input, R* output);
40
41/// Create a query object.
42///
43/// The query object holds all the internal memory required to query a network.
44/// Query objects allocate all memory up front so that network queries can run
45/// without additional memory allocation.
46nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int num_inputs);
47
48/// Delete the query object and free its internal memory.
49void nnDeleteQueryObject(nnQueryObject**);
50
51/// Return the outputs of the query.
52const nnMatrix* nnNetOutputs(const nnQueryObject*);
53
54/// Return the network's input size.
55int nnNetInputSize(const nnNeuralNetwork*);
56
57/// Return the network's output size.
58int nnNetOutputSize(const nnNeuralNetwork*);
59
60/// Return the layer's input size.
61int nnLayerInputSize(const nnMatrix* weights);
62
63/// Return the layer's output size.
64int nnLayerOutputSize(const nnMatrix* weights);