aboutsummaryrefslogtreecommitdiff
path: root/src/bin/mnist/src/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/mnist/src/main.c')
-rw-r--r--src/bin/mnist/src/main.c473
1 files changed, 473 insertions, 0 deletions
diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c
new file mode 100644
index 0000000..4d268ac
--- /dev/null
+++ b/src/bin/mnist/src/main.c
@@ -0,0 +1,473 @@
1#include <neuralnet/matrix.h>
2#include <neuralnet/neuralnet.h>
3#include <neuralnet/train.h>
4
5#include <zlib.h>
6
7#include <assert.h>
8#include <bsd/string.h>
9#include <linux/limits.h>
10#include <math.h>
11#include <stdbool.h>
12#include <stdint.h>
13#include <stdio.h>
14#include <stdlib.h>
15
16static const int TRAIN_ITERATIONS = 100;
17
18static const int32_t IMAGE_FILE_MAGIC = 0x00000803;
19static const int32_t LABEL_FILE_MAGIC = 0x00000801;
20
21// Inputs of 0 cancel weights during training. This value is used to rescale the
22// input pixels from [0,255] to [PIXEL_LOWER_BOUND, 1.0].
23static const double PIXEL_LOWER_BOUND = 0.01;
24
25// Scale the outputs to (0,1) since the sigmoid cannot produce 0 or 1.
26static const double LABEL_LOWER_BOUND = 0.01;
27static const double LABEL_UPPER_BOUND = 0.99;
28
29// Epsilon used to compare R values.
30static const double EPS = 1e-10;
31
32#define min(a,b) ((a) < (b) ? (a) : (b))
33
34typedef struct ImageSet {
35 nnMatrix images; // Images flattened into row vectors of the matrix.
36 nnMatrix labels; // One-hot-encoded labels.
37 int count; // Number of images and labels.
38 int rows; // Rows in an image.
39 int cols; // Columns in an image.
40} ImageSet;
41
42static void usage(const char* argv0) {
43 fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0);
44 fprintf(stderr, "\n");
45 fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n");
46}
47
48static bool R_eq(R a, R b) {
49 return fabs(a-b) <= EPS;
50}
51
52static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) {
53 assert(images);
54 assert((0 <= image_index) && (image_index < images->rows));
55
56 // Top line.
57 for (int j = 0; j < cols/2; ++j) {
58 printf(" -");
59 }
60 printf("\n");
61
62 // Image.
63 const R* value = nnMatrixRow(images, image_index);
64 for (int i = 0; i < rows; ++i) {
65 printf("|");
66 for (int j = 0; j < cols; ++j) {
67 if (*value > 0.8) {
68 printf("#");
69 } else if (*value > 0.5) {
70 printf("*");
71 }
72 else if (*value > PIXEL_LOWER_BOUND) {
73 printf(":");
74 } else if (*value == 0.0) {
75 // Values should not be exactly 0, otherwise they cancel out weights
76 // during training.
77 printf("X");
78 } else {
79 printf(" ");
80 }
81 value++;
82 }
83 printf("|\n");
84 }
85
86 // Bottom line.
87 for (int j = 0; j < cols/2; ++j) {
88 printf(" -");
89 }
90 printf("\n");
91}
92
93static void PrintLabel(const nnMatrix* labels, int label_index) {
94 assert(labels);
95 assert((0 <= label_index) && (label_index < labels->rows));
96
97 // Compute the label from the one-hot encoding.
98 const R* value = nnMatrixRow(labels, label_index);
99 int label = -1;
100 for (int i = 0; i < 10; ++i) {
101 if (R_eq(*value++, LABEL_UPPER_BOUND)) {
102 label = i;
103 break;
104 }
105 }
106 assert((0 <= label) && (label <= 9));
107
108 printf("Label: %d ( ", label);
109 value = nnMatrixRow(labels, label_index);
110 for (int i = 0; i < 10; ++i) {
111 printf("%.3f ", *value++);
112 }
113 printf(")\n");
114}
115
116static R lerp(R a, R b, R t) {
117 return a + t*(b-a);
118}
119
120/// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0].
121static R FormatPixel(uint8_t pixel) {
122 const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND;
123 assert(value >= PIXEL_LOWER_BOUND);
124 assert(value <= 1.0);
125 return value;
126}
127
128/// Rescales a one-hot-encoded label value to (0,1).
129static R FormatLabel(R label) {
130 const R value = lerp(LABEL_LOWER_BOUND, LABEL_UPPER_BOUND, label);
131 assert(value > 0.0);
132 assert(value < 1.0);
133 return value;
134}
135
136static int32_t ReverseEndian32(int32_t x) {
137 const int32_t x0 = x & 0xff;
138 const int32_t x1 = (x >> 8) & 0xff;
139 const int32_t x2 = (x >> 16) & 0xff;
140 const int32_t x3 = (x >> 24) & 0xff;
141 return (x0 << 24) | (x1 << 16) | (x2 << 8) | x3;
142}
143
144static void ImageToMatrix(
145 const uint8_t* pixels, int num_pixels, int row, nnMatrix* images) {
146 assert(pixels);
147 assert(images);
148
149 for (int i = 0; i < num_pixels; ++i) {
150 const R pixel = FormatPixel(pixels[i]);
151 nnMatrixSet(images, row, i, pixel);
152 }
153}
154
155static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) {
156 assert(images_file != Z_NULL);
157 assert(image_set);
158
159 bool success = false;
160
161 uint8_t* pixels = 0;
162
163 int32_t magic, total_images, rows, cols;
164 if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) ||
165 (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) ||
166 (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) ||
167 (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) {
168 fprintf(stderr, "Failed to read header\n");
169 goto cleanup;
170 }
171
172 magic = ReverseEndian32(magic);
173 total_images = ReverseEndian32(total_images);
174 rows = ReverseEndian32(rows);
175 cols = ReverseEndian32(cols);
176
177 if (magic != IMAGE_FILE_MAGIC) {
178 fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n",
179 magic, IMAGE_FILE_MAGIC);
180 goto cleanup;
181 }
182
183 printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n",
184 magic, total_images, rows, cols);
185
186 total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images;
187
188 // Images are flattened into single row vectors.
189 const int num_pixels = rows * cols;
190 image_set->images = nnMatrixMake(total_images, num_pixels);
191 image_set->count = total_images;
192 image_set->rows = rows;
193 image_set->cols = cols;
194
195 pixels = calloc(1, num_pixels);
196 if (!pixels) {
197 fprintf(stderr, "Failed to allocate image buffer\n");
198 goto cleanup;
199 }
200
201 for (int i = 0; i < total_images; ++i) {
202 const int bytes_read = gzread(images_file, pixels, num_pixels);
203 if (bytes_read < num_pixels) {
204 fprintf(stderr, "Failed to read image %d\n", i);
205 goto cleanup;
206 }
207 ImageToMatrix(pixels, num_pixels, i, &image_set->images);
208 }
209
210 success = true;
211
212cleanup:
213 if (pixels) {
214 free(pixels);
215 }
216 if (!success) {
217 nnMatrixDel(&image_set->images);
218 }
219 return success;
220}
221
222static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) {
223 assert(labels_bytes);
224 assert(labels);
225 assert(labels->rows == num_labels);
226 assert(labels->cols == 10);
227
228 static const R one_hot[10][10] = {
229 { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
230 { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 },
231 { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
232 { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 },
233 { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 },
234 { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 },
235 { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 },
236 { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 },
237 { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 },
238 { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 },
239 };
240
241 R* value = labels->values;
242
243 for (int i = 0; i < num_labels; ++i) {
244 const uint8_t label = labels_bytes[i];
245 const R* one_hot_value = one_hot[label];
246
247 for (int j = 0; j < 10; ++j) {
248 *value++ = FormatLabel(*one_hot_value++);
249 }
250 }
251}
252
253static int OneHotDecode(const nnMatrix* label_matrix) {
254 assert(label_matrix);
255 assert(label_matrix->cols == 1);
256 assert(label_matrix->rows == 10);
257
258 R max_value = 0;
259 int pos_max = 0;
260 for (int i = 0; i < 10; ++i) {
261 const R value = nnMatrixAt(label_matrix, 0, i);
262 if (value > max_value) {
263 max_value = value;
264 pos_max = i;
265 }
266 }
267 assert(pos_max >= 0);
268 assert(pos_max <= 10);
269 return pos_max;
270}
271
272static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) {
273 assert(labels_file != Z_NULL);
274 assert(image_set != 0);
275
276 bool success = false;
277
278 uint8_t* labels = 0;
279
280 int32_t magic, total_labels;
281 if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) ||
282 (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) {
283 fprintf(stderr, "Failed to read header\n");
284 goto cleanup;
285 }
286
287 magic = ReverseEndian32(magic);
288 total_labels = ReverseEndian32(total_labels);
289
290 if (magic != LABEL_FILE_MAGIC) {
291 fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n",
292 magic, LABEL_FILE_MAGIC);
293 goto cleanup;
294 }
295
296 printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels);
297
298 total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels;
299
300 assert(image_set->count == total_labels);
301
302 // One-hot encoding of labels, 10 values (digits) per label.
303 image_set->labels = nnMatrixMake(total_labels, 10);
304
305 labels = calloc(total_labels, sizeof(uint8_t));
306 if (!labels) {
307 fprintf(stderr, "Failed to allocate labels buffer\n");
308 goto cleanup;
309 }
310
311 if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) {
312 fprintf(stderr, "Failed to read labels\n");
313 goto cleanup;
314 }
315
316 OneHotEncode(labels, total_labels, &image_set->labels);
317
318 success = true;
319
320cleanup:
321 if (labels) {
322 free(labels);
323 }
324 if (!success) {
325 nnMatrixDel(&image_set->labels);
326 }
327 return success;
328}
329
330int main(int argc, const char** argv) {
331 if (argc < 2) {
332 usage(argv[0]);
333 return 1;
334 }
335
336 bool success = false;
337
338 gzFile train_images_file = Z_NULL;
339 gzFile train_labels_file = Z_NULL;
340 gzFile test_images_file = Z_NULL;
341 gzFile test_labels_file = Z_NULL;
342 ImageSet train_set = { 0 };
343 ImageSet test_set = { 0 };
344 nnNeuralNetwork* net = 0;
345 nnQueryObject* query = 0;
346
347 const char* mnist_files_dir = argv[1];
348 const int max_num_images = argc > 2 ? atoi(argv[2]) : -1;
349
350 char train_labels_path[PATH_MAX];
351 char train_images_path[PATH_MAX];
352 char test_labels_path[PATH_MAX];
353 char test_images_path[PATH_MAX];
354 strlcpy(train_labels_path, mnist_files_dir, PATH_MAX);
355 strlcpy(train_images_path, mnist_files_dir, PATH_MAX);
356 strlcpy(test_labels_path, mnist_files_dir, PATH_MAX);
357 strlcpy(test_images_path, mnist_files_dir, PATH_MAX);
358 strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX);
359 strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX);
360 strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX);
361 strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX);
362
363 train_images_file = gzopen(train_images_path, "r");
364 if (train_images_file == Z_NULL) {
365 fprintf(stderr, "Failed to open file: %s\n", train_images_path);
366 goto cleanup;
367 }
368
369 train_labels_file = gzopen(train_labels_path, "r");
370 if (train_labels_file == Z_NULL) {
371 fprintf(stderr, "Failed to open file: %s\n", train_labels_path);
372 goto cleanup;
373 }
374
375 test_images_file = gzopen(test_images_path, "r");
376 if (test_images_file == Z_NULL) {
377 fprintf(stderr, "Failed to open file: %s\n", test_images_path);
378 goto cleanup;
379 }
380
381 test_labels_file = gzopen(test_labels_path, "r");
382 if (test_labels_file == Z_NULL) {
383 fprintf(stderr, "Failed to open file: %s\n", test_labels_path);
384 goto cleanup;
385 }
386
387 if (!ReadImages(train_images_file, max_num_images, &train_set)) {
388 goto cleanup;
389 }
390 if (!ReadLabels(train_labels_file, max_num_images, &train_set)) {
391 goto cleanup;
392 }
393
394 if (!ReadImages(test_images_file, max_num_images, &test_set)) {
395 goto cleanup;
396 }
397 if (!ReadLabels(test_labels_file, max_num_images, &test_set)) {
398 goto cleanup;
399 }
400
401 printf("\nTraining image/label pair examples:\n");
402 for (int i = 0; i < min(3, train_set.images.rows); ++i) {
403 PrintImage(&train_set.images, train_set.rows, train_set.cols, i);
404 PrintLabel(&train_set.labels, i);
405 printf("\n");
406 }
407
408 // Network definition.
409 const int image_size_pixels = train_set.rows * train_set.cols;
410 const int num_layers = 2;
411 const int layer_sizes[3] = { image_size_pixels, 100, 10 };
412 const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid };
413 if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) {
414 fprintf(stderr, "Failed to create neural network\n");
415 goto cleanup;
416 }
417
418 // Train.
419 printf("Training with up to %d images from the data set\n\n", max_num_images);
420 const nnTrainingParams training_params = {
421 .learning_rate = 0.1,
422 .max_iterations = TRAIN_ITERATIONS,
423 .seed = 0,
424 .weight_init = nnWeightInitNormal,
425 .debug = true,
426 };
427 nnTrain(net, &train_set.images, &train_set.labels, &training_params);
428
429 // Test.
430 int hits = 0;
431 query = nnMakeQueryObject(net, /*num_inputs=*/1);
432 for (int i = 0; i < test_set.count; ++i) {
433 const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1);
434 const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1);
435
436 nnQuery(net, query, &test_image);
437
438 const int test_label_expected = OneHotDecode(&test_label);
439 const int test_label_actual = OneHotDecode(nnNetOutputs(query));
440
441 if (test_label_actual == test_label_expected) {
442 ++hits;
443 }
444 }
445 const R hit_ratio = (R)hits / (R)test_set.count;
446 printf("Test images: %d\n", test_set.count);
447 printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100);
448
449 success = true;
450
451cleanup:
452 if (query) {
453 nnDeleteQueryObject(&query);
454 }
455 if (net) {
456 nnDeleteNet(&net);
457 }
458 nnMatrixDel(&train_set.images);
459 nnMatrixDel(&test_set.images);
460 if (train_images_file != Z_NULL) {
461 gzclose(train_images_file);
462 }
463 if (train_labels_file != Z_NULL) {
464 gzclose(train_labels_file);
465 }
466 if (test_images_file != Z_NULL) {
467 gzclose(test_images_file);
468 }
469 if (test_labels_file != Z_NULL) {
470 gzclose(test_labels_file);
471 }
472 return success ? 0 : 1;
473}