From 653e98e029a0d0f110b0ac599e50406060bb0f87 Mon Sep 17 00:00:00 2001 From: 3gg <3gg@shellblade.net> Date: Sat, 16 Dec 2023 10:21:16 -0800 Subject: Decouple activations from linear layer. --- src/bin/mnist/src/main.c | 195 ++++++++++++++++++++++++++--------------------- 1 file changed, 108 insertions(+), 87 deletions(-) (limited to 'src/bin') diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c index 9aa3ce5..53e0197 100644 --- a/src/bin/mnist/src/main.c +++ b/src/bin/mnist/src/main.c @@ -29,32 +29,35 @@ static const double LABEL_UPPER_BOUND = 0.99; // Epsilon used to compare R values. static const double EPS = 1e-10; -#define min(a,b) ((a) < (b) ? (a) : (b)) +#define min(a, b) ((a) < (b) ? (a) : (b)) typedef struct ImageSet { - nnMatrix images; // Images flattened into row vectors of the matrix. - nnMatrix labels; // One-hot-encoded labels. - int count; // Number of images and labels. - int rows; // Rows in an image. - int cols; // Columns in an image. + nnMatrix images; // Images flattened into row vectors of the matrix. + nnMatrix labels; // One-hot-encoded labels. + int count; // Number of images and labels. + int rows; // Rows in an image. + int cols; // Columns in an image. } ImageSet; static void usage(const char* argv0) { - fprintf(stderr, "Usage: %s [num images]\n", argv0); + fprintf( + stderr, "Usage: %s [num images]\n", + argv0); fprintf(stderr, "\n"); - fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); + fprintf( + stderr, + " Use -1 for [num images] to use all the images in the data set\n"); } -static bool R_eq(R a, R b) { - return fabs(a-b) <= EPS; -} +static bool R_eq(R a, R b) { return fabs(a - b) <= EPS; } -static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { +static void PrintImage( + const nnMatrix* images, int rows, int cols, int image_index) { assert(images); assert((0 <= image_index) && (image_index < images->rows)); // Top line. - for (int j = 0; j < cols/2; ++j) { + for (int j = 0; j < cols / 2; ++j) { printf(" -"); } printf("\n"); @@ -68,8 +71,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind printf("#"); } else if (*value > 0.5) { printf("*"); - } - else if (*value > PIXEL_LOWER_BOUND) { + } else if (*value > PIXEL_LOWER_BOUND) { printf(":"); } else if (*value == 0.0) { // Values should not be exactly 0, otherwise they cancel out weights @@ -84,7 +86,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind } // Bottom line. - for (int j = 0; j < cols/2; ++j) { + for (int j = 0; j < cols / 2; ++j) { printf(" -"); } printf("\n"); @@ -96,7 +98,7 @@ static void PrintLabel(const nnMatrix* labels, int label_index) { // Compute the label from the one-hot encoding. const R* value = nnMatrixRow(labels, label_index); - int label = -1; + int label = -1; for (int i = 0; i < 10; ++i) { if (R_eq(*value++, LABEL_UPPER_BOUND)) { label = i; @@ -113,13 +115,12 @@ static void PrintLabel(const nnMatrix* labels, int label_index) { printf(")\n"); } -static R lerp(R a, R b, R t) { - return a + t*(b-a); -} +static R lerp(R a, R b, R t) { return a + t * (b - a); } /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. static R FormatPixel(uint8_t pixel) { - const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; + const R value = + (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; assert(value >= PIXEL_LOWER_BOUND); assert(value <= 1.0); return value; @@ -152,7 +153,8 @@ static void ImageToMatrix( } } -static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { +static bool ReadImages( + gzFile images_file, int max_num_images, ImageSet* image_set) { assert(images_file != Z_NULL); assert(image_set); @@ -161,36 +163,41 @@ static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_s uint8_t* pixels = 0; int32_t magic, total_images, rows, cols; - if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || - (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || - (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || - (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { + if ((gzread(images_file, (char*)&magic, sizeof(int32_t)) != + sizeof(int32_t)) || + (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != + sizeof(int32_t)) || + (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || + (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t))) { fprintf(stderr, "Failed to read header\n"); goto cleanup; } - magic = ReverseEndian32(magic); + magic = ReverseEndian32(magic); total_images = ReverseEndian32(total_images); - rows = ReverseEndian32(rows); - cols = ReverseEndian32(cols); + rows = ReverseEndian32(rows); + cols = ReverseEndian32(cols); if (magic != IMAGE_FILE_MAGIC) { - fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", - magic, IMAGE_FILE_MAGIC); + fprintf( + stderr, "Magic number mismatch. Got %x, expected: %x\n", magic, + IMAGE_FILE_MAGIC); goto cleanup; } - printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", - magic, total_images, rows, cols); + printf( + "Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", magic, + total_images, rows, cols); - total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; + total_images = + max_num_images >= 0 ? min(total_images, max_num_images) : total_images; // Images are flattened into single row vectors. const int num_pixels = rows * cols; - image_set->images = nnMatrixMake(total_images, num_pixels); - image_set->count = total_images; - image_set->rows = rows; - image_set->cols = cols; + image_set->images = nnMatrixMake(total_images, num_pixels); + image_set->count = total_images; + image_set->rows = rows; + image_set->cols = cols; pixels = calloc(1, num_pixels); if (!pixels) { @@ -219,30 +226,31 @@ cleanup: return success; } -static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { +static void OneHotEncode( + const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { assert(labels_bytes); assert(labels); assert(labels->rows == num_labels); assert(labels->cols == 10); static const R one_hot[10][10] = { - { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, - { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, - { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, - { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, - { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, - { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, - { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 1, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 1, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, }; R* value = labels->values; for (int i = 0; i < num_labels; ++i) { - const uint8_t label = labels_bytes[i]; - const R* one_hot_value = one_hot[label]; + const uint8_t label = labels_bytes[i]; + const R* one_hot_value = one_hot[label]; for (int j = 0; j < 10; ++j) { *value++ = FormatLabel(*one_hot_value++); @@ -255,13 +263,13 @@ static int OneHotDecode(const nnMatrix* label_matrix) { assert(label_matrix->cols == 10); assert(label_matrix->rows == 1); - R max_value = 0; - int pos_max = 0; + R max_value = 0; + int pos_max = 0; for (int i = 0; i < 10; ++i) { const R value = nnMatrixAt(label_matrix, 0, i); if (value > max_value) { max_value = value; - pos_max = i; + pos_max = i; } } assert(pos_max >= 0); @@ -269,7 +277,8 @@ static int OneHotDecode(const nnMatrix* label_matrix) { return pos_max; } -static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { +static bool ReadLabels( + gzFile labels_file, int max_num_labels, ImageSet* image_set) { assert(labels_file != Z_NULL); assert(image_set != 0); @@ -278,24 +287,28 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s uint8_t* labels = 0; int32_t magic, total_labels; - if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || - (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { + if ((gzread(labels_file, (char*)&magic, sizeof(int32_t)) != + sizeof(int32_t)) || + (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != + sizeof(int32_t))) { fprintf(stderr, "Failed to read header\n"); goto cleanup; } - magic = ReverseEndian32(magic); + magic = ReverseEndian32(magic); total_labels = ReverseEndian32(total_labels); if (magic != LABEL_FILE_MAGIC) { - fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", - magic, LABEL_FILE_MAGIC); + fprintf( + stderr, "Magic number mismatch. Got %x, expected: %x\n", magic, + LABEL_FILE_MAGIC); goto cleanup; } printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); - total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; + total_labels = + max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; assert(image_set->count == total_labels); @@ -308,7 +321,8 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s goto cleanup; } - if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { + if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != + total_labels) { fprintf(stderr, "Failed to read labels\n"); goto cleanup; } @@ -335,17 +349,17 @@ int main(int argc, const char** argv) { bool success = false; - gzFile train_images_file = Z_NULL; - gzFile train_labels_file = Z_NULL; - gzFile test_images_file = Z_NULL; - gzFile test_labels_file = Z_NULL; - ImageSet train_set = { 0 }; - ImageSet test_set = { 0 }; - nnNeuralNetwork* net = 0; - nnQueryObject* query = 0; + gzFile train_images_file = Z_NULL; + gzFile train_labels_file = Z_NULL; + gzFile test_images_file = Z_NULL; + gzFile test_labels_file = Z_NULL; + ImageSet train_set = {0}; + ImageSet test_set = {0}; + nnNeuralNetwork* net = 0; + nnQueryObject* query = 0; const char* mnist_files_dir = argv[1]; - const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; + const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; char train_labels_path[PATH_MAX]; char train_images_path[PATH_MAX]; @@ -353,12 +367,12 @@ int main(int argc, const char** argv) { char test_images_path[PATH_MAX]; strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); strlcpy(train_images_path, mnist_files_dir, PATH_MAX); - strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); - strlcpy(test_images_path, mnist_files_dir, PATH_MAX); + strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); + strlcpy(test_images_path, mnist_files_dir, PATH_MAX); strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); - strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); - strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); + strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); + strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); train_images_file = gzopen(train_images_path, "r"); if (train_images_file == Z_NULL) { @@ -406,11 +420,18 @@ int main(int argc, const char** argv) { } // Network definition. - const int image_size_pixels = train_set.rows * train_set.cols; - const int num_layers = 2; - const int layer_sizes[3] = { image_size_pixels, 100, 10 }; - const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; - if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { + const int image_size_pixels = train_set.rows * train_set.cols; + const int num_layers = 4; + const int hidden_size = 100; + const nnLayer layers[4] = { + {.type = nnLinear, + .linear = {.input_size = image_size_pixels, .output_size = hidden_size}}, + {.type = nnSigmoid}, + {.type = nnLinear, + .linear = {.input_size = hidden_size, .output_size = 10}}, + {.type = nnSigmoid} + }; + if (!(net = nnMakeNet(layers, num_layers, image_size_pixels))) { fprintf(stderr, "Failed to create neural network\n"); goto cleanup; } @@ -418,17 +439,17 @@ int main(int argc, const char** argv) { // Train. printf("Training with up to %d images from the data set\n\n", max_num_images); const nnTrainingParams training_params = { - .learning_rate = 0.1, - .max_iterations = TRAIN_ITERATIONS, - .seed = 0, - .weight_init = nnWeightInitNormal, - .debug = true, + .learning_rate = 0.1, + .max_iterations = TRAIN_ITERATIONS, + .seed = 0, + .weight_init = nnWeightInitNormal, + .debug = true, }; nnTrain(net, &train_set.images, &train_set.labels, &training_params); // Test. int hits = 0; - query = nnMakeQueryObject(net, /*num_inputs=*/1); + query = nnMakeQueryObject(net, /*num_inputs=*/1); for (int i = 0; i < test_set.count; ++i) { const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); @@ -444,7 +465,7 @@ int main(int argc, const char** argv) { } const R hit_ratio = (R)hits / (R)test_set.count; printf("Test images: %d\n", test_set.count); - printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); + printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio * 100); success = true; -- cgit v1.2.3