diff options
Diffstat (limited to 'src/bin/mnist')
| -rw-r--r-- | src/bin/mnist/src/main.c | 195 |
1 files changed, 108 insertions, 87 deletions
diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c index 9aa3ce5..53e0197 100644 --- a/src/bin/mnist/src/main.c +++ b/src/bin/mnist/src/main.c | |||
| @@ -29,32 +29,35 @@ static const double LABEL_UPPER_BOUND = 0.99; | |||
| 29 | // Epsilon used to compare R values. | 29 | // Epsilon used to compare R values. |
| 30 | static const double EPS = 1e-10; | 30 | static const double EPS = 1e-10; |
| 31 | 31 | ||
| 32 | #define min(a,b) ((a) < (b) ? (a) : (b)) | 32 | #define min(a, b) ((a) < (b) ? (a) : (b)) |
| 33 | 33 | ||
| 34 | typedef struct ImageSet { | 34 | typedef struct ImageSet { |
| 35 | nnMatrix images; // Images flattened into row vectors of the matrix. | 35 | nnMatrix images; // Images flattened into row vectors of the matrix. |
| 36 | nnMatrix labels; // One-hot-encoded labels. | 36 | nnMatrix labels; // One-hot-encoded labels. |
| 37 | int count; // Number of images and labels. | 37 | int count; // Number of images and labels. |
| 38 | int rows; // Rows in an image. | 38 | int rows; // Rows in an image. |
| 39 | int cols; // Columns in an image. | 39 | int cols; // Columns in an image. |
| 40 | } ImageSet; | 40 | } ImageSet; |
| 41 | 41 | ||
| 42 | static void usage(const char* argv0) { | 42 | static void usage(const char* argv0) { |
| 43 | fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0); | 43 | fprintf( |
| 44 | stderr, "Usage: %s <path to mnist files directory> [num images]\n", | ||
| 45 | argv0); | ||
| 44 | fprintf(stderr, "\n"); | 46 | fprintf(stderr, "\n"); |
| 45 | fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); | 47 | fprintf( |
| 48 | stderr, | ||
| 49 | " Use -1 for [num images] to use all the images in the data set\n"); | ||
| 46 | } | 50 | } |
| 47 | 51 | ||
| 48 | static bool R_eq(R a, R b) { | 52 | static bool R_eq(R a, R b) { return fabs(a - b) <= EPS; } |
| 49 | return fabs(a-b) <= EPS; | ||
| 50 | } | ||
| 51 | 53 | ||
| 52 | static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { | 54 | static void PrintImage( |
| 55 | const nnMatrix* images, int rows, int cols, int image_index) { | ||
| 53 | assert(images); | 56 | assert(images); |
| 54 | assert((0 <= image_index) && (image_index < images->rows)); | 57 | assert((0 <= image_index) && (image_index < images->rows)); |
| 55 | 58 | ||
| 56 | // Top line. | 59 | // Top line. |
| 57 | for (int j = 0; j < cols/2; ++j) { | 60 | for (int j = 0; j < cols / 2; ++j) { |
| 58 | printf(" -"); | 61 | printf(" -"); |
| 59 | } | 62 | } |
| 60 | printf("\n"); | 63 | printf("\n"); |
| @@ -68,8 +71,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind | |||
| 68 | printf("#"); | 71 | printf("#"); |
| 69 | } else if (*value > 0.5) { | 72 | } else if (*value > 0.5) { |
| 70 | printf("*"); | 73 | printf("*"); |
| 71 | } | 74 | } else if (*value > PIXEL_LOWER_BOUND) { |
| 72 | else if (*value > PIXEL_LOWER_BOUND) { | ||
| 73 | printf(":"); | 75 | printf(":"); |
| 74 | } else if (*value == 0.0) { | 76 | } else if (*value == 0.0) { |
| 75 | // Values should not be exactly 0, otherwise they cancel out weights | 77 | // Values should not be exactly 0, otherwise they cancel out weights |
| @@ -84,7 +86,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind | |||
| 84 | } | 86 | } |
| 85 | 87 | ||
| 86 | // Bottom line. | 88 | // Bottom line. |
| 87 | for (int j = 0; j < cols/2; ++j) { | 89 | for (int j = 0; j < cols / 2; ++j) { |
| 88 | printf(" -"); | 90 | printf(" -"); |
| 89 | } | 91 | } |
| 90 | printf("\n"); | 92 | printf("\n"); |
| @@ -96,7 +98,7 @@ static void PrintLabel(const nnMatrix* labels, int label_index) { | |||
| 96 | 98 | ||
| 97 | // Compute the label from the one-hot encoding. | 99 | // Compute the label from the one-hot encoding. |
| 98 | const R* value = nnMatrixRow(labels, label_index); | 100 | const R* value = nnMatrixRow(labels, label_index); |
| 99 | int label = -1; | 101 | int label = -1; |
| 100 | for (int i = 0; i < 10; ++i) { | 102 | for (int i = 0; i < 10; ++i) { |
| 101 | if (R_eq(*value++, LABEL_UPPER_BOUND)) { | 103 | if (R_eq(*value++, LABEL_UPPER_BOUND)) { |
| 102 | label = i; | 104 | label = i; |
| @@ -113,13 +115,12 @@ static void PrintLabel(const nnMatrix* labels, int label_index) { | |||
| 113 | printf(")\n"); | 115 | printf(")\n"); |
| 114 | } | 116 | } |
| 115 | 117 | ||
| 116 | static R lerp(R a, R b, R t) { | 118 | static R lerp(R a, R b, R t) { return a + t * (b - a); } |
| 117 | return a + t*(b-a); | ||
| 118 | } | ||
| 119 | 119 | ||
| 120 | /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. | 120 | /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. |
| 121 | static R FormatPixel(uint8_t pixel) { | 121 | static R FormatPixel(uint8_t pixel) { |
| 122 | const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; | 122 | const R value = |
| 123 | (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; | ||
| 123 | assert(value >= PIXEL_LOWER_BOUND); | 124 | assert(value >= PIXEL_LOWER_BOUND); |
| 124 | assert(value <= 1.0); | 125 | assert(value <= 1.0); |
| 125 | return value; | 126 | return value; |
| @@ -152,7 +153,8 @@ static void ImageToMatrix( | |||
| 152 | } | 153 | } |
| 153 | } | 154 | } |
| 154 | 155 | ||
| 155 | static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { | 156 | static bool ReadImages( |
| 157 | gzFile images_file, int max_num_images, ImageSet* image_set) { | ||
| 156 | assert(images_file != Z_NULL); | 158 | assert(images_file != Z_NULL); |
| 157 | assert(image_set); | 159 | assert(image_set); |
| 158 | 160 | ||
| @@ -161,36 +163,41 @@ static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_s | |||
| 161 | uint8_t* pixels = 0; | 163 | uint8_t* pixels = 0; |
| 162 | 164 | ||
| 163 | int32_t magic, total_images, rows, cols; | 165 | int32_t magic, total_images, rows, cols; |
| 164 | if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || | 166 | if ((gzread(images_file, (char*)&magic, sizeof(int32_t)) != |
| 165 | (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || | 167 | sizeof(int32_t)) || |
| 166 | (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || | 168 | (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != |
| 167 | (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { | 169 | sizeof(int32_t)) || |
| 170 | (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || | ||
| 171 | (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t))) { | ||
| 168 | fprintf(stderr, "Failed to read header\n"); | 172 | fprintf(stderr, "Failed to read header\n"); |
| 169 | goto cleanup; | 173 | goto cleanup; |
| 170 | } | 174 | } |
| 171 | 175 | ||
| 172 | magic = ReverseEndian32(magic); | 176 | magic = ReverseEndian32(magic); |
| 173 | total_images = ReverseEndian32(total_images); | 177 | total_images = ReverseEndian32(total_images); |
| 174 | rows = ReverseEndian32(rows); | 178 | rows = ReverseEndian32(rows); |
| 175 | cols = ReverseEndian32(cols); | 179 | cols = ReverseEndian32(cols); |
| 176 | 180 | ||
| 177 | if (magic != IMAGE_FILE_MAGIC) { | 181 | if (magic != IMAGE_FILE_MAGIC) { |
| 178 | fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", | 182 | fprintf( |
| 179 | magic, IMAGE_FILE_MAGIC); | 183 | stderr, "Magic number mismatch. Got %x, expected: %x\n", magic, |
| 184 | IMAGE_FILE_MAGIC); | ||
| 180 | goto cleanup; | 185 | goto cleanup; |
| 181 | } | 186 | } |
| 182 | 187 | ||
| 183 | printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", | 188 | printf( |
| 184 | magic, total_images, rows, cols); | 189 | "Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", magic, |
| 190 | total_images, rows, cols); | ||
| 185 | 191 | ||
| 186 | total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; | 192 | total_images = |
| 193 | max_num_images >= 0 ? min(total_images, max_num_images) : total_images; | ||
| 187 | 194 | ||
| 188 | // Images are flattened into single row vectors. | 195 | // Images are flattened into single row vectors. |
| 189 | const int num_pixels = rows * cols; | 196 | const int num_pixels = rows * cols; |
| 190 | image_set->images = nnMatrixMake(total_images, num_pixels); | 197 | image_set->images = nnMatrixMake(total_images, num_pixels); |
| 191 | image_set->count = total_images; | 198 | image_set->count = total_images; |
| 192 | image_set->rows = rows; | 199 | image_set->rows = rows; |
| 193 | image_set->cols = cols; | 200 | image_set->cols = cols; |
| 194 | 201 | ||
| 195 | pixels = calloc(1, num_pixels); | 202 | pixels = calloc(1, num_pixels); |
| 196 | if (!pixels) { | 203 | if (!pixels) { |
| @@ -219,30 +226,31 @@ cleanup: | |||
| 219 | return success; | 226 | return success; |
| 220 | } | 227 | } |
| 221 | 228 | ||
| 222 | static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { | 229 | static void OneHotEncode( |
| 230 | const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { | ||
| 223 | assert(labels_bytes); | 231 | assert(labels_bytes); |
| 224 | assert(labels); | 232 | assert(labels); |
| 225 | assert(labels->rows == num_labels); | 233 | assert(labels->rows == num_labels); |
| 226 | assert(labels->cols == 10); | 234 | assert(labels->cols == 10); |
| 227 | 235 | ||
| 228 | static const R one_hot[10][10] = { | 236 | static const R one_hot[10][10] = { |
| 229 | { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | 237 | {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, |
| 230 | { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, | 238 | {0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, |
| 231 | { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, | 239 | {0, 0, 1, 0, 0, 0, 0, 0, 0, 0}, |
| 232 | { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, | 240 | {0, 0, 0, 1, 0, 0, 0, 0, 0, 0}, |
| 233 | { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, | 241 | {0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, |
| 234 | { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, | 242 | {0, 0, 0, 0, 0, 1, 0, 0, 0, 0}, |
| 235 | { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, | 243 | {0, 0, 0, 0, 0, 0, 1, 0, 0, 0}, |
| 236 | { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, | 244 | {0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, |
| 237 | { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, | 245 | {0, 0, 0, 0, 0, 0, 0, 0, 1, 0}, |
| 238 | { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, | 246 | {0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, |
| 239 | }; | 247 | }; |
| 240 | 248 | ||
| 241 | R* value = labels->values; | 249 | R* value = labels->values; |
| 242 | 250 | ||
| 243 | for (int i = 0; i < num_labels; ++i) { | 251 | for (int i = 0; i < num_labels; ++i) { |
| 244 | const uint8_t label = labels_bytes[i]; | 252 | const uint8_t label = labels_bytes[i]; |
| 245 | const R* one_hot_value = one_hot[label]; | 253 | const R* one_hot_value = one_hot[label]; |
| 246 | 254 | ||
| 247 | for (int j = 0; j < 10; ++j) { | 255 | for (int j = 0; j < 10; ++j) { |
| 248 | *value++ = FormatLabel(*one_hot_value++); | 256 | *value++ = FormatLabel(*one_hot_value++); |
| @@ -255,13 +263,13 @@ static int OneHotDecode(const nnMatrix* label_matrix) { | |||
| 255 | assert(label_matrix->cols == 10); | 263 | assert(label_matrix->cols == 10); |
| 256 | assert(label_matrix->rows == 1); | 264 | assert(label_matrix->rows == 1); |
| 257 | 265 | ||
| 258 | R max_value = 0; | 266 | R max_value = 0; |
| 259 | int pos_max = 0; | 267 | int pos_max = 0; |
| 260 | for (int i = 0; i < 10; ++i) { | 268 | for (int i = 0; i < 10; ++i) { |
| 261 | const R value = nnMatrixAt(label_matrix, 0, i); | 269 | const R value = nnMatrixAt(label_matrix, 0, i); |
| 262 | if (value > max_value) { | 270 | if (value > max_value) { |
| 263 | max_value = value; | 271 | max_value = value; |
| 264 | pos_max = i; | 272 | pos_max = i; |
| 265 | } | 273 | } |
| 266 | } | 274 | } |
| 267 | assert(pos_max >= 0); | 275 | assert(pos_max >= 0); |
| @@ -269,7 +277,8 @@ static int OneHotDecode(const nnMatrix* label_matrix) { | |||
| 269 | return pos_max; | 277 | return pos_max; |
| 270 | } | 278 | } |
| 271 | 279 | ||
| 272 | static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { | 280 | static bool ReadLabels( |
| 281 | gzFile labels_file, int max_num_labels, ImageSet* image_set) { | ||
| 273 | assert(labels_file != Z_NULL); | 282 | assert(labels_file != Z_NULL); |
| 274 | assert(image_set != 0); | 283 | assert(image_set != 0); |
| 275 | 284 | ||
| @@ -278,24 +287,28 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s | |||
| 278 | uint8_t* labels = 0; | 287 | uint8_t* labels = 0; |
| 279 | 288 | ||
| 280 | int32_t magic, total_labels; | 289 | int32_t magic, total_labels; |
| 281 | if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || | 290 | if ((gzread(labels_file, (char*)&magic, sizeof(int32_t)) != |
| 282 | (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { | 291 | sizeof(int32_t)) || |
| 292 | (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != | ||
| 293 | sizeof(int32_t))) { | ||
| 283 | fprintf(stderr, "Failed to read header\n"); | 294 | fprintf(stderr, "Failed to read header\n"); |
| 284 | goto cleanup; | 295 | goto cleanup; |
| 285 | } | 296 | } |
| 286 | 297 | ||
| 287 | magic = ReverseEndian32(magic); | 298 | magic = ReverseEndian32(magic); |
| 288 | total_labels = ReverseEndian32(total_labels); | 299 | total_labels = ReverseEndian32(total_labels); |
| 289 | 300 | ||
| 290 | if (magic != LABEL_FILE_MAGIC) { | 301 | if (magic != LABEL_FILE_MAGIC) { |
| 291 | fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", | 302 | fprintf( |
| 292 | magic, LABEL_FILE_MAGIC); | 303 | stderr, "Magic number mismatch. Got %x, expected: %x\n", magic, |
| 304 | LABEL_FILE_MAGIC); | ||
| 293 | goto cleanup; | 305 | goto cleanup; |
| 294 | } | 306 | } |
| 295 | 307 | ||
| 296 | printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); | 308 | printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); |
| 297 | 309 | ||
| 298 | total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; | 310 | total_labels = |
| 311 | max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; | ||
| 299 | 312 | ||
| 300 | assert(image_set->count == total_labels); | 313 | assert(image_set->count == total_labels); |
| 301 | 314 | ||
| @@ -308,7 +321,8 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s | |||
| 308 | goto cleanup; | 321 | goto cleanup; |
| 309 | } | 322 | } |
| 310 | 323 | ||
| 311 | if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { | 324 | if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != |
| 325 | total_labels) { | ||
| 312 | fprintf(stderr, "Failed to read labels\n"); | 326 | fprintf(stderr, "Failed to read labels\n"); |
| 313 | goto cleanup; | 327 | goto cleanup; |
| 314 | } | 328 | } |
| @@ -335,17 +349,17 @@ int main(int argc, const char** argv) { | |||
| 335 | 349 | ||
| 336 | bool success = false; | 350 | bool success = false; |
| 337 | 351 | ||
| 338 | gzFile train_images_file = Z_NULL; | 352 | gzFile train_images_file = Z_NULL; |
| 339 | gzFile train_labels_file = Z_NULL; | 353 | gzFile train_labels_file = Z_NULL; |
| 340 | gzFile test_images_file = Z_NULL; | 354 | gzFile test_images_file = Z_NULL; |
| 341 | gzFile test_labels_file = Z_NULL; | 355 | gzFile test_labels_file = Z_NULL; |
| 342 | ImageSet train_set = { 0 }; | 356 | ImageSet train_set = {0}; |
| 343 | ImageSet test_set = { 0 }; | 357 | ImageSet test_set = {0}; |
| 344 | nnNeuralNetwork* net = 0; | 358 | nnNeuralNetwork* net = 0; |
| 345 | nnQueryObject* query = 0; | 359 | nnQueryObject* query = 0; |
| 346 | 360 | ||
| 347 | const char* mnist_files_dir = argv[1]; | 361 | const char* mnist_files_dir = argv[1]; |
| 348 | const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; | 362 | const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; |
| 349 | 363 | ||
| 350 | char train_labels_path[PATH_MAX]; | 364 | char train_labels_path[PATH_MAX]; |
| 351 | char train_images_path[PATH_MAX]; | 365 | char train_images_path[PATH_MAX]; |
| @@ -353,12 +367,12 @@ int main(int argc, const char** argv) { | |||
| 353 | char test_images_path[PATH_MAX]; | 367 | char test_images_path[PATH_MAX]; |
| 354 | strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); | 368 | strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); |
| 355 | strlcpy(train_images_path, mnist_files_dir, PATH_MAX); | 369 | strlcpy(train_images_path, mnist_files_dir, PATH_MAX); |
| 356 | strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); | 370 | strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); |
| 357 | strlcpy(test_images_path, mnist_files_dir, PATH_MAX); | 371 | strlcpy(test_images_path, mnist_files_dir, PATH_MAX); |
| 358 | strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); | 372 | strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); |
| 359 | strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); | 373 | strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); |
| 360 | strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); | 374 | strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); |
| 361 | strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); | 375 | strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); |
| 362 | 376 | ||
| 363 | train_images_file = gzopen(train_images_path, "r"); | 377 | train_images_file = gzopen(train_images_path, "r"); |
| 364 | if (train_images_file == Z_NULL) { | 378 | if (train_images_file == Z_NULL) { |
| @@ -406,11 +420,18 @@ int main(int argc, const char** argv) { | |||
| 406 | } | 420 | } |
| 407 | 421 | ||
| 408 | // Network definition. | 422 | // Network definition. |
| 409 | const int image_size_pixels = train_set.rows * train_set.cols; | 423 | const int image_size_pixels = train_set.rows * train_set.cols; |
| 410 | const int num_layers = 2; | 424 | const int num_layers = 4; |
| 411 | const int layer_sizes[3] = { image_size_pixels, 100, 10 }; | 425 | const int hidden_size = 100; |
| 412 | const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; | 426 | const nnLayer layers[4] = { |
| 413 | if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { | 427 | {.type = nnLinear, |
| 428 | .linear = {.input_size = image_size_pixels, .output_size = hidden_size}}, | ||
| 429 | {.type = nnSigmoid}, | ||
| 430 | {.type = nnLinear, | ||
| 431 | .linear = {.input_size = hidden_size, .output_size = 10}}, | ||
| 432 | {.type = nnSigmoid} | ||
| 433 | }; | ||
| 434 | if (!(net = nnMakeNet(layers, num_layers, image_size_pixels))) { | ||
| 414 | fprintf(stderr, "Failed to create neural network\n"); | 435 | fprintf(stderr, "Failed to create neural network\n"); |
| 415 | goto cleanup; | 436 | goto cleanup; |
| 416 | } | 437 | } |
| @@ -418,17 +439,17 @@ int main(int argc, const char** argv) { | |||
| 418 | // Train. | 439 | // Train. |
| 419 | printf("Training with up to %d images from the data set\n\n", max_num_images); | 440 | printf("Training with up to %d images from the data set\n\n", max_num_images); |
| 420 | const nnTrainingParams training_params = { | 441 | const nnTrainingParams training_params = { |
| 421 | .learning_rate = 0.1, | 442 | .learning_rate = 0.1, |
| 422 | .max_iterations = TRAIN_ITERATIONS, | 443 | .max_iterations = TRAIN_ITERATIONS, |
| 423 | .seed = 0, | 444 | .seed = 0, |
| 424 | .weight_init = nnWeightInitNormal, | 445 | .weight_init = nnWeightInitNormal, |
| 425 | .debug = true, | 446 | .debug = true, |
| 426 | }; | 447 | }; |
| 427 | nnTrain(net, &train_set.images, &train_set.labels, &training_params); | 448 | nnTrain(net, &train_set.images, &train_set.labels, &training_params); |
| 428 | 449 | ||
| 429 | // Test. | 450 | // Test. |
| 430 | int hits = 0; | 451 | int hits = 0; |
| 431 | query = nnMakeQueryObject(net, /*num_inputs=*/1); | 452 | query = nnMakeQueryObject(net, /*num_inputs=*/1); |
| 432 | for (int i = 0; i < test_set.count; ++i) { | 453 | for (int i = 0; i < test_set.count; ++i) { |
| 433 | const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); | 454 | const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); |
| 434 | const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); | 455 | const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); |
| @@ -444,7 +465,7 @@ int main(int argc, const char** argv) { | |||
| 444 | } | 465 | } |
| 445 | const R hit_ratio = (R)hits / (R)test_set.count; | 466 | const R hit_ratio = (R)hits / (R)test_set.count; |
| 446 | printf("Test images: %d\n", test_set.count); | 467 | printf("Test images: %d\n", test_set.count); |
| 447 | printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); | 468 | printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio * 100); |
| 448 | 469 | ||
| 449 | success = true; | 470 | success = true; |
| 450 | 471 | ||
