diff options
Diffstat (limited to 'src/lib/src/neuralnet_impl.h')
| -rw-r--r-- | src/lib/src/neuralnet_impl.h | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/src/lib/src/neuralnet_impl.h b/src/lib/src/neuralnet_impl.h new file mode 100644 index 0000000..26107b5 --- /dev/null +++ b/src/lib/src/neuralnet_impl.h | |||
| @@ -0,0 +1,36 @@ | |||
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include <neuralnet/matrix.h> | ||
| 4 | |||
| 5 | /// Neural network object. | ||
| 6 | /// | ||
| 7 | /// We store the transposes of the weight matrices so that we can do forward | ||
| 8 | /// passes with a minimal amount of work. That is, if in paper we write: | ||
| 9 | /// | ||
| 10 | /// [w11 w21] | ||
| 11 | /// [w12 w22] | ||
| 12 | /// | ||
| 13 | /// then the weight matrix in memory is stored as the following array: | ||
| 14 | /// | ||
| 15 | /// w11 w12 w21 w22 | ||
| 16 | typedef struct nnNeuralNetwork { | ||
| 17 | int num_layers; // Number of non-input layers (hidden + output). | ||
| 18 | nnMatrix* weights; // One matrix per non-input layer. | ||
| 19 | nnMatrix* biases; // One vector per non-input layer. | ||
| 20 | nnActivation* activations; // One per non-input layer. | ||
| 21 | } nnNeuralNetwork; | ||
| 22 | |||
| 23 | /// A query object that holds all the memory necessary to query a network. | ||
| 24 | /// | ||
| 25 | /// |layer_outputs| is an array of matrices of intermediate layer outputs. There | ||
| 26 | /// is one matrix per intermediate layer. Each matrix holds the layer's output, | ||
| 27 | /// with one row per input, and as many columns as the layer's output size (the | ||
| 28 | /// output vector is transposed.) | ||
| 29 | /// | ||
| 30 | /// |network_outputs| points to the last output matrix in |layer_outputs| for | ||
| 31 | /// convenience. | ||
| 32 | typedef struct nnQueryObject { | ||
| 33 | int num_layers; | ||
| 34 | nnMatrix* layer_outputs; // Output matrices, one output per layer. | ||
| 35 | nnMatrix* network_outputs; // Points to the last output matrix. | ||
| 36 | } nnTrainingQueryObject; | ||
