From 411f66a2540fa17c736116d865e0ceb0cfe5623b Mon Sep 17 00:00:00 2001 From: jeanne Date: Wed, 11 May 2022 09:54:38 -0700 Subject: Initial commit. --- src/bin/mnist/CMakeLists.txt | 11 + src/bin/mnist/src/main.c | 473 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 484 insertions(+) create mode 100644 src/bin/mnist/CMakeLists.txt create mode 100644 src/bin/mnist/src/main.c (limited to 'src/bin/mnist') diff --git a/src/bin/mnist/CMakeLists.txt b/src/bin/mnist/CMakeLists.txt new file mode 100644 index 0000000..a6c54f2 --- /dev/null +++ b/src/bin/mnist/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.0) + +add_executable(mnist + src/main.c) + +target_link_libraries(mnist PRIVATE + neuralnet + bsd + z) + +target_compile_options(mnist PRIVATE -Wall -Wextra) diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c new file mode 100644 index 0000000..4d268ac --- /dev/null +++ b/src/bin/mnist/src/main.c @@ -0,0 +1,473 @@ +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +static const int TRAIN_ITERATIONS = 100; + +static const int32_t IMAGE_FILE_MAGIC = 0x00000803; +static const int32_t LABEL_FILE_MAGIC = 0x00000801; + +// Inputs of 0 cancel weights during training. This value is used to rescale the +// input pixels from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. +static const double PIXEL_LOWER_BOUND = 0.01; + +// Scale the outputs to (0,1) since the sigmoid cannot produce 0 or 1. +static const double LABEL_LOWER_BOUND = 0.01; +static const double LABEL_UPPER_BOUND = 0.99; + +// Epsilon used to compare R values. +static const double EPS = 1e-10; + +#define min(a,b) ((a) < (b) ? (a) : (b)) + +typedef struct ImageSet { + nnMatrix images; // Images flattened into row vectors of the matrix. + nnMatrix labels; // One-hot-encoded labels. + int count; // Number of images and labels. + int rows; // Rows in an image. + int cols; // Columns in an image. +} ImageSet; + +static void usage(const char* argv0) { + fprintf(stderr, "Usage: %s [num images]\n", argv0); + fprintf(stderr, "\n"); + fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); +} + +static bool R_eq(R a, R b) { + return fabs(a-b) <= EPS; +} + +static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { + assert(images); + assert((0 <= image_index) && (image_index < images->rows)); + + // Top line. + for (int j = 0; j < cols/2; ++j) { + printf(" -"); + } + printf("\n"); + + // Image. + const R* value = nnMatrixRow(images, image_index); + for (int i = 0; i < rows; ++i) { + printf("|"); + for (int j = 0; j < cols; ++j) { + if (*value > 0.8) { + printf("#"); + } else if (*value > 0.5) { + printf("*"); + } + else if (*value > PIXEL_LOWER_BOUND) { + printf(":"); + } else if (*value == 0.0) { + // Values should not be exactly 0, otherwise they cancel out weights + // during training. + printf("X"); + } else { + printf(" "); + } + value++; + } + printf("|\n"); + } + + // Bottom line. + for (int j = 0; j < cols/2; ++j) { + printf(" -"); + } + printf("\n"); +} + +static void PrintLabel(const nnMatrix* labels, int label_index) { + assert(labels); + assert((0 <= label_index) && (label_index < labels->rows)); + + // Compute the label from the one-hot encoding. + const R* value = nnMatrixRow(labels, label_index); + int label = -1; + for (int i = 0; i < 10; ++i) { + if (R_eq(*value++, LABEL_UPPER_BOUND)) { + label = i; + break; + } + } + assert((0 <= label) && (label <= 9)); + + printf("Label: %d ( ", label); + value = nnMatrixRow(labels, label_index); + for (int i = 0; i < 10; ++i) { + printf("%.3f ", *value++); + } + printf(")\n"); +} + +static R lerp(R a, R b, R t) { + return a + t*(b-a); +} + +/// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. +static R FormatPixel(uint8_t pixel) { + const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; + assert(value >= PIXEL_LOWER_BOUND); + assert(value <= 1.0); + return value; +} + +/// Rescales a one-hot-encoded label value to (0,1). +static R FormatLabel(R label) { + const R value = lerp(LABEL_LOWER_BOUND, LABEL_UPPER_BOUND, label); + assert(value > 0.0); + assert(value < 1.0); + return value; +} + +static int32_t ReverseEndian32(int32_t x) { + const int32_t x0 = x & 0xff; + const int32_t x1 = (x >> 8) & 0xff; + const int32_t x2 = (x >> 16) & 0xff; + const int32_t x3 = (x >> 24) & 0xff; + return (x0 << 24) | (x1 << 16) | (x2 << 8) | x3; +} + +static void ImageToMatrix( + const uint8_t* pixels, int num_pixels, int row, nnMatrix* images) { + assert(pixels); + assert(images); + + for (int i = 0; i < num_pixels; ++i) { + const R pixel = FormatPixel(pixels[i]); + nnMatrixSet(images, row, i, pixel); + } +} + +static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { + assert(images_file != Z_NULL); + assert(image_set); + + bool success = false; + + uint8_t* pixels = 0; + + int32_t magic, total_images, rows, cols; + if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || + (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || + (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || + (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { + fprintf(stderr, "Failed to read header\n"); + goto cleanup; + } + + magic = ReverseEndian32(magic); + total_images = ReverseEndian32(total_images); + rows = ReverseEndian32(rows); + cols = ReverseEndian32(cols); + + if (magic != IMAGE_FILE_MAGIC) { + fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", + magic, IMAGE_FILE_MAGIC); + goto cleanup; + } + + printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", + magic, total_images, rows, cols); + + total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; + + // Images are flattened into single row vectors. + const int num_pixels = rows * cols; + image_set->images = nnMatrixMake(total_images, num_pixels); + image_set->count = total_images; + image_set->rows = rows; + image_set->cols = cols; + + pixels = calloc(1, num_pixels); + if (!pixels) { + fprintf(stderr, "Failed to allocate image buffer\n"); + goto cleanup; + } + + for (int i = 0; i < total_images; ++i) { + const int bytes_read = gzread(images_file, pixels, num_pixels); + if (bytes_read < num_pixels) { + fprintf(stderr, "Failed to read image %d\n", i); + goto cleanup; + } + ImageToMatrix(pixels, num_pixels, i, &image_set->images); + } + + success = true; + +cleanup: + if (pixels) { + free(pixels); + } + if (!success) { + nnMatrixDel(&image_set->images); + } + return success; +} + +static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { + assert(labels_bytes); + assert(labels); + assert(labels->rows == num_labels); + assert(labels->cols == 10); + + static const R one_hot[10][10] = { + { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, + { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, + { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, + { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, + { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, + }; + + R* value = labels->values; + + for (int i = 0; i < num_labels; ++i) { + const uint8_t label = labels_bytes[i]; + const R* one_hot_value = one_hot[label]; + + for (int j = 0; j < 10; ++j) { + *value++ = FormatLabel(*one_hot_value++); + } + } +} + +static int OneHotDecode(const nnMatrix* label_matrix) { + assert(label_matrix); + assert(label_matrix->cols == 1); + assert(label_matrix->rows == 10); + + R max_value = 0; + int pos_max = 0; + for (int i = 0; i < 10; ++i) { + const R value = nnMatrixAt(label_matrix, 0, i); + if (value > max_value) { + max_value = value; + pos_max = i; + } + } + assert(pos_max >= 0); + assert(pos_max <= 10); + return pos_max; +} + +static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { + assert(labels_file != Z_NULL); + assert(image_set != 0); + + bool success = false; + + uint8_t* labels = 0; + + int32_t magic, total_labels; + if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || + (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { + fprintf(stderr, "Failed to read header\n"); + goto cleanup; + } + + magic = ReverseEndian32(magic); + total_labels = ReverseEndian32(total_labels); + + if (magic != LABEL_FILE_MAGIC) { + fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", + magic, LABEL_FILE_MAGIC); + goto cleanup; + } + + printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); + + total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; + + assert(image_set->count == total_labels); + + // One-hot encoding of labels, 10 values (digits) per label. + image_set->labels = nnMatrixMake(total_labels, 10); + + labels = calloc(total_labels, sizeof(uint8_t)); + if (!labels) { + fprintf(stderr, "Failed to allocate labels buffer\n"); + goto cleanup; + } + + if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { + fprintf(stderr, "Failed to read labels\n"); + goto cleanup; + } + + OneHotEncode(labels, total_labels, &image_set->labels); + + success = true; + +cleanup: + if (labels) { + free(labels); + } + if (!success) { + nnMatrixDel(&image_set->labels); + } + return success; +} + +int main(int argc, const char** argv) { + if (argc < 2) { + usage(argv[0]); + return 1; + } + + bool success = false; + + gzFile train_images_file = Z_NULL; + gzFile train_labels_file = Z_NULL; + gzFile test_images_file = Z_NULL; + gzFile test_labels_file = Z_NULL; + ImageSet train_set = { 0 }; + ImageSet test_set = { 0 }; + nnNeuralNetwork* net = 0; + nnQueryObject* query = 0; + + const char* mnist_files_dir = argv[1]; + const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; + + char train_labels_path[PATH_MAX]; + char train_images_path[PATH_MAX]; + char test_labels_path[PATH_MAX]; + char test_images_path[PATH_MAX]; + strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); + strlcpy(train_images_path, mnist_files_dir, PATH_MAX); + strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); + strlcpy(test_images_path, mnist_files_dir, PATH_MAX); + strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); + strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); + strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); + strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); + + train_images_file = gzopen(train_images_path, "r"); + if (train_images_file == Z_NULL) { + fprintf(stderr, "Failed to open file: %s\n", train_images_path); + goto cleanup; + } + + train_labels_file = gzopen(train_labels_path, "r"); + if (train_labels_file == Z_NULL) { + fprintf(stderr, "Failed to open file: %s\n", train_labels_path); + goto cleanup; + } + + test_images_file = gzopen(test_images_path, "r"); + if (test_images_file == Z_NULL) { + fprintf(stderr, "Failed to open file: %s\n", test_images_path); + goto cleanup; + } + + test_labels_file = gzopen(test_labels_path, "r"); + if (test_labels_file == Z_NULL) { + fprintf(stderr, "Failed to open file: %s\n", test_labels_path); + goto cleanup; + } + + if (!ReadImages(train_images_file, max_num_images, &train_set)) { + goto cleanup; + } + if (!ReadLabels(train_labels_file, max_num_images, &train_set)) { + goto cleanup; + } + + if (!ReadImages(test_images_file, max_num_images, &test_set)) { + goto cleanup; + } + if (!ReadLabels(test_labels_file, max_num_images, &test_set)) { + goto cleanup; + } + + printf("\nTraining image/label pair examples:\n"); + for (int i = 0; i < min(3, train_set.images.rows); ++i) { + PrintImage(&train_set.images, train_set.rows, train_set.cols, i); + PrintLabel(&train_set.labels, i); + printf("\n"); + } + + // Network definition. + const int image_size_pixels = train_set.rows * train_set.cols; + const int num_layers = 2; + const int layer_sizes[3] = { image_size_pixels, 100, 10 }; + const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; + if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { + fprintf(stderr, "Failed to create neural network\n"); + goto cleanup; + } + + // Train. + printf("Training with up to %d images from the data set\n\n", max_num_images); + const nnTrainingParams training_params = { + .learning_rate = 0.1, + .max_iterations = TRAIN_ITERATIONS, + .seed = 0, + .weight_init = nnWeightInitNormal, + .debug = true, + }; + nnTrain(net, &train_set.images, &train_set.labels, &training_params); + + // Test. + int hits = 0; + query = nnMakeQueryObject(net, /*num_inputs=*/1); + for (int i = 0; i < test_set.count; ++i) { + const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); + const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); + + nnQuery(net, query, &test_image); + + const int test_label_expected = OneHotDecode(&test_label); + const int test_label_actual = OneHotDecode(nnNetOutputs(query)); + + if (test_label_actual == test_label_expected) { + ++hits; + } + } + const R hit_ratio = (R)hits / (R)test_set.count; + printf("Test images: %d\n", test_set.count); + printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); + + success = true; + +cleanup: + if (query) { + nnDeleteQueryObject(&query); + } + if (net) { + nnDeleteNet(&net); + } + nnMatrixDel(&train_set.images); + nnMatrixDel(&test_set.images); + if (train_images_file != Z_NULL) { + gzclose(train_images_file); + } + if (train_labels_file != Z_NULL) { + gzclose(train_labels_file); + } + if (test_images_file != Z_NULL) { + gzclose(test_images_file); + } + if (test_labels_file != Z_NULL) { + gzclose(test_labels_file); + } + return success ? 0 : 1; +} -- cgit v1.2.3