aboutsummaryrefslogtreecommitdiff
path: root/src/bin
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2023-12-16 10:21:16 -0800
committer3gg <3gg@shellblade.net>2023-12-16 10:21:16 -0800
commit653e98e029a0d0f110b0ac599e50406060bb0f87 (patch)
tree6f909215218f6720266bde1b3f49aeddad8b1da3 /src/bin
parent3df7b6fb0c65295eed4590e6f166d60e89b3c68e (diff)
Decouple activations from linear layer.
Diffstat (limited to 'src/bin')
-rw-r--r--src/bin/mnist/src/main.c195
1 files changed, 108 insertions, 87 deletions
diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c
index 9aa3ce5..53e0197 100644
--- a/src/bin/mnist/src/main.c
+++ b/src/bin/mnist/src/main.c
@@ -29,32 +29,35 @@ static const double LABEL_UPPER_BOUND = 0.99;
29// Epsilon used to compare R values. 29// Epsilon used to compare R values.
30static const double EPS = 1e-10; 30static const double EPS = 1e-10;
31 31
32#define min(a,b) ((a) < (b) ? (a) : (b)) 32#define min(a, b) ((a) < (b) ? (a) : (b))
33 33
34typedef struct ImageSet { 34typedef struct ImageSet {
35 nnMatrix images; // Images flattened into row vectors of the matrix. 35 nnMatrix images; // Images flattened into row vectors of the matrix.
36 nnMatrix labels; // One-hot-encoded labels. 36 nnMatrix labels; // One-hot-encoded labels.
37 int count; // Number of images and labels. 37 int count; // Number of images and labels.
38 int rows; // Rows in an image. 38 int rows; // Rows in an image.
39 int cols; // Columns in an image. 39 int cols; // Columns in an image.
40} ImageSet; 40} ImageSet;
41 41
42static void usage(const char* argv0) { 42static void usage(const char* argv0) {
43 fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0); 43 fprintf(
44 stderr, "Usage: %s <path to mnist files directory> [num images]\n",
45 argv0);
44 fprintf(stderr, "\n"); 46 fprintf(stderr, "\n");
45 fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); 47 fprintf(
48 stderr,
49 " Use -1 for [num images] to use all the images in the data set\n");
46} 50}
47 51
48static bool R_eq(R a, R b) { 52static bool R_eq(R a, R b) { return fabs(a - b) <= EPS; }
49 return fabs(a-b) <= EPS;
50}
51 53
52static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { 54static void PrintImage(
55 const nnMatrix* images, int rows, int cols, int image_index) {
53 assert(images); 56 assert(images);
54 assert((0 <= image_index) && (image_index < images->rows)); 57 assert((0 <= image_index) && (image_index < images->rows));
55 58
56 // Top line. 59 // Top line.
57 for (int j = 0; j < cols/2; ++j) { 60 for (int j = 0; j < cols / 2; ++j) {
58 printf(" -"); 61 printf(" -");
59 } 62 }
60 printf("\n"); 63 printf("\n");
@@ -68,8 +71,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind
68 printf("#"); 71 printf("#");
69 } else if (*value > 0.5) { 72 } else if (*value > 0.5) {
70 printf("*"); 73 printf("*");
71 } 74 } else if (*value > PIXEL_LOWER_BOUND) {
72 else if (*value > PIXEL_LOWER_BOUND) {
73 printf(":"); 75 printf(":");
74 } else if (*value == 0.0) { 76 } else if (*value == 0.0) {
75 // Values should not be exactly 0, otherwise they cancel out weights 77 // Values should not be exactly 0, otherwise they cancel out weights
@@ -84,7 +86,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind
84 } 86 }
85 87
86 // Bottom line. 88 // Bottom line.
87 for (int j = 0; j < cols/2; ++j) { 89 for (int j = 0; j < cols / 2; ++j) {
88 printf(" -"); 90 printf(" -");
89 } 91 }
90 printf("\n"); 92 printf("\n");
@@ -96,7 +98,7 @@ static void PrintLabel(const nnMatrix* labels, int label_index) {
96 98
97 // Compute the label from the one-hot encoding. 99 // Compute the label from the one-hot encoding.
98 const R* value = nnMatrixRow(labels, label_index); 100 const R* value = nnMatrixRow(labels, label_index);
99 int label = -1; 101 int label = -1;
100 for (int i = 0; i < 10; ++i) { 102 for (int i = 0; i < 10; ++i) {
101 if (R_eq(*value++, LABEL_UPPER_BOUND)) { 103 if (R_eq(*value++, LABEL_UPPER_BOUND)) {
102 label = i; 104 label = i;
@@ -113,13 +115,12 @@ static void PrintLabel(const nnMatrix* labels, int label_index) {
113 printf(")\n"); 115 printf(")\n");
114} 116}
115 117
116static R lerp(R a, R b, R t) { 118static R lerp(R a, R b, R t) { return a + t * (b - a); }
117 return a + t*(b-a);
118}
119 119
120/// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. 120/// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0].
121static R FormatPixel(uint8_t pixel) { 121static R FormatPixel(uint8_t pixel) {
122 const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; 122 const R value =
123 (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND;
123 assert(value >= PIXEL_LOWER_BOUND); 124 assert(value >= PIXEL_LOWER_BOUND);
124 assert(value <= 1.0); 125 assert(value <= 1.0);
125 return value; 126 return value;
@@ -152,7 +153,8 @@ static void ImageToMatrix(
152 } 153 }
153} 154}
154 155
155static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { 156static bool ReadImages(
157 gzFile images_file, int max_num_images, ImageSet* image_set) {
156 assert(images_file != Z_NULL); 158 assert(images_file != Z_NULL);
157 assert(image_set); 159 assert(image_set);
158 160
@@ -161,36 +163,41 @@ static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_s
161 uint8_t* pixels = 0; 163 uint8_t* pixels = 0;
162 164
163 int32_t magic, total_images, rows, cols; 165 int32_t magic, total_images, rows, cols;
164 if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || 166 if ((gzread(images_file, (char*)&magic, sizeof(int32_t)) !=
165 (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || 167 sizeof(int32_t)) ||
166 (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || 168 (gzread(images_file, (char*)&total_images, sizeof(int32_t)) !=
167 (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { 169 sizeof(int32_t)) ||
170 (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) ||
171 (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t))) {
168 fprintf(stderr, "Failed to read header\n"); 172 fprintf(stderr, "Failed to read header\n");
169 goto cleanup; 173 goto cleanup;
170 } 174 }
171 175
172 magic = ReverseEndian32(magic); 176 magic = ReverseEndian32(magic);
173 total_images = ReverseEndian32(total_images); 177 total_images = ReverseEndian32(total_images);
174 rows = ReverseEndian32(rows); 178 rows = ReverseEndian32(rows);
175 cols = ReverseEndian32(cols); 179 cols = ReverseEndian32(cols);
176 180
177 if (magic != IMAGE_FILE_MAGIC) { 181 if (magic != IMAGE_FILE_MAGIC) {
178 fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", 182 fprintf(
179 magic, IMAGE_FILE_MAGIC); 183 stderr, "Magic number mismatch. Got %x, expected: %x\n", magic,
184 IMAGE_FILE_MAGIC);
180 goto cleanup; 185 goto cleanup;
181 } 186 }
182 187
183 printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", 188 printf(
184 magic, total_images, rows, cols); 189 "Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", magic,
190 total_images, rows, cols);
185 191
186 total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; 192 total_images =
193 max_num_images >= 0 ? min(total_images, max_num_images) : total_images;
187 194
188 // Images are flattened into single row vectors. 195 // Images are flattened into single row vectors.
189 const int num_pixels = rows * cols; 196 const int num_pixels = rows * cols;
190 image_set->images = nnMatrixMake(total_images, num_pixels); 197 image_set->images = nnMatrixMake(total_images, num_pixels);
191 image_set->count = total_images; 198 image_set->count = total_images;
192 image_set->rows = rows; 199 image_set->rows = rows;
193 image_set->cols = cols; 200 image_set->cols = cols;
194 201
195 pixels = calloc(1, num_pixels); 202 pixels = calloc(1, num_pixels);
196 if (!pixels) { 203 if (!pixels) {
@@ -219,30 +226,31 @@ cleanup:
219 return success; 226 return success;
220} 227}
221 228
222static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { 229static void OneHotEncode(
230 const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) {
223 assert(labels_bytes); 231 assert(labels_bytes);
224 assert(labels); 232 assert(labels);
225 assert(labels->rows == num_labels); 233 assert(labels->rows == num_labels);
226 assert(labels->cols == 10); 234 assert(labels->cols == 10);
227 235
228 static const R one_hot[10][10] = { 236 static const R one_hot[10][10] = {
229 { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 237 {1, 0, 0, 0, 0, 0, 0, 0, 0, 0},
230 { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, 238 {0, 1, 0, 0, 0, 0, 0, 0, 0, 0},
231 { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, 239 {0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
232 { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, 240 {0, 0, 0, 1, 0, 0, 0, 0, 0, 0},
233 { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, 241 {0, 0, 0, 0, 1, 0, 0, 0, 0, 0},
234 { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, 242 {0, 0, 0, 0, 0, 1, 0, 0, 0, 0},
235 { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, 243 {0, 0, 0, 0, 0, 0, 1, 0, 0, 0},
236 { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, 244 {0, 0, 0, 0, 0, 0, 0, 1, 0, 0},
237 { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, 245 {0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
238 { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 246 {0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
239 }; 247 };
240 248
241 R* value = labels->values; 249 R* value = labels->values;
242 250
243 for (int i = 0; i < num_labels; ++i) { 251 for (int i = 0; i < num_labels; ++i) {
244 const uint8_t label = labels_bytes[i]; 252 const uint8_t label = labels_bytes[i];
245 const R* one_hot_value = one_hot[label]; 253 const R* one_hot_value = one_hot[label];
246 254
247 for (int j = 0; j < 10; ++j) { 255 for (int j = 0; j < 10; ++j) {
248 *value++ = FormatLabel(*one_hot_value++); 256 *value++ = FormatLabel(*one_hot_value++);
@@ -255,13 +263,13 @@ static int OneHotDecode(const nnMatrix* label_matrix) {
255 assert(label_matrix->cols == 10); 263 assert(label_matrix->cols == 10);
256 assert(label_matrix->rows == 1); 264 assert(label_matrix->rows == 1);
257 265
258 R max_value = 0; 266 R max_value = 0;
259 int pos_max = 0; 267 int pos_max = 0;
260 for (int i = 0; i < 10; ++i) { 268 for (int i = 0; i < 10; ++i) {
261 const R value = nnMatrixAt(label_matrix, 0, i); 269 const R value = nnMatrixAt(label_matrix, 0, i);
262 if (value > max_value) { 270 if (value > max_value) {
263 max_value = value; 271 max_value = value;
264 pos_max = i; 272 pos_max = i;
265 } 273 }
266 } 274 }
267 assert(pos_max >= 0); 275 assert(pos_max >= 0);
@@ -269,7 +277,8 @@ static int OneHotDecode(const nnMatrix* label_matrix) {
269 return pos_max; 277 return pos_max;
270} 278}
271 279
272static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { 280static bool ReadLabels(
281 gzFile labels_file, int max_num_labels, ImageSet* image_set) {
273 assert(labels_file != Z_NULL); 282 assert(labels_file != Z_NULL);
274 assert(image_set != 0); 283 assert(image_set != 0);
275 284
@@ -278,24 +287,28 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s
278 uint8_t* labels = 0; 287 uint8_t* labels = 0;
279 288
280 int32_t magic, total_labels; 289 int32_t magic, total_labels;
281 if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || 290 if ((gzread(labels_file, (char*)&magic, sizeof(int32_t)) !=
282 (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { 291 sizeof(int32_t)) ||
292 (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) !=
293 sizeof(int32_t))) {
283 fprintf(stderr, "Failed to read header\n"); 294 fprintf(stderr, "Failed to read header\n");
284 goto cleanup; 295 goto cleanup;
285 } 296 }
286 297
287 magic = ReverseEndian32(magic); 298 magic = ReverseEndian32(magic);
288 total_labels = ReverseEndian32(total_labels); 299 total_labels = ReverseEndian32(total_labels);
289 300
290 if (magic != LABEL_FILE_MAGIC) { 301 if (magic != LABEL_FILE_MAGIC) {
291 fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", 302 fprintf(
292 magic, LABEL_FILE_MAGIC); 303 stderr, "Magic number mismatch. Got %x, expected: %x\n", magic,
304 LABEL_FILE_MAGIC);
293 goto cleanup; 305 goto cleanup;
294 } 306 }
295 307
296 printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); 308 printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels);
297 309
298 total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; 310 total_labels =
311 max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels;
299 312
300 assert(image_set->count == total_labels); 313 assert(image_set->count == total_labels);
301 314
@@ -308,7 +321,8 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s
308 goto cleanup; 321 goto cleanup;
309 } 322 }
310 323
311 if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { 324 if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) !=
325 total_labels) {
312 fprintf(stderr, "Failed to read labels\n"); 326 fprintf(stderr, "Failed to read labels\n");
313 goto cleanup; 327 goto cleanup;
314 } 328 }
@@ -335,17 +349,17 @@ int main(int argc, const char** argv) {
335 349
336 bool success = false; 350 bool success = false;
337 351
338 gzFile train_images_file = Z_NULL; 352 gzFile train_images_file = Z_NULL;
339 gzFile train_labels_file = Z_NULL; 353 gzFile train_labels_file = Z_NULL;
340 gzFile test_images_file = Z_NULL; 354 gzFile test_images_file = Z_NULL;
341 gzFile test_labels_file = Z_NULL; 355 gzFile test_labels_file = Z_NULL;
342 ImageSet train_set = { 0 }; 356 ImageSet train_set = {0};
343 ImageSet test_set = { 0 }; 357 ImageSet test_set = {0};
344 nnNeuralNetwork* net = 0; 358 nnNeuralNetwork* net = 0;
345 nnQueryObject* query = 0; 359 nnQueryObject* query = 0;
346 360
347 const char* mnist_files_dir = argv[1]; 361 const char* mnist_files_dir = argv[1];
348 const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; 362 const int max_num_images = argc > 2 ? atoi(argv[2]) : -1;
349 363
350 char train_labels_path[PATH_MAX]; 364 char train_labels_path[PATH_MAX];
351 char train_images_path[PATH_MAX]; 365 char train_images_path[PATH_MAX];
@@ -353,12 +367,12 @@ int main(int argc, const char** argv) {
353 char test_images_path[PATH_MAX]; 367 char test_images_path[PATH_MAX];
354 strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); 368 strlcpy(train_labels_path, mnist_files_dir, PATH_MAX);
355 strlcpy(train_images_path, mnist_files_dir, PATH_MAX); 369 strlcpy(train_images_path, mnist_files_dir, PATH_MAX);
356 strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); 370 strlcpy(test_labels_path, mnist_files_dir, PATH_MAX);
357 strlcpy(test_images_path, mnist_files_dir, PATH_MAX); 371 strlcpy(test_images_path, mnist_files_dir, PATH_MAX);
358 strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); 372 strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX);
359 strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); 373 strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX);
360 strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); 374 strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX);
361 strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); 375 strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX);
362 376
363 train_images_file = gzopen(train_images_path, "r"); 377 train_images_file = gzopen(train_images_path, "r");
364 if (train_images_file == Z_NULL) { 378 if (train_images_file == Z_NULL) {
@@ -406,11 +420,18 @@ int main(int argc, const char** argv) {
406 } 420 }
407 421
408 // Network definition. 422 // Network definition.
409 const int image_size_pixels = train_set.rows * train_set.cols; 423 const int image_size_pixels = train_set.rows * train_set.cols;
410 const int num_layers = 2; 424 const int num_layers = 4;
411 const int layer_sizes[3] = { image_size_pixels, 100, 10 }; 425 const int hidden_size = 100;
412 const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; 426 const nnLayer layers[4] = {
413 if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { 427 {.type = nnLinear,
428 .linear = {.input_size = image_size_pixels, .output_size = hidden_size}},
429 {.type = nnSigmoid},
430 {.type = nnLinear,
431 .linear = {.input_size = hidden_size, .output_size = 10}},
432 {.type = nnSigmoid}
433 };
434 if (!(net = nnMakeNet(layers, num_layers, image_size_pixels))) {
414 fprintf(stderr, "Failed to create neural network\n"); 435 fprintf(stderr, "Failed to create neural network\n");
415 goto cleanup; 436 goto cleanup;
416 } 437 }
@@ -418,17 +439,17 @@ int main(int argc, const char** argv) {
418 // Train. 439 // Train.
419 printf("Training with up to %d images from the data set\n\n", max_num_images); 440 printf("Training with up to %d images from the data set\n\n", max_num_images);
420 const nnTrainingParams training_params = { 441 const nnTrainingParams training_params = {
421 .learning_rate = 0.1, 442 .learning_rate = 0.1,
422 .max_iterations = TRAIN_ITERATIONS, 443 .max_iterations = TRAIN_ITERATIONS,
423 .seed = 0, 444 .seed = 0,
424 .weight_init = nnWeightInitNormal, 445 .weight_init = nnWeightInitNormal,
425 .debug = true, 446 .debug = true,
426 }; 447 };
427 nnTrain(net, &train_set.images, &train_set.labels, &training_params); 448 nnTrain(net, &train_set.images, &train_set.labels, &training_params);
428 449
429 // Test. 450 // Test.
430 int hits = 0; 451 int hits = 0;
431 query = nnMakeQueryObject(net, /*num_inputs=*/1); 452 query = nnMakeQueryObject(net, /*num_inputs=*/1);
432 for (int i = 0; i < test_set.count; ++i) { 453 for (int i = 0; i < test_set.count; ++i) {
433 const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); 454 const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1);
434 const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); 455 const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1);
@@ -444,7 +465,7 @@ int main(int argc, const char** argv) {
444 } 465 }
445 const R hit_ratio = (R)hits / (R)test_set.count; 466 const R hit_ratio = (R)hits / (R)test_set.count;
446 printf("Test images: %d\n", test_set.count); 467 printf("Test images: %d\n", test_set.count);
447 printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); 468 printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio * 100);
448 469
449 success = true; 470 success = true;
450 471