diff options
Diffstat (limited to 'src/lib/src/neuralnet_impl.h')
| -rw-r--r-- | src/lib/src/neuralnet_impl.h | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/src/lib/src/neuralnet_impl.h b/src/lib/src/neuralnet_impl.h index f5a9c63..935c5ea 100644 --- a/src/lib/src/neuralnet_impl.h +++ b/src/lib/src/neuralnet_impl.h | |||
| @@ -2,22 +2,29 @@ | |||
| 2 | 2 | ||
| 3 | #include <neuralnet/matrix.h> | 3 | #include <neuralnet/matrix.h> |
| 4 | 4 | ||
| 5 | #include <stdbool.h> | ||
| 6 | |||
| 7 | /// Linear layer parameters. | ||
| 8 | typedef struct nnLinearImpl { | ||
| 9 | nnMatrix weights; | ||
| 10 | nnMatrix biases; | ||
| 11 | bool owned; /// Whether the library owns the weights and biases. | ||
| 12 | } nnLinearImpl; | ||
| 13 | |||
| 14 | /// Neural network layer. | ||
| 15 | typedef struct nnLayerImpl { | ||
| 16 | nnLayerType type; | ||
| 17 | int input_size; | ||
| 18 | int output_size; | ||
| 19 | union { | ||
| 20 | nnLinearImpl linear; | ||
| 21 | }; | ||
| 22 | } nnLayerImpl; | ||
| 23 | |||
| 5 | /// Neural network object. | 24 | /// Neural network object. |
| 6 | /// | ||
| 7 | /// We store the transposes of the weight matrices so that we can do forward | ||
| 8 | /// passes with a minimal amount of work. That is, if in paper we write: | ||
| 9 | /// | ||
| 10 | /// [w11 w21] | ||
| 11 | /// [w12 w22] | ||
| 12 | /// | ||
| 13 | /// then the weight matrix in memory is stored as the following array: | ||
| 14 | /// | ||
| 15 | /// w11 w12 w21 w22 | ||
| 16 | typedef struct nnNeuralNetwork { | 25 | typedef struct nnNeuralNetwork { |
| 17 | int num_layers; // Number of non-input layers (hidden + output). | 26 | int num_layers; // Number of non-input layers (hidden + output). |
| 18 | nnMatrix* weights; // One matrix per non-input layer. | 27 | nnLayerImpl* layers; // One per non-input layer. |
| 19 | nnMatrix* biases; // One vector per non-input layer. | ||
| 20 | nnActivation* activations; // One per non-input layer. | ||
| 21 | } nnNeuralNetwork; | 28 | } nnNeuralNetwork; |
| 22 | 29 | ||
| 23 | /// A query object that holds all the memory necessary to query a network. | 30 | /// A query object that holds all the memory necessary to query a network. |
