aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include/neuralnet/train.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/include/neuralnet/train.h')
-rw-r--r--src/lib/include/neuralnet/train.h42
1 files changed, 42 insertions, 0 deletions
diff --git a/src/lib/include/neuralnet/train.h b/src/lib/include/neuralnet/train.h
new file mode 100644
index 0000000..79f8e7b
--- /dev/null
+++ b/src/lib/include/neuralnet/train.h
@@ -0,0 +1,42 @@
1#pragma once
2
3#include <neuralnet/neuralnet.h>
4
5#include <stdbool.h>
6#include <stdint.h>
7
8typedef struct nnMatrix nnMatrix;
9
10/// Weight initialization strategy.
11///
12/// Note that regardless of strategy, a layer's weights are scaled by the
13/// layer's size. This is to avoid saturation when, e.g., using a sigmoid
14/// activation with many inputs. Thus, a (0,1) initialization is really
15/// (0,scale), for example.
16typedef enum nnWeightInitStrategy {
17 nnWeightInit01, // (0,1) range.
18 nnWeightInit11, // (-1,+1) range.
19 nnWeightInitNormal, // Normal distribution.
20} nnWeightInitStrategy;
21
22/// Network training parameters.
23typedef struct nnTrainingParams {
24 R learning_rate;
25 int max_iterations;
26 uint64_t seed;
27 nnWeightInitStrategy weight_init;
28 bool debug;
29} nnTrainingParams;
30
31/// Train the network.
32///
33/// |inputs| is a matrix of inputs, one row per input and as many columns as
34/// the input's dimension.
35///
36/// |targets| is a matrix of targets, one row per target and as many columns as
37/// the target's dimension.
38void nnTrain(
39 nnNeuralNetwork*,
40 const nnMatrix* inputs,
41 const nnMatrix* targets,
42 const nnTrainingParams*);