aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/neuralnet_impl.h
blob: 26107b5e49553102e896a07e30ab17e451685614 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#pragma once

#include <neuralnet/matrix.h>

/// Neural network object.
///
/// We store the transposes of the weight matrices so that we can do forward
/// passes with a minimal amount of work. That is, if in paper we write:
///
///   [w11 w21]
///   [w12 w22]
///
/// then the weight matrix in memory is stored as the following array:
///
///   w11 w12 w21 w22
typedef struct nnNeuralNetwork {
  int num_layers;             // Number of non-input layers (hidden + output).
  nnMatrix* weights;          // One matrix per non-input layer.
  nnMatrix* biases;           // One vector per non-input layer.
  nnActivation* activations;  // One per non-input layer.
} nnNeuralNetwork;

/// A query object that holds all the memory necessary to query a network.
///
/// |layer_outputs| is an array of matrices of intermediate layer outputs. There
/// is one matrix per intermediate layer. Each matrix holds the layer's output,
/// with one row per input, and as many columns as the layer's output size (the
/// output vector is transposed.)
///
/// |network_outputs| points to the last output matrix in |layer_outputs| for
/// convenience.
typedef struct nnQueryObject {
  int       num_layers;
  nnMatrix* layer_outputs;    // Output matrices, one output per layer.
  nnMatrix* network_outputs;  // Points to the last output matrix.
} nnTrainingQueryObject;