aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/bin/mnist/src/main.c195
-rw-r--r--src/lib/include/neuralnet/matrix.h3
-rw-r--r--src/lib/include/neuralnet/neuralnet.h51
-rw-r--r--src/lib/src/activation.h4
-rw-r--r--src/lib/src/matrix.c6
-rw-r--r--src/lib/src/neuralnet.c218
-rw-r--r--src/lib/src/neuralnet_impl.h35
-rw-r--r--src/lib/src/train.c182
-rw-r--r--src/lib/test/neuralnet_test.c103
-rw-r--r--src/lib/test/train_linear_perceptron_non_origin_test.c46
-rw-r--r--src/lib/test/train_linear_perceptron_test.c44
-rw-r--r--src/lib/test/train_sigmoid_test.c46
-rw-r--r--src/lib/test/train_xor_test.c55
13 files changed, 559 insertions, 429 deletions
diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c
index 9aa3ce5..53e0197 100644
--- a/src/bin/mnist/src/main.c
+++ b/src/bin/mnist/src/main.c
@@ -29,32 +29,35 @@ static const double LABEL_UPPER_BOUND = 0.99;
29// Epsilon used to compare R values. 29// Epsilon used to compare R values.
30static const double EPS = 1e-10; 30static const double EPS = 1e-10;
31 31
32#define min(a,b) ((a) < (b) ? (a) : (b)) 32#define min(a, b) ((a) < (b) ? (a) : (b))
33 33
34typedef struct ImageSet { 34typedef struct ImageSet {
35 nnMatrix images; // Images flattened into row vectors of the matrix. 35 nnMatrix images; // Images flattened into row vectors of the matrix.
36 nnMatrix labels; // One-hot-encoded labels. 36 nnMatrix labels; // One-hot-encoded labels.
37 int count; // Number of images and labels. 37 int count; // Number of images and labels.
38 int rows; // Rows in an image. 38 int rows; // Rows in an image.
39 int cols; // Columns in an image. 39 int cols; // Columns in an image.
40} ImageSet; 40} ImageSet;
41 41
42static void usage(const char* argv0) { 42static void usage(const char* argv0) {
43 fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0); 43 fprintf(
44 stderr, "Usage: %s <path to mnist files directory> [num images]\n",
45 argv0);
44 fprintf(stderr, "\n"); 46 fprintf(stderr, "\n");
45 fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); 47 fprintf(
48 stderr,
49 " Use -1 for [num images] to use all the images in the data set\n");
46} 50}
47 51
48static bool R_eq(R a, R b) { 52static bool R_eq(R a, R b) { return fabs(a - b) <= EPS; }
49 return fabs(a-b) <= EPS;
50}
51 53
52static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { 54static void PrintImage(
55 const nnMatrix* images, int rows, int cols, int image_index) {
53 assert(images); 56 assert(images);
54 assert((0 <= image_index) && (image_index < images->rows)); 57 assert((0 <= image_index) && (image_index < images->rows));
55 58
56 // Top line. 59 // Top line.
57 for (int j = 0; j < cols/2; ++j) { 60 for (int j = 0; j < cols / 2; ++j) {
58 printf(" -"); 61 printf(" -");
59 } 62 }
60 printf("\n"); 63 printf("\n");
@@ -68,8 +71,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind
68 printf("#"); 71 printf("#");
69 } else if (*value > 0.5) { 72 } else if (*value > 0.5) {
70 printf("*"); 73 printf("*");
71 } 74 } else if (*value > PIXEL_LOWER_BOUND) {
72 else if (*value > PIXEL_LOWER_BOUND) {
73 printf(":"); 75 printf(":");
74 } else if (*value == 0.0) { 76 } else if (*value == 0.0) {
75 // Values should not be exactly 0, otherwise they cancel out weights 77 // Values should not be exactly 0, otherwise they cancel out weights
@@ -84,7 +86,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind
84 } 86 }
85 87
86 // Bottom line. 88 // Bottom line.
87 for (int j = 0; j < cols/2; ++j) { 89 for (int j = 0; j < cols / 2; ++j) {
88 printf(" -"); 90 printf(" -");
89 } 91 }
90 printf("\n"); 92 printf("\n");
@@ -96,7 +98,7 @@ static void PrintLabel(const nnMatrix* labels, int label_index) {
96 98
97 // Compute the label from the one-hot encoding. 99 // Compute the label from the one-hot encoding.
98 const R* value = nnMatrixRow(labels, label_index); 100 const R* value = nnMatrixRow(labels, label_index);
99 int label = -1; 101 int label = -1;
100 for (int i = 0; i < 10; ++i) { 102 for (int i = 0; i < 10; ++i) {
101 if (R_eq(*value++, LABEL_UPPER_BOUND)) { 103 if (R_eq(*value++, LABEL_UPPER_BOUND)) {
102 label = i; 104 label = i;
@@ -113,13 +115,12 @@ static void PrintLabel(const nnMatrix* labels, int label_index) {
113 printf(")\n"); 115 printf(")\n");
114} 116}
115 117
116static R lerp(R a, R b, R t) { 118static R lerp(R a, R b, R t) { return a + t * (b - a); }
117 return a + t*(b-a);
118}
119 119
120/// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. 120/// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0].
121static R FormatPixel(uint8_t pixel) { 121static R FormatPixel(uint8_t pixel) {
122 const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; 122 const R value =
123 (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND;
123 assert(value >= PIXEL_LOWER_BOUND); 124 assert(value >= PIXEL_LOWER_BOUND);
124 assert(value <= 1.0); 125 assert(value <= 1.0);
125 return value; 126 return value;
@@ -152,7 +153,8 @@ static void ImageToMatrix(
152 } 153 }
153} 154}
154 155
155static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { 156static bool ReadImages(
157 gzFile images_file, int max_num_images, ImageSet* image_set) {
156 assert(images_file != Z_NULL); 158 assert(images_file != Z_NULL);
157 assert(image_set); 159 assert(image_set);
158 160
@@ -161,36 +163,41 @@ static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_s
161 uint8_t* pixels = 0; 163 uint8_t* pixels = 0;
162 164
163 int32_t magic, total_images, rows, cols; 165 int32_t magic, total_images, rows, cols;
164 if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || 166 if ((gzread(images_file, (char*)&magic, sizeof(int32_t)) !=
165 (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || 167 sizeof(int32_t)) ||
166 (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || 168 (gzread(images_file, (char*)&total_images, sizeof(int32_t)) !=
167 (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { 169 sizeof(int32_t)) ||
170 (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) ||
171 (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t))) {
168 fprintf(stderr, "Failed to read header\n"); 172 fprintf(stderr, "Failed to read header\n");
169 goto cleanup; 173 goto cleanup;
170 } 174 }
171 175
172 magic = ReverseEndian32(magic); 176 magic = ReverseEndian32(magic);
173 total_images = ReverseEndian32(total_images); 177 total_images = ReverseEndian32(total_images);
174 rows = ReverseEndian32(rows); 178 rows = ReverseEndian32(rows);
175 cols = ReverseEndian32(cols); 179 cols = ReverseEndian32(cols);
176 180
177 if (magic != IMAGE_FILE_MAGIC) { 181 if (magic != IMAGE_FILE_MAGIC) {
178 fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", 182 fprintf(
179 magic, IMAGE_FILE_MAGIC); 183 stderr, "Magic number mismatch. Got %x, expected: %x\n", magic,
184 IMAGE_FILE_MAGIC);
180 goto cleanup; 185 goto cleanup;
181 } 186 }
182 187
183 printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", 188 printf(
184 magic, total_images, rows, cols); 189 "Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", magic,
190 total_images, rows, cols);
185 191
186 total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; 192 total_images =
193 max_num_images >= 0 ? min(total_images, max_num_images) : total_images;
187 194
188 // Images are flattened into single row vectors. 195 // Images are flattened into single row vectors.
189 const int num_pixels = rows * cols; 196 const int num_pixels = rows * cols;
190 image_set->images = nnMatrixMake(total_images, num_pixels); 197 image_set->images = nnMatrixMake(total_images, num_pixels);
191 image_set->count = total_images; 198 image_set->count = total_images;
192 image_set->rows = rows; 199 image_set->rows = rows;
193 image_set->cols = cols; 200 image_set->cols = cols;
194 201
195 pixels = calloc(1, num_pixels); 202 pixels = calloc(1, num_pixels);
196 if (!pixels) { 203 if (!pixels) {
@@ -219,30 +226,31 @@ cleanup:
219 return success; 226 return success;
220} 227}
221 228
222static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { 229static void OneHotEncode(
230 const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) {
223 assert(labels_bytes); 231 assert(labels_bytes);
224 assert(labels); 232 assert(labels);
225 assert(labels->rows == num_labels); 233 assert(labels->rows == num_labels);
226 assert(labels->cols == 10); 234 assert(labels->cols == 10);
227 235
228 static const R one_hot[10][10] = { 236 static const R one_hot[10][10] = {
229 { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 237 {1, 0, 0, 0, 0, 0, 0, 0, 0, 0},
230 { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, 238 {0, 1, 0, 0, 0, 0, 0, 0, 0, 0},
231 { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, 239 {0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
232 { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, 240 {0, 0, 0, 1, 0, 0, 0, 0, 0, 0},
233 { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, 241 {0, 0, 0, 0, 1, 0, 0, 0, 0, 0},
234 { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, 242 {0, 0, 0, 0, 0, 1, 0, 0, 0, 0},
235 { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, 243 {0, 0, 0, 0, 0, 0, 1, 0, 0, 0},
236 { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, 244 {0, 0, 0, 0, 0, 0, 0, 1, 0, 0},
237 { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, 245 {0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
238 { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 246 {0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
239 }; 247 };
240 248
241 R* value = labels->values; 249 R* value = labels->values;
242 250
243 for (int i = 0; i < num_labels; ++i) { 251 for (int i = 0; i < num_labels; ++i) {
244 const uint8_t label = labels_bytes[i]; 252 const uint8_t label = labels_bytes[i];
245 const R* one_hot_value = one_hot[label]; 253 const R* one_hot_value = one_hot[label];
246 254
247 for (int j = 0; j < 10; ++j) { 255 for (int j = 0; j < 10; ++j) {
248 *value++ = FormatLabel(*one_hot_value++); 256 *value++ = FormatLabel(*one_hot_value++);
@@ -255,13 +263,13 @@ static int OneHotDecode(const nnMatrix* label_matrix) {
255 assert(label_matrix->cols == 10); 263 assert(label_matrix->cols == 10);
256 assert(label_matrix->rows == 1); 264 assert(label_matrix->rows == 1);
257 265
258 R max_value = 0; 266 R max_value = 0;
259 int pos_max = 0; 267 int pos_max = 0;
260 for (int i = 0; i < 10; ++i) { 268 for (int i = 0; i < 10; ++i) {
261 const R value = nnMatrixAt(label_matrix, 0, i); 269 const R value = nnMatrixAt(label_matrix, 0, i);
262 if (value > max_value) { 270 if (value > max_value) {
263 max_value = value; 271 max_value = value;
264 pos_max = i; 272 pos_max = i;
265 } 273 }
266 } 274 }
267 assert(pos_max >= 0); 275 assert(pos_max >= 0);
@@ -269,7 +277,8 @@ static int OneHotDecode(const nnMatrix* label_matrix) {
269 return pos_max; 277 return pos_max;
270} 278}
271 279
272static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { 280static bool ReadLabels(
281 gzFile labels_file, int max_num_labels, ImageSet* image_set) {
273 assert(labels_file != Z_NULL); 282 assert(labels_file != Z_NULL);
274 assert(image_set != 0); 283 assert(image_set != 0);
275 284
@@ -278,24 +287,28 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s
278 uint8_t* labels = 0; 287 uint8_t* labels = 0;
279 288
280 int32_t magic, total_labels; 289 int32_t magic, total_labels;
281 if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || 290 if ((gzread(labels_file, (char*)&magic, sizeof(int32_t)) !=
282 (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { 291 sizeof(int32_t)) ||
292 (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) !=
293 sizeof(int32_t))) {
283 fprintf(stderr, "Failed to read header\n"); 294 fprintf(stderr, "Failed to read header\n");
284 goto cleanup; 295 goto cleanup;
285 } 296 }
286 297
287 magic = ReverseEndian32(magic); 298 magic = ReverseEndian32(magic);
288 total_labels = ReverseEndian32(total_labels); 299 total_labels = ReverseEndian32(total_labels);
289 300
290 if (magic != LABEL_FILE_MAGIC) { 301 if (magic != LABEL_FILE_MAGIC) {
291 fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", 302 fprintf(
292 magic, LABEL_FILE_MAGIC); 303 stderr, "Magic number mismatch. Got %x, expected: %x\n", magic,
304 LABEL_FILE_MAGIC);
293 goto cleanup; 305 goto cleanup;
294 } 306 }
295 307
296 printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); 308 printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels);
297 309
298 total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; 310 total_labels =
311 max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels;
299 312
300 assert(image_set->count == total_labels); 313 assert(image_set->count == total_labels);
301 314
@@ -308,7 +321,8 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s
308 goto cleanup; 321 goto cleanup;
309 } 322 }
310 323
311 if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { 324 if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) !=
325 total_labels) {
312 fprintf(stderr, "Failed to read labels\n"); 326 fprintf(stderr, "Failed to read labels\n");
313 goto cleanup; 327 goto cleanup;
314 } 328 }
@@ -335,17 +349,17 @@ int main(int argc, const char** argv) {
335 349
336 bool success = false; 350 bool success = false;
337 351
338 gzFile train_images_file = Z_NULL; 352 gzFile train_images_file = Z_NULL;
339 gzFile train_labels_file = Z_NULL; 353 gzFile train_labels_file = Z_NULL;
340 gzFile test_images_file = Z_NULL; 354 gzFile test_images_file = Z_NULL;
341 gzFile test_labels_file = Z_NULL; 355 gzFile test_labels_file = Z_NULL;
342 ImageSet train_set = { 0 }; 356 ImageSet train_set = {0};
343 ImageSet test_set = { 0 }; 357 ImageSet test_set = {0};
344 nnNeuralNetwork* net = 0; 358 nnNeuralNetwork* net = 0;
345 nnQueryObject* query = 0; 359 nnQueryObject* query = 0;
346 360
347 const char* mnist_files_dir = argv[1]; 361 const char* mnist_files_dir = argv[1];
348 const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; 362 const int max_num_images = argc > 2 ? atoi(argv[2]) : -1;
349 363
350 char train_labels_path[PATH_MAX]; 364 char train_labels_path[PATH_MAX];
351 char train_images_path[PATH_MAX]; 365 char train_images_path[PATH_MAX];
@@ -353,12 +367,12 @@ int main(int argc, const char** argv) {
353 char test_images_path[PATH_MAX]; 367 char test_images_path[PATH_MAX];
354 strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); 368 strlcpy(train_labels_path, mnist_files_dir, PATH_MAX);
355 strlcpy(train_images_path, mnist_files_dir, PATH_MAX); 369 strlcpy(train_images_path, mnist_files_dir, PATH_MAX);
356 strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); 370 strlcpy(test_labels_path, mnist_files_dir, PATH_MAX);
357 strlcpy(test_images_path, mnist_files_dir, PATH_MAX); 371 strlcpy(test_images_path, mnist_files_dir, PATH_MAX);
358 strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); 372 strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX);
359 strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); 373 strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX);
360 strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); 374 strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX);
361 strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); 375 strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX);
362 376
363 train_images_file = gzopen(train_images_path, "r"); 377 train_images_file = gzopen(train_images_path, "r");
364 if (train_images_file == Z_NULL) { 378 if (train_images_file == Z_NULL) {
@@ -406,11 +420,18 @@ int main(int argc, const char** argv) {
406 } 420 }
407 421
408 // Network definition. 422 // Network definition.
409 const int image_size_pixels = train_set.rows * train_set.cols; 423 const int image_size_pixels = train_set.rows * train_set.cols;
410 const int num_layers = 2; 424 const int num_layers = 4;
411 const int layer_sizes[3] = { image_size_pixels, 100, 10 }; 425 const int hidden_size = 100;
412 const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; 426 const nnLayer layers[4] = {
413 if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { 427 {.type = nnLinear,
428 .linear = {.input_size = image_size_pixels, .output_size = hidden_size}},
429 {.type = nnSigmoid},
430 {.type = nnLinear,
431 .linear = {.input_size = hidden_size, .output_size = 10}},
432 {.type = nnSigmoid}
433 };
434 if (!(net = nnMakeNet(layers, num_layers, image_size_pixels))) {
414 fprintf(stderr, "Failed to create neural network\n"); 435 fprintf(stderr, "Failed to create neural network\n");
415 goto cleanup; 436 goto cleanup;
416 } 437 }
@@ -418,17 +439,17 @@ int main(int argc, const char** argv) {
418 // Train. 439 // Train.
419 printf("Training with up to %d images from the data set\n\n", max_num_images); 440 printf("Training with up to %d images from the data set\n\n", max_num_images);
420 const nnTrainingParams training_params = { 441 const nnTrainingParams training_params = {
421 .learning_rate = 0.1, 442 .learning_rate = 0.1,
422 .max_iterations = TRAIN_ITERATIONS, 443 .max_iterations = TRAIN_ITERATIONS,
423 .seed = 0, 444 .seed = 0,
424 .weight_init = nnWeightInitNormal, 445 .weight_init = nnWeightInitNormal,
425 .debug = true, 446 .debug = true,
426 }; 447 };
427 nnTrain(net, &train_set.images, &train_set.labels, &training_params); 448 nnTrain(net, &train_set.images, &train_set.labels, &training_params);
428 449
429 // Test. 450 // Test.
430 int hits = 0; 451 int hits = 0;
431 query = nnMakeQueryObject(net, /*num_inputs=*/1); 452 query = nnMakeQueryObject(net, /*num_inputs=*/1);
432 for (int i = 0; i < test_set.count; ++i) { 453 for (int i = 0; i < test_set.count; ++i) {
433 const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); 454 const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1);
434 const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); 455 const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1);
@@ -444,7 +465,7 @@ int main(int argc, const char** argv) {
444 } 465 }
445 const R hit_ratio = (R)hits / (R)test_set.count; 466 const R hit_ratio = (R)hits / (R)test_set.count;
446 printf("Test images: %d\n", test_set.count); 467 printf("Test images: %d\n", test_set.count);
447 printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); 468 printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio * 100);
448 469
449 success = true; 470 success = true;
450 471
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
index b7281bf..f80b985 100644
--- a/src/lib/include/neuralnet/matrix.h
+++ b/src/lib/include/neuralnet/matrix.h
@@ -17,6 +17,9 @@ nnMatrix nnMatrixMake(int rows, int cols);
17/// Delete a matrix and free its internal memory. 17/// Delete a matrix and free its internal memory.
18void nnMatrixDel(nnMatrix*); 18void nnMatrixDel(nnMatrix*);
19 19
20/// Construct a matrix from an array of values.
21nnMatrix nnMatrixFromArray(int rows, int cols, const R values[]);
22
20/// Move a matrix. 23/// Move a matrix.
21/// 24///
22/// |in| is an empty matrix after the move. 25/// |in| is an empty matrix after the move.
diff --git a/src/lib/include/neuralnet/neuralnet.h b/src/lib/include/neuralnet/neuralnet.h
index 05c9406..f122c2a 100644
--- a/src/lib/include/neuralnet/neuralnet.h
+++ b/src/lib/include/neuralnet/neuralnet.h
@@ -1,32 +1,45 @@
1#pragma once 1#pragma once
2 2
3#include <neuralnet/matrix.h>
3#include <neuralnet/types.h> 4#include <neuralnet/types.h>
4 5
5typedef struct nnMatrix nnMatrix;
6
7typedef struct nnNeuralNetwork nnNeuralNetwork; 6typedef struct nnNeuralNetwork nnNeuralNetwork;
8typedef struct nnQueryObject nnQueryObject; 7typedef struct nnQueryObject nnQueryObject;
9 8
10/// Neuron activation. 9/// Linear layer parameters.
11typedef enum nnActivation { 10///
12 nnIdentity, 11/// Either one of the following must be set:
12/// a) Training: input and output sizes.
13/// b) Inference: weights + biases.
14typedef struct nnLinearParams {
15 int input_size;
16 int output_size;
17 nnMatrix weights;
18 nnMatrix biases;
19} nnLinearParams;
20
21/// Layer type.
22typedef enum nnLayerType {
23 nnLinear,
13 nnSigmoid, 24 nnSigmoid,
14 nnRelu, 25 nnRelu,
15} nnActivation; 26} nnLayerType;
27
28/// Neural network layer.
29typedef struct nnLayer {
30 nnLayerType type;
31 union {
32 nnLinearParams linear;
33 };
34} nnLayer;
16 35
17/// Create a network. 36/// Create a network.
18nnNeuralNetwork* nnMakeNet( 37nnNeuralNetwork* nnMakeNet(
19 int num_layers, const int* layer_sizes, const nnActivation* activations); 38 const nnLayer* layers, int num_layers, int input_size);
20 39
21/// Delete the network and free its internal memory. 40/// Delete the network and free its internal memory.
22void nnDeleteNet(nnNeuralNetwork**); 41void nnDeleteNet(nnNeuralNetwork**);
23 42
24/// Set the network's weights.
25void nnSetWeights(nnNeuralNetwork*, const R* weights);
26
27/// Set the network's biases.
28void nnSetBiases(nnNeuralNetwork*, const R* biases);
29
30/// Query the network. 43/// Query the network.
31/// 44///
32/// |input| is a matrix of inputs, one row per input and as many columns as the 45/// |input| is a matrix of inputs, one row per input and as many columns as the
@@ -42,10 +55,10 @@ void nnQueryArray(
42 55
43/// Create a query object. 56/// Create a query object.
44/// 57///
45/// The query object holds all the internal memory required to query a network. 58/// The query object holds all the internal memory required to query a network
46/// Query objects allocate all memory up front so that network queries can run 59/// with batches of the given size. Memory is allocated up front so that network
47/// without additional memory allocation. 60/// queries can run without additional memory allocation.
48nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int num_inputs); 61nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int batch_size);
49 62
50/// Delete the query object and free its internal memory. 63/// Delete the query object and free its internal memory.
51void nnDeleteQueryObject(nnQueryObject**); 64void nnDeleteQueryObject(nnQueryObject**);
@@ -60,7 +73,7 @@ int nnNetInputSize(const nnNeuralNetwork*);
60int nnNetOutputSize(const nnNeuralNetwork*); 73int nnNetOutputSize(const nnNeuralNetwork*);
61 74
62/// Return the layer's input size. 75/// Return the layer's input size.
63int nnLayerInputSize(const nnMatrix* weights); 76int nnLayerInputSize(const nnNeuralNetwork*, int layer);
64 77
65/// Return the layer's output size. 78/// Return the layer's output size.
66int nnLayerOutputSize(const nnMatrix* weights); 79int nnLayerOutputSize(const nnNeuralNetwork*, int layer);
diff --git a/src/lib/src/activation.h b/src/lib/src/activation.h
index b56a69e..4c8a9e4 100644
--- a/src/lib/src/activation.h
+++ b/src/lib/src/activation.h
@@ -9,8 +9,8 @@ static inline R sigmoid(R x) { return 1. / (1. + exp(-x)); }
9static inline R relu(R x) { return fmax(0, x); } 9static inline R relu(R x) { return fmax(0, x); }
10 10
11#define NN_MAP_ARRAY(f, in, out, size) \ 11#define NN_MAP_ARRAY(f, in, out, size) \
12 for (int i = 0; i < size; ++i) { \ 12 for (int ii = 0; ii < size; ++ii) { \
13 out[i] = f(in[i]); \ 13 out[ii] = f(in[ii]); \
14 } 14 }
15 15
16#define sigmoid_array(in, out, size) NN_MAP_ARRAY(sigmoid, in, out, size) 16#define sigmoid_array(in, out, size) NN_MAP_ARRAY(sigmoid, in, out, size)
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c
index d98c8bb..d5c3fcc 100644
--- a/src/lib/src/matrix.c
+++ b/src/lib/src/matrix.c
@@ -26,6 +26,12 @@ void nnMatrixDel(nnMatrix* matrix) {
26 } 26 }
27} 27}
28 28
29nnMatrix nnMatrixFromArray(int rows, int cols, const R values[]) {
30 nnMatrix m = nnMatrixMake(rows, cols);
31 nnMatrixInit(&m, values);
32 return m;
33}
34
29void nnMatrixMove(nnMatrix* in, nnMatrix* out) { 35void nnMatrixMove(nnMatrix* in, nnMatrix* out) {
30 assert(in); 36 assert(in);
31 assert(out); 37 assert(out);
diff --git a/src/lib/src/neuralnet.c b/src/lib/src/neuralnet.c
index a5fc59b..4322b8c 100644
--- a/src/lib/src/neuralnet.c
+++ b/src/lib/src/neuralnet.c
@@ -7,11 +7,65 @@
7#include <assert.h> 7#include <assert.h>
8#include <stdlib.h> 8#include <stdlib.h>
9 9
10static void MakeLayerImpl(
11 int prev_layer_output_size, const nnLayer* layer, nnLayerImpl* impl) {
12 impl->type = layer->type;
13
14 switch (layer->type) {
15 case nnLinear: {
16 const nnLinearParams* params = &layer->linear;
17 nnLinearImpl* linear = &impl->linear;
18
19 if ((params->input_size > 0) && (params->output_size > 0)) {
20 const int rows = params->input_size;
21 const int cols = params->output_size;
22 linear->weights = nnMatrixMake(rows, cols);
23 linear->biases = nnMatrixMake(1, cols);
24 linear->owned = true;
25 } else {
26 linear->weights = params->weights;
27 linear->biases = params->biases;
28 linear->owned = false;
29 }
30
31 impl->input_size = linear->weights.rows;
32 impl->output_size = linear->weights.cols;
33
34 break;
35 }
36
37 // Activation layers.
38 case nnRelu:
39 case nnSigmoid:
40 impl->input_size = prev_layer_output_size;
41 impl->output_size = prev_layer_output_size;
42 break;
43 }
44}
45
46static void DeleteLayer(nnLayerImpl* layer) {
47 switch (layer->type) {
48 case nnLinear: {
49 nnLinearImpl* linear = &layer->linear;
50 if (linear->owned) {
51 nnMatrixDel(&linear->weights);
52 nnMatrixDel(&linear->biases);
53 }
54 break;
55 }
56
57 // No parameters for these layers.
58 case nnRelu:
59 case nnSigmoid:
60 break;
61 }
62}
63
10nnNeuralNetwork* nnMakeNet( 64nnNeuralNetwork* nnMakeNet(
11 int num_layers, const int* layer_sizes, const nnActivation* activations) { 65 const nnLayer* layers, int num_layers, int input_size) {
66 assert(layers);
12 assert(num_layers > 0); 67 assert(num_layers > 0);
13 assert(layer_sizes); 68 assert(input_size > 0);
14 assert(activations);
15 69
16 nnNeuralNetwork* net = calloc(1, sizeof(nnNeuralNetwork)); 70 nnNeuralNetwork* net = calloc(1, sizeof(nnNeuralNetwork));
17 if (net == 0) { 71 if (net == 0) {
@@ -20,84 +74,38 @@ nnNeuralNetwork* nnMakeNet(
20 74
21 net->num_layers = num_layers; 75 net->num_layers = num_layers;
22 76
23 net->weights = calloc(num_layers, sizeof(nnMatrix)); 77 net->layers = calloc(num_layers, sizeof(nnLayerImpl));
24 net->biases = calloc(num_layers, sizeof(nnMatrix)); 78 if (net->layers == 0) {
25 net->activations = calloc(num_layers, sizeof(nnActivation));
26 if ((net->weights == 0) || (net->biases == 0) || (net->activations == 0)) {
27 nnDeleteNet(&net); 79 nnDeleteNet(&net);
28 return 0; 80 return 0;
29 } 81 }
30 82
83 int prev_layer_output_size = input_size;
31 for (int l = 0; l < num_layers; ++l) { 84 for (int l = 0; l < num_layers; ++l) {
32 // layer_sizes = { input layer size, first hidden layer size, ...} 85 MakeLayerImpl(prev_layer_output_size, &layers[l], &net->layers[l]);
33 const int layer_input_size = layer_sizes[l]; 86 prev_layer_output_size = net->layers[l].output_size;
34 const int layer_output_size = layer_sizes[l + 1];
35
36 // We store the transpose of the weight matrix as written in textbooks.
37 // Our vectors are row vectors and the matrices row-major.
38 const int rows = layer_input_size;
39 const int cols = layer_output_size;
40
41 net->weights[l] = nnMatrixMake(rows, cols);
42 net->biases[l] = nnMatrixMake(1, cols);
43 net->activations[l] = activations[l];
44 } 87 }
45 88
46 return net; 89 return net;
47} 90}
48 91
49void nnDeleteNet(nnNeuralNetwork** net) { 92void nnDeleteNet(nnNeuralNetwork** ppNet) {
50 if ((!net) || (!(*net))) { 93 if ((!ppNet) || (!(*ppNet))) {
51 return; 94 return;
52 } 95 }
53 if ((*net)->weights != 0) { 96 nnNeuralNetwork* net = *ppNet;
54 for (int l = 0; l < (*net)->num_layers; ++l) {
55 nnMatrixDel(&(*net)->weights[l]);
56 }
57 free((*net)->weights);
58 (*net)->weights = 0;
59 }
60 if ((*net)->biases != 0) {
61 for (int l = 0; l < (*net)->num_layers; ++l) {
62 nnMatrixDel(&(*net)->biases[l]);
63 }
64 free((*net)->biases);
65 (*net)->biases = 0;
66 }
67 if ((*net)->activations) {
68 free((*net)->activations);
69 (*net)->activations = 0;
70 }
71 free(*net);
72 *net = 0;
73}
74
75void nnSetWeights(nnNeuralNetwork* net, const R* weights) {
76 assert(net);
77 assert(weights);
78 97
79 for (int l = 0; l < net->num_layers; ++l) { 98 for (int l = 0; l < net->num_layers; ++l) {
80 nnMatrix* layer_weights = &net->weights[l]; 99 DeleteLayer(&net->layers[l]);
81 R* layer_values = layer_weights->values;
82
83 for (int j = 0; j < layer_weights->rows * layer_weights->cols; ++j) {
84 *layer_values++ = *weights++;
85 }
86 } 100 }
87}
88
89void nnSetBiases(nnNeuralNetwork* net, const R* biases) {
90 assert(net);
91 assert(biases);
92
93 for (int l = 0; l < net->num_layers; ++l) {
94 nnMatrix* layer_biases = &net->biases[l];
95 R* layer_values = layer_biases->values;
96 101
97 for (int j = 0; j < layer_biases->rows * layer_biases->cols; ++j) { 102 if (net->layers) {
98 *layer_values++ = *biases++; 103 free(net->layers);
99 } 104 net->layers = 0;
100 } 105 }
106
107 free(net);
108 *ppNet = 0;
101} 109}
102 110
103void nnQuery( 111void nnQuery(
@@ -114,35 +122,40 @@ void nnQuery(
114 nnMatrix input_vector = nnMatrixBorrowRows((nnMatrix*)input, i, 1); 122 nnMatrix input_vector = nnMatrixBorrowRows((nnMatrix*)input, i, 1);
115 123
116 for (int l = 0; l < net->num_layers; ++l) { 124 for (int l = 0; l < net->num_layers; ++l) {
117 const nnMatrix* layer_weights = &net->weights[l];
118 const nnMatrix* layer_biases = &net->biases[l];
119 // Y^T = (W*X)^T = X^T*W^T
120 //
121 // TODO: If we had a row-row matrix multiplication, we could compute:
122 // Y^T = W ** X^T
123 // The row-row multiplication could be more cache-friendly. We just need
124 // to store W as is, without transposing.
125 // We could also rewrite the original Mul function to go row x row,
126 // decomposing the multiplication. Preserving the original meaning of Mul
127 // makes everything clearer.
128 nnMatrix output_vector = 125 nnMatrix output_vector =
129 nnMatrixBorrowRows(&query->layer_outputs[l], i, 1); 126 nnMatrixBorrowRows(&query->layer_outputs[l], i, 1);
130 nnMatrixMul(&input_vector, layer_weights, &output_vector);
131 nnMatrixAddRow(&output_vector, layer_biases, &output_vector);
132 127
133 switch (net->activations[l]) { 128 switch (net->layers[l].type) {
134 case nnIdentity: 129 case nnLinear: {
135 break; // Nothing to do for the identity function. 130 const nnLinearImpl* linear = &net->layers[l].linear;
136 case nnSigmoid: 131 const nnMatrix* layer_weights = &linear->weights;
137 sigmoid_array( 132 const nnMatrix* layer_biases = &linear->biases;
138 output_vector.values, output_vector.values, output_vector.cols); 133
134 // Y^T = (W*X)^T = X^T*W^T
135 //
136 // TODO: If we had a row-row matrix multiplication, we could compute:
137 // Y^T = W ** X^T
138 //
139 // The row-row multiplication could be more cache-friendly. We just need
140 // to store W as is, without transposing.
141 //
142 // We could also rewrite the original Mul function to go row x row,
143 // decomposing the multiplication. Preserving the original meaning of
144 // Mul makes everything clearer.
145 nnMatrixMul(&input_vector, layer_weights, &output_vector);
146 nnMatrixAddRow(&output_vector, layer_biases, &output_vector);
139 break; 147 break;
148 }
140 case nnRelu: 149 case nnRelu:
150 assert(input_vector.cols == output_vector.cols);
141 relu_array( 151 relu_array(
142 output_vector.values, output_vector.values, output_vector.cols); 152 input_vector.values, output_vector.values, output_vector.cols);
153 break;
154 case nnSigmoid:
155 assert(input_vector.cols == output_vector.cols);
156 sigmoid_array(
157 input_vector.values, output_vector.values, output_vector.cols);
143 break; 158 break;
144 default:
145 assert(0);
146 } 159 }
147 160
148 input_vector = output_vector; // Borrow. 161 input_vector = output_vector; // Borrow.
@@ -159,15 +172,15 @@ void nnQueryArray(
159 assert(output); 172 assert(output);
160 assert(net->num_layers > 0); 173 assert(net->num_layers > 0);
161 174
162 nnMatrix input_vector = nnMatrixMake(net->weights[0].cols, 1); 175 nnMatrix input_vector = nnMatrixMake(1, nnNetInputSize(net));
163 nnMatrixInit(&input_vector, input); 176 nnMatrixInit(&input_vector, input);
164 nnQuery(net, query, &input_vector); 177 nnQuery(net, query, &input_vector);
165 nnMatrixRowToArray(query->network_outputs, 0, output); 178 nnMatrixRowToArray(query->network_outputs, 0, output);
166} 179}
167 180
168nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork* net, int num_inputs) { 181nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork* net, int batch_size) {
169 assert(net); 182 assert(net);
170 assert(num_inputs > 0); 183 assert(batch_size > 0);
171 assert(net->num_layers > 0); 184 assert(net->num_layers > 0);
172 185
173 nnQueryObject* query = calloc(1, sizeof(nnQueryObject)); 186 nnQueryObject* query = calloc(1, sizeof(nnQueryObject));
@@ -183,11 +196,12 @@ nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork* net, int num_inputs) {
183 free(query); 196 free(query);
184 return 0; 197 return 0;
185 } 198 }
199
186 for (int l = 0; l < net->num_layers; ++l) { 200 for (int l = 0; l < net->num_layers; ++l) {
187 const nnMatrix* layer_weights = &net->weights[l]; 201 const int layer_output_size = nnLayerOutputSize(net, l);
188 const int layer_output_size = nnLayerOutputSize(layer_weights); 202 query->layer_outputs[l] = nnMatrixMake(batch_size, layer_output_size);
189 query->layer_outputs[l] = nnMatrixMake(num_inputs, layer_output_size);
190 } 203 }
204
191 query->network_outputs = &query->layer_outputs[net->num_layers - 1]; 205 query->network_outputs = &query->layer_outputs[net->num_layers - 1];
192 206
193 return query; 207 return query;
@@ -213,23 +227,19 @@ const nnMatrix* nnNetOutputs(const nnQueryObject* query) {
213} 227}
214 228
215int nnNetInputSize(const nnNeuralNetwork* net) { 229int nnNetInputSize(const nnNeuralNetwork* net) {
216 assert(net); 230 return nnLayerInputSize(net, 0);
217 assert(net->num_layers > 0);
218 return net->weights[0].rows;
219} 231}
220 232
221int nnNetOutputSize(const nnNeuralNetwork* net) { 233int nnNetOutputSize(const nnNeuralNetwork* net) {
222 assert(net); 234 return nnLayerOutputSize(net, net->num_layers - 1);
223 assert(net->num_layers > 0);
224 return net->weights[net->num_layers - 1].cols;
225} 235}
226 236
227int nnLayerInputSize(const nnMatrix* weights) { 237int nnLayerInputSize(const nnNeuralNetwork* net, int layer) {
228 assert(weights); 238 assert(net);
229 return weights->rows; 239 return net->layers[layer].input_size;
230} 240}
231 241
232int nnLayerOutputSize(const nnMatrix* weights) { 242int nnLayerOutputSize(const nnNeuralNetwork* net, int layer) {
233 assert(weights); 243 assert(net);
234 return weights->cols; 244 return net->layers[layer].output_size;
235} 245}
diff --git a/src/lib/src/neuralnet_impl.h b/src/lib/src/neuralnet_impl.h
index f5a9c63..935c5ea 100644
--- a/src/lib/src/neuralnet_impl.h
+++ b/src/lib/src/neuralnet_impl.h
@@ -2,22 +2,29 @@
2 2
3#include <neuralnet/matrix.h> 3#include <neuralnet/matrix.h>
4 4
5#include <stdbool.h>
6
7/// Linear layer parameters.
8typedef struct nnLinearImpl {
9 nnMatrix weights;
10 nnMatrix biases;
11 bool owned; /// Whether the library owns the weights and biases.
12} nnLinearImpl;
13
14/// Neural network layer.
15typedef struct nnLayerImpl {
16 nnLayerType type;
17 int input_size;
18 int output_size;
19 union {
20 nnLinearImpl linear;
21 };
22} nnLayerImpl;
23
5/// Neural network object. 24/// Neural network object.
6///
7/// We store the transposes of the weight matrices so that we can do forward
8/// passes with a minimal amount of work. That is, if in paper we write:
9///
10/// [w11 w21]
11/// [w12 w22]
12///
13/// then the weight matrix in memory is stored as the following array:
14///
15/// w11 w12 w21 w22
16typedef struct nnNeuralNetwork { 25typedef struct nnNeuralNetwork {
17 int num_layers; // Number of non-input layers (hidden + output). 26 int num_layers; // Number of non-input layers (hidden + output).
18 nnMatrix* weights; // One matrix per non-input layer. 27 nnLayerImpl* layers; // One per non-input layer.
19 nnMatrix* biases; // One vector per non-input layer.
20 nnActivation* activations; // One per non-input layer.
21} nnNeuralNetwork; 28} nnNeuralNetwork;
22 29
23/// A query object that holds all the memory necessary to query a network. 30/// A query object that holds all the memory necessary to query a network.
diff --git a/src/lib/src/train.c b/src/lib/src/train.c
index dc93f0f..98f58ad 100644
--- a/src/lib/src/train.c
+++ b/src/lib/src/train.c
@@ -38,7 +38,7 @@ typedef struct nnSigmoidGradientElements {
38/// each layer. A data type is defined for these because we allocate all the 38/// each layer. A data type is defined for these because we allocate all the
39/// required memory up front before entering the training loop. 39/// required memory up front before entering the training loop.
40typedef struct nnGradientElements { 40typedef struct nnGradientElements {
41 nnActivation type; 41 nnLayerType type;
42 // Gradient vector, same size as the layer. 42 // Gradient vector, same size as the layer.
43 // This will contain the gradient expression except for the output value of 43 // This will contain the gradient expression except for the output value of
44 // the previous layer. 44 // the previous layer.
@@ -57,10 +57,27 @@ void nnInitNet(
57 mt19937_64_init(&rng, seed); 57 mt19937_64_init(&rng, seed);
58 58
59 for (int l = 0; l < net->num_layers; ++l) { 59 for (int l = 0; l < net->num_layers; ++l) {
60 nnMatrix* weights = &net->weights[l]; 60 // Get the layer's weights and biases, if any.
61 nnMatrix* biases = &net->biases[l]; 61 nnMatrix* weights = 0;
62 nnMatrix* biases = 0;
63 switch (net->layers[l].type) {
64 case nnLinear: {
65 nnLinearImpl* linear = &net->layers[l].linear;
66
67 weights = &linear->weights;
68 biases = &linear->biases;
69 break;
70 }
71 // Activations.
72 case nnRelu:
73 case nnSigmoid:
74 break;
75 }
76 if (!weights || !biases) {
77 continue;
78 }
62 79
63 const R layer_size = (R)nnLayerInputSize(weights); 80 const R layer_size = (R)nnLayerInputSize(net, l);
64 const R scale = 1. / layer_size; 81 const R scale = 1. / layer_size;
65 const R stdev = 1. / sqrt((R)layer_size); 82 const R stdev = 1. / sqrt((R)layer_size);
66 const R sigma = stdev * stdev; 83 const R sigma = stdev * stdev;
@@ -128,9 +145,6 @@ void nnTrain(
128 // with one sample at a time. 145 // with one sample at a time.
129 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); 146 nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix));
130 147
131 // Allocate the weight transpose matrices up front for backpropagation.
132 // nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix));
133
134 // Allocate the weight delta matrices. 148 // Allocate the weight delta matrices.
135 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); 149 nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix));
136 150
@@ -144,30 +158,24 @@ void nnTrain(
144 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); 158 nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix));
145 159
146 assert(errors != 0); 160 assert(errors != 0);
147 // assert(weights_T != 0);
148 assert(weight_deltas != 0); 161 assert(weight_deltas != 0);
149 assert(gradient_elems); 162 assert(gradient_elems);
150 assert(outputs_T); 163 assert(outputs_T);
151 164
152 for (int l = 0; l < net->num_layers; ++l) { 165 for (int l = 0; l < net->num_layers; ++l) {
153 const nnMatrix* layer_weights = &net->weights[l]; 166 const int layer_input_size = nnLayerInputSize(net, l);
154 const int layer_output_size = net->weights[l].cols; 167 const int layer_output_size = nnLayerOutputSize(net, l);
155 const nnActivation activation = net->activations[l]; 168 const nnLayerImpl* layer = &net->layers[l];
156
157 errors[l] = nnMatrixMake(1, layer_weights->cols);
158
159 // weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows);
160 // nnMatrixTranspose(layer_weights, &weights_T[l]);
161
162 weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols);
163 169
164 outputs_T[l] = nnMatrixMake(layer_output_size, 1); 170 errors[l] = nnMatrixMake(1, layer_output_size);
171 weight_deltas[l] = nnMatrixMake(layer_input_size, layer_output_size);
172 outputs_T[l] = nnMatrixMake(layer_output_size, 1);
165 173
166 // Allocate the gradient elements and vectors for weight delta calculation. 174 // Allocate the gradient elements and vectors for weight delta calculation.
167 nnGradientElements* elems = &gradient_elems[l]; 175 nnGradientElements* elems = &gradient_elems[l];
168 elems->type = activation; 176 elems->type = layer->type;
169 switch (activation) { 177 switch (layer->type) {
170 case nnIdentity: 178 case nnLinear:
171 break; // Gradient vector will be borrowed, no need to allocate. 179 break; // Gradient vector will be borrowed, no need to allocate.
172 180
173 case nnSigmoid: 181 case nnSigmoid:
@@ -208,6 +216,7 @@ void nnTrain(
208 216
209 // For now, we train with one sample at a time. 217 // For now, we train with one sample at a time.
210 for (int sample = 0; sample < inputs->rows; ++sample) { 218 for (int sample = 0; sample < inputs->rows; ++sample) {
219 // TODO: Introduce a BorrowMut.
211 // Slice the input and target matrices with the batch size. 220 // Slice the input and target matrices with the batch size.
212 // We are not mutating the inputs, but we need the cast to borrow. 221 // We are not mutating the inputs, but we need the cast to borrow.
213 nnMatrix training_inputs = 222 nnMatrix training_inputs =
@@ -219,15 +228,16 @@ void nnTrain(
219 // Assuming one training input per iteration for now. 228 // Assuming one training input per iteration for now.
220 nnMatrixTranspose(&training_inputs, &training_inputs_T); 229 nnMatrixTranspose(&training_inputs, &training_inputs_T);
221 230
222 // Run a forward pass and compute the output layer error relevant to the 231 // Forward pass.
223 // derivative: o-t. 232 nnQuery(net, query, &training_inputs);
224 // Error: (t-o)^2 233
225 // dE/do = -2(t-o) 234 // Compute the error derivative: o-t.
226 // = +2(o-t) 235 // Error: 1/2 (t-o)^2
236 // dE/do = -(t-o)
237 // = +(o-t)
227 // Note that we compute o-t instead to remove that outer negative sign. 238 // Note that we compute o-t instead to remove that outer negative sign.
228 // The 2 is dropped because we are only interested in the direction of the 239 // The 2 is dropped because we are only interested in the direction of the
229 // gradient. The learning rate controls the magnitude. 240 // gradient. The learning rate controls the magnitude.
230 nnQuery(net, query, &training_inputs);
231 nnMatrixSub( 241 nnMatrixSub(
232 training_outputs, &training_targets, &errors[net->num_layers - 1]); 242 training_outputs, &training_targets, &errors[net->num_layers - 1]);
233 243
@@ -236,68 +246,86 @@ void nnTrain(
236 nnMatrixTranspose(&query->layer_outputs[l], &outputs_T[l]); 246 nnMatrixTranspose(&query->layer_outputs[l], &outputs_T[l]);
237 } 247 }
238 248
239 // Update weights and biases for each internal layer, backpropagating 249 // Update weights and biases for each internal layer, back-propagating
240 // errors along the way. 250 // errors along the way.
241 for (int l = net->num_layers - 1; l >= 0; --l) { 251 for (int l = net->num_layers - 1; l >= 0; --l) {
242 const nnMatrix* layer_output = &query->layer_outputs[l]; 252 const nnMatrix* layer_output = &query->layer_outputs[l];
243 nnMatrix* layer_weights = &net->weights[l]; 253 nnGradientElements* elems = &gradient_elems[l];
244 nnMatrix* layer_biases = &net->biases[l]; 254 nnMatrix* gradient = &elems->gradient;
245 nnGradientElements* elems = &gradient_elems[l]; 255 nnLayerImpl* layer = &net->layers[l];
246 nnMatrix* gradient = &elems->gradient; 256
247 const nnActivation activation = net->activations[l]; 257 // Compute this layer's gradient.
248 258 //
249 // Compute the gradient (the part of the expression that does not 259 // By "gradient" we mean the expression common to the weights and bias
250 // contain the output of the previous layer). 260 // gradients. This is the part of the expression that does not contain
261 // this layer's input.
251 // 262 //
252 // Identity: G = error_k 263 // Linear: G = id
253 // Sigmoid: G = error_k * output_k * (1 - output_k). 264 // Relu: G = (output_k > 0 ? 1 : 0)
254 // Relu: G = error_k * (output_k > 0 ? 1 : 0) 265 // Sigmoid: G = output_k * (1 - output_k)
255 switch (activation) { 266 switch (layer->type) {
256 case nnIdentity: 267 case nnLinear: {
257 // TODO: Just copy the pointer? 268 // TODO: Just copy the pointer?
258 *gradient = nnMatrixBorrow(&errors[l]); 269 *gradient = nnMatrixBorrow(&errors[l]);
259 break; 270 break;
271 }
272 case nnRelu:
273 nnMatrixGt(layer_output, 0, gradient);
274 break;
260 case nnSigmoid: 275 case nnSigmoid:
261 nnMatrixSub(&elems->sigmoid.ones, layer_output, gradient); 276 nnMatrixSub(&elems->sigmoid.ones, layer_output, gradient);
262 nnMatrixMulPairs(layer_output, gradient, gradient); 277 nnMatrixMulPairs(layer_output, gradient, gradient);
263 nnMatrixMulPairs(&errors[l], gradient, gradient);
264 break;
265 case nnRelu:
266 nnMatrixGt(layer_output, 0, gradient);
267 nnMatrixMulPairs(&errors[l], gradient, gradient);
268 break; 278 break;
269 } 279 }
270 280
271 // Outer product to compute the weight deltas. 281 // Back-propagate the error.
272 const nnMatrix* output_T = 282 //
273 (l == 0) ? &training_inputs_T : &outputs_T[l - 1]; 283 // This combines this layer's gradient with the back-propagated error,
274 nnMatrixMul(output_T, gradient, &weight_deltas[l]); 284 // which is the combination of the gradients of subsequent layers down
275 285 // to the output layer error.
276 // Backpropagate the error before updating weights. 286 //
287 // Note that this step uses the layer's original weights.
277 if (l > 0) { 288 if (l > 0) {
278 // G * W^T == G *^T W. 289 switch (layer->type) {
279 // nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); 290 case nnLinear: {
280 nnMatrixMulRows(gradient, layer_weights, &errors[l - 1]); 291 const nnMatrix* layer_weights = &layer->linear.weights;
292 // E * W^T == E *^T W.
293 // Using nnMatrixMulRows, we avoid having to transpose the weight
294 // matrix.
295 nnMatrixMulRows(&errors[l], layer_weights, &errors[l - 1]);
296 break;
297 }
298 // For activations, the error back-propagates as is but multiplied by
299 // the layer's gradient.
300 case nnRelu:
301 case nnSigmoid:
302 nnMatrixMulPairs(&errors[l], gradient, &errors[l - 1]);
303 break;
304 }
281 } 305 }
282 306
283 // Update weights. 307 // Update layer weights.
284 nnMatrixScale(&weight_deltas[l], params->learning_rate); 308 if (layer->type == nnLinear) {
285 // The gradient has a negative sign from -(t - o), but we have computed 309 nnLinearImpl* linear = &layer->linear;
286 // e = o - t instead, so we can subtract directly. 310 nnMatrix* layer_weights = &linear->weights;
287 // nnMatrixAdd(layer_weights, &weight_deltas[l], layer_weights); 311 nnMatrix* layer_biases = &linear->biases;
288 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); 312
289 313 // Outer product to compute the weight deltas.
290 // Update weight transpose matrix for the next training iteration. 314 // This layer's input is the previous layer's output.
291 // nnMatrixTranspose(layer_weights, &weights_T[l]); 315 const nnMatrix* input_T =
292 316 (l == 0) ? &training_inputs_T : &outputs_T[l - 1];
293 // Update biases. 317 nnMatrixMul(input_T, gradient, &weight_deltas[l]);
294 // This is the same formula as for weights, except that the o_j term is 318
295 // just 1. We can simply re-use the gradient that we have already 319 // Update weights.
296 // computed for the weight update. 320 nnMatrixScale(&weight_deltas[l], params->learning_rate);
297 // nnMatrixMulAdd(layer_biases, gradient, params->learning_rate, 321 nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights);
298 // layer_biases); 322
299 nnMatrixMulSub( 323 // Update biases.
300 layer_biases, gradient, params->learning_rate, layer_biases); 324 // This is the same formula as for weights, except that the o_j term
325 // is just 1.
326 nnMatrixMulSub(
327 layer_biases, gradient, params->learning_rate, layer_biases);
328 }
301 } 329 }
302 330
303 // TODO: Add this under a verbose debugging mode. 331 // TODO: Add this under a verbose debugging mode.
@@ -334,12 +362,11 @@ void nnTrain(
334 for (int l = 0; l < net->num_layers; ++l) { 362 for (int l = 0; l < net->num_layers; ++l) {
335 nnMatrixDel(&errors[l]); 363 nnMatrixDel(&errors[l]);
336 nnMatrixDel(&outputs_T[l]); 364 nnMatrixDel(&outputs_T[l]);
337 // nnMatrixDel(&weights_T[l]);
338 nnMatrixDel(&weight_deltas[l]); 365 nnMatrixDel(&weight_deltas[l]);
339 366
340 nnGradientElements* elems = &gradient_elems[l]; 367 nnGradientElements* elems = &gradient_elems[l];
341 switch (elems->type) { 368 switch (elems->type) {
342 case nnIdentity: 369 case nnLinear:
343 break; // Gradient vector is borrowed, no need to deallocate. 370 break; // Gradient vector is borrowed, no need to deallocate.
344 371
345 case nnSigmoid: 372 case nnSigmoid:
@@ -355,7 +382,6 @@ void nnTrain(
355 nnMatrixDel(&training_inputs_T); 382 nnMatrixDel(&training_inputs_T);
356 free(errors); 383 free(errors);
357 free(outputs_T); 384 free(outputs_T);
358 // free(weights_T);
359 free(weight_deltas); 385 free(weight_deltas);
360 free(gradient_elems); 386 free(gradient_elems);
361} 387}
diff --git a/src/lib/test/neuralnet_test.c b/src/lib/test/neuralnet_test.c
index 14d9438..0f8d7b8 100644
--- a/src/lib/test/neuralnet_test.c
+++ b/src/lib/test/neuralnet_test.c
@@ -1,8 +1,8 @@
1#include <neuralnet/neuralnet.h> 1#include <neuralnet/neuralnet.h>
2 2
3#include <neuralnet/matrix.h>
4#include "activation.h" 3#include "activation.h"
5#include "neuralnet_impl.h" 4#include "neuralnet_impl.h"
5#include <neuralnet/matrix.h>
6 6
7#include "test.h" 7#include "test.h"
8#include "test_util.h" 8#include "test_util.h"
@@ -10,23 +10,31 @@
10#include <assert.h> 10#include <assert.h>
11 11
12TEST_CASE(neuralnet_perceptron_test) { 12TEST_CASE(neuralnet_perceptron_test) {
13 const int num_layers = 1; 13 const int num_layers = 2;
14 const int layer_sizes[] = { 1, 1 }; 14 const int input_size = 1;
15 const nnActivation layer_activations[] = { nnSigmoid }; 15 const R weights[] = {0.3};
16 const R weights[] = { 0.3 }; 16 const R biases[] = {0.0};
17 const nnLayer layers[] = {
18 {.type = nnLinear,
19 .linear =
20 {.weights = nnMatrixFromArray(1, 1, weights),
21 .biases = nnMatrixFromArray(1, 1, biases)}},
22 {.type = nnSigmoid},
23 };
17 24
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 25 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 26 assert(net);
20 nnSetWeights(net, weights);
21 27
22 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); 28 nnQueryObject* query = nnMakeQueryObject(net, 1);
23 29
24 const R input[] = { 0.9 }; 30 const R input[] = {0.9};
25 R output[1]; 31 R output[1];
26 nnQueryArray(net, query, input, output); 32 nnQueryArray(net, query, input, output);
27 33
28 const R expected_output = sigmoid(input[0] * weights[0]); 34 const R expected_output = sigmoid(input[0] * weights[0]);
29 printf("\nOutput: %f, Expected: %f\n", output[0], expected_output); 35 printf(
36 "\n[neuralnet_perceptron_test] Output: %f, Expected: %f\n", output[0],
37 expected_output);
30 TEST_TRUE(double_eq(output[0], expected_output, EPS)); 38 TEST_TRUE(double_eq(output[0], expected_output, EPS));
31 39
32 nnDeleteQueryObject(&query); 40 nnDeleteQueryObject(&query);
@@ -34,53 +42,66 @@ TEST_CASE(neuralnet_perceptron_test) {
34} 42}
35 43
36TEST_CASE(neuralnet_xor_test) { 44TEST_CASE(neuralnet_xor_test) {
37 const int num_layers = 2; 45 // First (hidden) layer.
38 const int layer_sizes[] = { 2, 2, 1 }; 46 const R weights0[] = {1, 1, 1, 1};
39 const nnActivation layer_activations[] = { nnRelu, nnIdentity }; 47 const R biases0[] = {0, -1};
40 const R weights[] = { 48 // Second (output) layer.
41 1, 1, 1, 1, // First (hidden) layer. 49 const R weights1[] = {1, -2};
42 1, -2 // Second (output) layer. 50 const R biases1[] = {0};
43 }; 51 // Network.
44 const R biases[] = { 52 const int num_layers = 3;
45 0, -1, // First (hidden) layer. 53 const int input_size = 2;
46 0 // Second (output) layer. 54 const nnLayer layers[] = {
55 {.type = nnLinear,
56 .linear =
57 {.weights = nnMatrixFromArray(2, 2, weights0),
58 .biases = nnMatrixFromArray(1, 2, biases0)}},
59 {.type = nnRelu},
60 {.type = nnLinear,
61 .linear =
62 {.weights = nnMatrixFromArray(2, 1, weights1),
63 .biases = nnMatrixFromArray(1, 1, biases1)}},
47 }; 64 };
48 65
49 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 66 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
50 assert(net); 67 assert(net);
51 nnSetWeights(net, weights);
52 nnSetBiases(net, biases);
53 68
54 // First layer weights. 69 // First layer weights.
55 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 0), 1); 70 TEST_EQUAL(nnMatrixAt(&net->layers[0].linear.weights, 0, 0), 1);
56 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 1), 1); 71 TEST_EQUAL(nnMatrixAt(&net->layers[0].linear.weights, 0, 1), 1);
57 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 2), 1); 72 TEST_EQUAL(nnMatrixAt(&net->layers[0].linear.weights, 0, 2), 1);
58 TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 3), 1); 73 TEST_EQUAL(nnMatrixAt(&net->layers[0].linear.weights, 0, 3), 1);
59 // Second layer weights. 74 // Second linear layer (third layer) weights.
60 TEST_EQUAL(nnMatrixAt(&net->weights[1], 0, 0), 1); 75 TEST_EQUAL(nnMatrixAt(&net->layers[2].linear.weights, 0, 0), 1);
61 TEST_EQUAL(nnMatrixAt(&net->weights[1], 0, 1), -2); 76 TEST_EQUAL(nnMatrixAt(&net->layers[2].linear.weights, 0, 1), -2);
62 // First layer biases. 77 // First layer biases.
63 TEST_EQUAL(nnMatrixAt(&net->biases[0], 0, 0), 0); 78 TEST_EQUAL(nnMatrixAt(&net->layers[0].linear.biases, 0, 0), 0);
64 TEST_EQUAL(nnMatrixAt(&net->biases[0], 0, 1), -1); 79 TEST_EQUAL(nnMatrixAt(&net->layers[0].linear.biases, 0, 1), -1);
65 // Second layer biases. 80 // Second linear layer (third layer) biases.
66 TEST_EQUAL(nnMatrixAt(&net->biases[1], 0, 0), 0); 81 TEST_EQUAL(nnMatrixAt(&net->layers[2].linear.biases, 0, 0), 0);
67 82
68 // Test. 83 // Test.
69 84
70 #define M 4 85#define M 4
71 86
72 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M); 87 nnQueryObject* query = nnMakeQueryObject(net, M);
73 88
74 const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } }; 89 const R test_inputs[M][2] = {
90 {0., 0.},
91 {1., 0.},
92 {0., 1.},
93 {1., 1.}
94 };
75 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2); 95 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2);
76 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs); 96 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs);
77 nnQuery(net, query, &test_inputs_matrix); 97 nnQuery(net, query, &test_inputs_matrix);
78 98
79 const R expected_outputs[M] = { 0., 1., 1., 0. }; 99 const R expected_outputs[M] = {0., 1., 1., 0.};
80 for (int i = 0; i < M; ++i) { 100 for (int i = 0; i < M; ++i) {
81 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); 101 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
82 printf("\nInput: (%f, %f), Output: %f, Expected: %f\n", 102 printf(
83 test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]); 103 "\nInput: (%f, %f), Output: %f, Expected: %f\n", test_inputs[i][0],
104 test_inputs[i][1], test_output, expected_outputs[i]);
84 } 105 }
85 for (int i = 0; i < M; ++i) { 106 for (int i = 0; i < M; ++i) {
86 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); 107 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
diff --git a/src/lib/test/train_linear_perceptron_non_origin_test.c b/src/lib/test/train_linear_perceptron_non_origin_test.c
index 5a320ac..40a42e0 100644
--- a/src/lib/test/train_linear_perceptron_non_origin_test.c
+++ b/src/lib/test/train_linear_perceptron_non_origin_test.c
@@ -1,9 +1,8 @@
1#include <neuralnet/train.h> 1#include <neuralnet/train.h>
2 2
3#include "neuralnet_impl.h"
3#include <neuralnet/matrix.h> 4#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h> 5#include <neuralnet/neuralnet.h>
5#include "activation.h"
6#include "neuralnet_impl.h"
7 6
8#include "test.h" 7#include "test.h"
9#include "test_util.h" 8#include "test_util.h"
@@ -11,19 +10,21 @@
11#include <assert.h> 10#include <assert.h>
12 11
13TEST_CASE(neuralnet_train_linear_perceptron_non_origin_test) { 12TEST_CASE(neuralnet_train_linear_perceptron_non_origin_test) {
14 const int num_layers = 1; 13 const int num_layers = 1;
15 const int layer_sizes[] = { 1, 1 }; 14 const int input_size = 1;
16 const nnActivation layer_activations[] = { nnIdentity }; 15 const nnLayer layers[] = {
16 {.type = nnLinear, .linear = {.input_size = 1, .output_size = 1}}
17 };
17 18
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 19 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 20 assert(net);
20 21
21 // Train. 22// Train.
22 23
23 // Try to learn the Y = 2X + 1 line. 24// Try to learn the Y = 2X + 1 line.
24 #define N 2 25#define N 2
25 const R inputs[N] = { 0., 1. }; 26 const R inputs[N] = {0., 1.};
26 const R targets[N] = { 1., 3. }; 27 const R targets[N] = {1., 3.};
27 28
28 nnMatrix inputs_matrix = nnMatrixMake(N, 1); 29 nnMatrix inputs_matrix = nnMatrixMake(N, 1);
29 nnMatrix targets_matrix = nnMatrixMake(N, 1); 30 nnMatrix targets_matrix = nnMatrixMake(N, 1);
@@ -31,31 +32,32 @@ TEST_CASE(neuralnet_train_linear_perceptron_non_origin_test) {
31 nnMatrixInit(&targets_matrix, targets); 32 nnMatrixInit(&targets_matrix, targets);
32 33
33 nnTrainingParams params = { 34 nnTrainingParams params = {
34 .learning_rate = 0.7, 35 .learning_rate = 0.7,
35 .max_iterations = 20, 36 .max_iterations = 20,
36 .seed = 0, 37 .seed = 0,
37 .weight_init = nnWeightInit01, 38 .weight_init = nnWeightInit01,
38 .debug = false, 39 .debug = false,
39 }; 40 };
40 41
41 nnTrain(net, &inputs_matrix, &targets_matrix, &params); 42 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
42 43
43 const R weight = nnMatrixAt(&net->weights[0], 0, 0); 44 const R weight = nnMatrixAt(&net->layers[0].linear.weights, 0, 0);
44 const R expected_weight = 2.0; 45 const R expected_weight = 2.0;
45 printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); 46 printf(
47 "\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
46 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); 48 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
47 49
48 const R bias = nnMatrixAt(&net->biases[0], 0, 0); 50 const R bias = nnMatrixAt(&net->layers[0].linear.biases, 0, 0);
49 const R expected_bias = 1.0; 51 const R expected_bias = 1.0;
50 printf("Trained network bias: %f, Expected: %f\n", bias, expected_bias); 52 printf("Trained network bias: %f, Expected: %f\n", bias, expected_bias);
51 TEST_TRUE(double_eq(bias, expected_bias, WEIGHT_EPS)); 53 TEST_TRUE(double_eq(bias, expected_bias, WEIGHT_EPS));
52 54
53 // Test. 55 // Test.
54 56
55 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); 57 nnQueryObject* query = nnMakeQueryObject(net, 1);
56 58
57 const R test_input[] = { 2.3 }; 59 const R test_input[] = {2.3};
58 R test_output[1]; 60 R test_output[1];
59 nnQueryArray(net, query, test_input, test_output); 61 nnQueryArray(net, query, test_input, test_output);
60 62
61 const R expected_output = test_input[0] * expected_weight + expected_bias; 63 const R expected_output = test_input[0] * expected_weight + expected_bias;
diff --git a/src/lib/test/train_linear_perceptron_test.c b/src/lib/test/train_linear_perceptron_test.c
index 2b1336d..667643b 100644
--- a/src/lib/test/train_linear_perceptron_test.c
+++ b/src/lib/test/train_linear_perceptron_test.c
@@ -1,9 +1,8 @@
1#include <neuralnet/train.h> 1#include <neuralnet/train.h>
2 2
3#include "neuralnet_impl.h"
3#include <neuralnet/matrix.h> 4#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h> 5#include <neuralnet/neuralnet.h>
5#include "activation.h"
6#include "neuralnet_impl.h"
7 6
8#include "test.h" 7#include "test.h"
9#include "test_util.h" 8#include "test_util.h"
@@ -11,19 +10,21 @@
11#include <assert.h> 10#include <assert.h>
12 11
13TEST_CASE(neuralnet_train_linear_perceptron_test) { 12TEST_CASE(neuralnet_train_linear_perceptron_test) {
14 const int num_layers = 1; 13 const int num_layers = 1;
15 const int layer_sizes[] = { 1, 1 }; 14 const int input_size = 1;
16 const nnActivation layer_activations[] = { nnIdentity }; 15 const nnLayer layers[] = {
16 {.type = nnLinear, .linear = {.input_size = 1, .output_size = 1}}
17 };
17 18
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 19 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 20 assert(net);
20 21
21 // Train. 22// Train.
22 23
23 // Try to learn the Y=X line. 24// Try to learn the Y=X line.
24 #define N 2 25#define N 2
25 const R inputs[N] = { 0., 1. }; 26 const R inputs[N] = {0., 1.};
26 const R targets[N] = { 0., 1. }; 27 const R targets[N] = {0., 1.};
27 28
28 nnMatrix inputs_matrix = nnMatrixMake(N, 1); 29 nnMatrix inputs_matrix = nnMatrixMake(N, 1);
29 nnMatrix targets_matrix = nnMatrixMake(N, 1); 30 nnMatrix targets_matrix = nnMatrixMake(N, 1);
@@ -31,26 +32,27 @@ TEST_CASE(neuralnet_train_linear_perceptron_test) {
31 nnMatrixInit(&targets_matrix, targets); 32 nnMatrixInit(&targets_matrix, targets);
32 33
33 nnTrainingParams params = { 34 nnTrainingParams params = {
34 .learning_rate = 0.7, 35 .learning_rate = 0.7,
35 .max_iterations = 10, 36 .max_iterations = 10,
36 .seed = 0, 37 .seed = 0,
37 .weight_init = nnWeightInit01, 38 .weight_init = nnWeightInit01,
38 .debug = false, 39 .debug = false,
39 }; 40 };
40 41
41 nnTrain(net, &inputs_matrix, &targets_matrix, &params); 42 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
42 43
43 const R weight = nnMatrixAt(&net->weights[0], 0, 0); 44 const R weight = nnMatrixAt(&net->layers[0].linear.weights, 0, 0);
44 const R expected_weight = 1.0; 45 const R expected_weight = 1.0;
45 printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); 46 printf(
47 "\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
46 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); 48 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
47 49
48 // Test. 50 // Test.
49 51
50 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); 52 nnQueryObject* query = nnMakeQueryObject(net, 1);
51 53
52 const R test_input[] = { 2.3 }; 54 const R test_input[] = {2.3};
53 R test_output[1]; 55 R test_output[1];
54 nnQueryArray(net, query, test_input, test_output); 56 nnQueryArray(net, query, test_input, test_output);
55 57
56 const R expected_output = test_input[0]; 58 const R expected_output = test_input[0];
diff --git a/src/lib/test/train_sigmoid_test.c b/src/lib/test/train_sigmoid_test.c
index 588e7ca..39a84b0 100644
--- a/src/lib/test/train_sigmoid_test.c
+++ b/src/lib/test/train_sigmoid_test.c
@@ -1,9 +1,9 @@
1#include <neuralnet/train.h> 1#include <neuralnet/train.h>
2 2
3#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h>
5#include "activation.h" 3#include "activation.h"
6#include "neuralnet_impl.h" 4#include "neuralnet_impl.h"
5#include <neuralnet/matrix.h>
6#include <neuralnet/neuralnet.h>
7 7
8#include "test.h" 8#include "test.h"
9#include "test_util.h" 9#include "test_util.h"
@@ -11,21 +11,24 @@
11#include <assert.h> 11#include <assert.h>
12 12
13TEST_CASE(neuralnet_train_sigmoid_test) { 13TEST_CASE(neuralnet_train_sigmoid_test) {
14 const int num_layers = 1; 14 const int num_layers = 2;
15 const int layer_sizes[] = { 1, 1 }; 15 const int input_size = 1;
16 const nnActivation layer_activations[] = { nnSigmoid }; 16 const nnLayer layers[] = {
17 {.type = nnLinear, .linear = {.input_size = 1, .output_size = 1}},
18 {.type = nnSigmoid},
19 };
17 20
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 21 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 22 assert(net);
20 23
21 // Train. 24// Train.
22 25
23 // Try to learn the sigmoid function. 26// Try to learn the sigmoid function.
24 #define N 3 27#define N 3
25 R inputs[N]; 28 R inputs[N];
26 R targets[N]; 29 R targets[N];
27 for (int i = 0; i < N; ++i) { 30 for (int i = 0; i < N; ++i) {
28 inputs[i] = lerp(-1, +1, (R)i / (R)(N-1)); 31 inputs[i] = lerp(-1, +1, (R)i / (R)(N - 1));
29 targets[i] = sigmoid(inputs[i]); 32 targets[i] = sigmoid(inputs[i]);
30 } 33 }
31 34
@@ -35,29 +38,30 @@ TEST_CASE(neuralnet_train_sigmoid_test) {
35 nnMatrixInit(&targets_matrix, targets); 38 nnMatrixInit(&targets_matrix, targets);
36 39
37 nnTrainingParams params = { 40 nnTrainingParams params = {
38 .learning_rate = 0.9, 41 .learning_rate = 0.9,
39 .max_iterations = 100, 42 .max_iterations = 100,
40 .seed = 0, 43 .seed = 0,
41 .weight_init = nnWeightInit01, 44 .weight_init = nnWeightInit01,
42 .debug = false, 45 .debug = false,
43 }; 46 };
44 47
45 nnTrain(net, &inputs_matrix, &targets_matrix, &params); 48 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
46 49
47 const R weight = nnMatrixAt(&net->weights[0], 0, 0); 50 const R weight = nnMatrixAt(&net->layers[0].linear.weights, 0, 0);
48 const R expected_weight = 1.0; 51 const R expected_weight = 1.0;
49 printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight); 52 printf(
53 "\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
50 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS)); 54 TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
51 55
52 // Test. 56 // Test.
53 57
54 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1); 58 nnQueryObject* query = nnMakeQueryObject(net, 1);
55 59
56 const R test_input[] = { 0.3 }; 60 const R test_input[] = {0.3};
57 R test_output[1]; 61 R test_output[1];
58 nnQueryArray(net, query, test_input, test_output); 62 nnQueryArray(net, query, test_input, test_output);
59 63
60 const R expected_output = 0.574442516811659; // sigmoid(0.3) 64 const R expected_output = 0.574442516811659; // sigmoid(0.3)
61 printf("Output: %f, Expected: %f\n", test_output[0], expected_output); 65 printf("Output: %f, Expected: %f\n", test_output[0], expected_output);
62 TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS)); 66 TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS));
63 67
diff --git a/src/lib/test/train_xor_test.c b/src/lib/test/train_xor_test.c
index 6ddc6e0..78695a3 100644
--- a/src/lib/test/train_xor_test.c
+++ b/src/lib/test/train_xor_test.c
@@ -1,9 +1,9 @@
1#include <neuralnet/train.h> 1#include <neuralnet/train.h>
2 2
3#include <neuralnet/matrix.h>
4#include <neuralnet/neuralnet.h>
5#include "activation.h" 3#include "activation.h"
6#include "neuralnet_impl.h" 4#include "neuralnet_impl.h"
5#include <neuralnet/matrix.h>
6#include <neuralnet/neuralnet.h>
7 7
8#include "test.h" 8#include "test.h"
9#include "test_util.h" 9#include "test_util.h"
@@ -11,18 +11,27 @@
11#include <assert.h> 11#include <assert.h>
12 12
13TEST_CASE(neuralnet_train_xor_test) { 13TEST_CASE(neuralnet_train_xor_test) {
14 const int num_layers = 2; 14 const int num_layers = 3;
15 const int layer_sizes[] = { 2, 2, 1 }; 15 const int input_size = 2;
16 const nnActivation layer_activations[] = { nnRelu, nnIdentity }; 16 const nnLayer layers[] = {
17 {.type = nnLinear, .linear = {.input_size = 2, .output_size = 2}},
18 {.type = nnRelu},
19 {.type = nnLinear, .linear = {.input_size = 2, .output_size = 1}}
20 };
17 21
18 nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations); 22 nnNeuralNetwork* net = nnMakeNet(layers, num_layers, input_size);
19 assert(net); 23 assert(net);
20 24
21 // Train. 25 // Train.
22 26
23 #define N 4 27#define N 4
24 const R inputs[N][2] = { { 0., 0. }, { 0., 1. }, { 1., 0. }, { 1., 1. } }; 28 const R inputs[N][2] = {
25 const R targets[N] = { 0., 1., 1., 0. }; 29 {0., 0.},
30 {0., 1.},
31 {1., 0.},
32 {1., 1.}
33 };
34 const R targets[N] = {0., 1., 1., 0.};
26 35
27 nnMatrix inputs_matrix = nnMatrixMake(N, 2); 36 nnMatrix inputs_matrix = nnMatrixMake(N, 2);
28 nnMatrix targets_matrix = nnMatrixMake(N, 1); 37 nnMatrix targets_matrix = nnMatrixMake(N, 1);
@@ -30,31 +39,37 @@ TEST_CASE(neuralnet_train_xor_test) {
30 nnMatrixInit(&targets_matrix, targets); 39 nnMatrixInit(&targets_matrix, targets);
31 40
32 nnTrainingParams params = { 41 nnTrainingParams params = {
33 .learning_rate = 0.1, 42 .learning_rate = 0.1,
34 .max_iterations = 500, 43 .max_iterations = 500,
35 .seed = 0, 44 .seed = 0,
36 .weight_init = nnWeightInit01, 45 .weight_init = nnWeightInit01,
37 .debug = false, 46 .debug = false,
38 }; 47 };
39 48
40 nnTrain(net, &inputs_matrix, &targets_matrix, &params); 49 nnTrain(net, &inputs_matrix, &targets_matrix, &params);
41 50
42 // Test. 51 // Test.
43 52
44 #define M 4 53#define M 4
45 54
46 nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M); 55 nnQueryObject* query = nnMakeQueryObject(net, M);
47 56
48 const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } }; 57 const R test_inputs[M][2] = {
58 {0., 0.},
59 {1., 0.},
60 {0., 1.},
61 {1., 1.}
62 };
49 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2); 63 nnMatrix test_inputs_matrix = nnMatrixMake(M, 2);
50 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs); 64 nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs);
51 nnQuery(net, query, &test_inputs_matrix); 65 nnQuery(net, query, &test_inputs_matrix);
52 66
53 const R expected_outputs[M] = { 0., 1., 1., 0. }; 67 const R expected_outputs[M] = {0., 1., 1., 0.};
54 for (int i = 0; i < M; ++i) { 68 for (int i = 0; i < M; ++i) {
55 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); 69 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
56 printf("\nInput: (%f, %f), Output: %f, Expected: %f\n", 70 printf(
57 test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]); 71 "\nInput: (%f, %f), Output: %f, Expected: %f\n", test_inputs[i][0],
72 test_inputs[i][1], test_output, expected_outputs[i]);
58 } 73 }
59 for (int i = 0; i < M; ++i) { 74 for (int i = 0; i < M; ++i) {
60 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0); 75 const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);