aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/neuralnet_impl.h
blob: 18694f4ec909816492a830e95e187a70e70a400b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#pragma once

#include <neuralnet/matrix.h>

/// Neural network object.
///
/// We store the transposes of the weight matrices so that we can do forward
/// passes with a minimal amount of work. That is, if in paper we write:
///
///   [w11 w21]
///   [w12 w22]
///
/// then the weight matrix in memory is stored as the following array:
///
///   w11 w12 w21 w22
typedef struct nnNeuralNetwork {
  int           num_layers;  // Number of non-input layers (hidden + output).
  nnMatrix*     weights;     // One matrix per non-input layer.
  nnMatrix*     biases;      // One vector per non-input layer.
  nnActivation* activations; // One per non-input layer.
} nnNeuralNetwork;

/// A query object that holds all the memory necessary to query a network.
///
/// |layer_outputs| is an array of matrices of intermediate layer outputs. There
/// is one matrix per intermediate layer. Each matrix holds the layer's output,
/// with one row per input, and as many columns as the layer's output size (the
/// output vector is transposed.)
///
/// |network_outputs| points to the last output matrix in |layer_outputs| for
/// convenience.
typedef struct nnQueryObject {
  int       num_layers;
  nnMatrix* layer_outputs;   // Output matrices, one output per layer.
  nnMatrix* network_outputs; // Points to the last output matrix.
} nnTrainingQueryObject;