diff options
Diffstat (limited to 'src/bin/mnist')
-rw-r--r-- | src/bin/mnist/src/main.c | 195 |
1 files changed, 108 insertions, 87 deletions
diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c index 9aa3ce5..53e0197 100644 --- a/src/bin/mnist/src/main.c +++ b/src/bin/mnist/src/main.c | |||
@@ -29,32 +29,35 @@ static const double LABEL_UPPER_BOUND = 0.99; | |||
29 | // Epsilon used to compare R values. | 29 | // Epsilon used to compare R values. |
30 | static const double EPS = 1e-10; | 30 | static const double EPS = 1e-10; |
31 | 31 | ||
32 | #define min(a,b) ((a) < (b) ? (a) : (b)) | 32 | #define min(a, b) ((a) < (b) ? (a) : (b)) |
33 | 33 | ||
34 | typedef struct ImageSet { | 34 | typedef struct ImageSet { |
35 | nnMatrix images; // Images flattened into row vectors of the matrix. | 35 | nnMatrix images; // Images flattened into row vectors of the matrix. |
36 | nnMatrix labels; // One-hot-encoded labels. | 36 | nnMatrix labels; // One-hot-encoded labels. |
37 | int count; // Number of images and labels. | 37 | int count; // Number of images and labels. |
38 | int rows; // Rows in an image. | 38 | int rows; // Rows in an image. |
39 | int cols; // Columns in an image. | 39 | int cols; // Columns in an image. |
40 | } ImageSet; | 40 | } ImageSet; |
41 | 41 | ||
42 | static void usage(const char* argv0) { | 42 | static void usage(const char* argv0) { |
43 | fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0); | 43 | fprintf( |
44 | stderr, "Usage: %s <path to mnist files directory> [num images]\n", | ||
45 | argv0); | ||
44 | fprintf(stderr, "\n"); | 46 | fprintf(stderr, "\n"); |
45 | fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); | 47 | fprintf( |
48 | stderr, | ||
49 | " Use -1 for [num images] to use all the images in the data set\n"); | ||
46 | } | 50 | } |
47 | 51 | ||
48 | static bool R_eq(R a, R b) { | 52 | static bool R_eq(R a, R b) { return fabs(a - b) <= EPS; } |
49 | return fabs(a-b) <= EPS; | ||
50 | } | ||
51 | 53 | ||
52 | static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { | 54 | static void PrintImage( |
55 | const nnMatrix* images, int rows, int cols, int image_index) { | ||
53 | assert(images); | 56 | assert(images); |
54 | assert((0 <= image_index) && (image_index < images->rows)); | 57 | assert((0 <= image_index) && (image_index < images->rows)); |
55 | 58 | ||
56 | // Top line. | 59 | // Top line. |
57 | for (int j = 0; j < cols/2; ++j) { | 60 | for (int j = 0; j < cols / 2; ++j) { |
58 | printf(" -"); | 61 | printf(" -"); |
59 | } | 62 | } |
60 | printf("\n"); | 63 | printf("\n"); |
@@ -68,8 +71,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind | |||
68 | printf("#"); | 71 | printf("#"); |
69 | } else if (*value > 0.5) { | 72 | } else if (*value > 0.5) { |
70 | printf("*"); | 73 | printf("*"); |
71 | } | 74 | } else if (*value > PIXEL_LOWER_BOUND) { |
72 | else if (*value > PIXEL_LOWER_BOUND) { | ||
73 | printf(":"); | 75 | printf(":"); |
74 | } else if (*value == 0.0) { | 76 | } else if (*value == 0.0) { |
75 | // Values should not be exactly 0, otherwise they cancel out weights | 77 | // Values should not be exactly 0, otherwise they cancel out weights |
@@ -84,7 +86,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind | |||
84 | } | 86 | } |
85 | 87 | ||
86 | // Bottom line. | 88 | // Bottom line. |
87 | for (int j = 0; j < cols/2; ++j) { | 89 | for (int j = 0; j < cols / 2; ++j) { |
88 | printf(" -"); | 90 | printf(" -"); |
89 | } | 91 | } |
90 | printf("\n"); | 92 | printf("\n"); |
@@ -96,7 +98,7 @@ static void PrintLabel(const nnMatrix* labels, int label_index) { | |||
96 | 98 | ||
97 | // Compute the label from the one-hot encoding. | 99 | // Compute the label from the one-hot encoding. |
98 | const R* value = nnMatrixRow(labels, label_index); | 100 | const R* value = nnMatrixRow(labels, label_index); |
99 | int label = -1; | 101 | int label = -1; |
100 | for (int i = 0; i < 10; ++i) { | 102 | for (int i = 0; i < 10; ++i) { |
101 | if (R_eq(*value++, LABEL_UPPER_BOUND)) { | 103 | if (R_eq(*value++, LABEL_UPPER_BOUND)) { |
102 | label = i; | 104 | label = i; |
@@ -113,13 +115,12 @@ static void PrintLabel(const nnMatrix* labels, int label_index) { | |||
113 | printf(")\n"); | 115 | printf(")\n"); |
114 | } | 116 | } |
115 | 117 | ||
116 | static R lerp(R a, R b, R t) { | 118 | static R lerp(R a, R b, R t) { return a + t * (b - a); } |
117 | return a + t*(b-a); | ||
118 | } | ||
119 | 119 | ||
120 | /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. | 120 | /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. |
121 | static R FormatPixel(uint8_t pixel) { | 121 | static R FormatPixel(uint8_t pixel) { |
122 | const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; | 122 | const R value = |
123 | (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; | ||
123 | assert(value >= PIXEL_LOWER_BOUND); | 124 | assert(value >= PIXEL_LOWER_BOUND); |
124 | assert(value <= 1.0); | 125 | assert(value <= 1.0); |
125 | return value; | 126 | return value; |
@@ -152,7 +153,8 @@ static void ImageToMatrix( | |||
152 | } | 153 | } |
153 | } | 154 | } |
154 | 155 | ||
155 | static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { | 156 | static bool ReadImages( |
157 | gzFile images_file, int max_num_images, ImageSet* image_set) { | ||
156 | assert(images_file != Z_NULL); | 158 | assert(images_file != Z_NULL); |
157 | assert(image_set); | 159 | assert(image_set); |
158 | 160 | ||
@@ -161,36 +163,41 @@ static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_s | |||
161 | uint8_t* pixels = 0; | 163 | uint8_t* pixels = 0; |
162 | 164 | ||
163 | int32_t magic, total_images, rows, cols; | 165 | int32_t magic, total_images, rows, cols; |
164 | if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || | 166 | if ((gzread(images_file, (char*)&magic, sizeof(int32_t)) != |
165 | (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || | 167 | sizeof(int32_t)) || |
166 | (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || | 168 | (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != |
167 | (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { | 169 | sizeof(int32_t)) || |
170 | (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || | ||
171 | (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t))) { | ||
168 | fprintf(stderr, "Failed to read header\n"); | 172 | fprintf(stderr, "Failed to read header\n"); |
169 | goto cleanup; | 173 | goto cleanup; |
170 | } | 174 | } |
171 | 175 | ||
172 | magic = ReverseEndian32(magic); | 176 | magic = ReverseEndian32(magic); |
173 | total_images = ReverseEndian32(total_images); | 177 | total_images = ReverseEndian32(total_images); |
174 | rows = ReverseEndian32(rows); | 178 | rows = ReverseEndian32(rows); |
175 | cols = ReverseEndian32(cols); | 179 | cols = ReverseEndian32(cols); |
176 | 180 | ||
177 | if (magic != IMAGE_FILE_MAGIC) { | 181 | if (magic != IMAGE_FILE_MAGIC) { |
178 | fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", | 182 | fprintf( |
179 | magic, IMAGE_FILE_MAGIC); | 183 | stderr, "Magic number mismatch. Got %x, expected: %x\n", magic, |
184 | IMAGE_FILE_MAGIC); | ||
180 | goto cleanup; | 185 | goto cleanup; |
181 | } | 186 | } |
182 | 187 | ||
183 | printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", | 188 | printf( |
184 | magic, total_images, rows, cols); | 189 | "Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", magic, |
190 | total_images, rows, cols); | ||
185 | 191 | ||
186 | total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; | 192 | total_images = |
193 | max_num_images >= 0 ? min(total_images, max_num_images) : total_images; | ||
187 | 194 | ||
188 | // Images are flattened into single row vectors. | 195 | // Images are flattened into single row vectors. |
189 | const int num_pixels = rows * cols; | 196 | const int num_pixels = rows * cols; |
190 | image_set->images = nnMatrixMake(total_images, num_pixels); | 197 | image_set->images = nnMatrixMake(total_images, num_pixels); |
191 | image_set->count = total_images; | 198 | image_set->count = total_images; |
192 | image_set->rows = rows; | 199 | image_set->rows = rows; |
193 | image_set->cols = cols; | 200 | image_set->cols = cols; |
194 | 201 | ||
195 | pixels = calloc(1, num_pixels); | 202 | pixels = calloc(1, num_pixels); |
196 | if (!pixels) { | 203 | if (!pixels) { |
@@ -219,30 +226,31 @@ cleanup: | |||
219 | return success; | 226 | return success; |
220 | } | 227 | } |
221 | 228 | ||
222 | static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { | 229 | static void OneHotEncode( |
230 | const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { | ||
223 | assert(labels_bytes); | 231 | assert(labels_bytes); |
224 | assert(labels); | 232 | assert(labels); |
225 | assert(labels->rows == num_labels); | 233 | assert(labels->rows == num_labels); |
226 | assert(labels->cols == 10); | 234 | assert(labels->cols == 10); |
227 | 235 | ||
228 | static const R one_hot[10][10] = { | 236 | static const R one_hot[10][10] = { |
229 | { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | 237 | {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, |
230 | { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, | 238 | {0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, |
231 | { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, | 239 | {0, 0, 1, 0, 0, 0, 0, 0, 0, 0}, |
232 | { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, | 240 | {0, 0, 0, 1, 0, 0, 0, 0, 0, 0}, |
233 | { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, | 241 | {0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, |
234 | { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, | 242 | {0, 0, 0, 0, 0, 1, 0, 0, 0, 0}, |
235 | { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, | 243 | {0, 0, 0, 0, 0, 0, 1, 0, 0, 0}, |
236 | { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, | 244 | {0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, |
237 | { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, | 245 | {0, 0, 0, 0, 0, 0, 0, 0, 1, 0}, |
238 | { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, | 246 | {0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, |
239 | }; | 247 | }; |
240 | 248 | ||
241 | R* value = labels->values; | 249 | R* value = labels->values; |
242 | 250 | ||
243 | for (int i = 0; i < num_labels; ++i) { | 251 | for (int i = 0; i < num_labels; ++i) { |
244 | const uint8_t label = labels_bytes[i]; | 252 | const uint8_t label = labels_bytes[i]; |
245 | const R* one_hot_value = one_hot[label]; | 253 | const R* one_hot_value = one_hot[label]; |
246 | 254 | ||
247 | for (int j = 0; j < 10; ++j) { | 255 | for (int j = 0; j < 10; ++j) { |
248 | *value++ = FormatLabel(*one_hot_value++); | 256 | *value++ = FormatLabel(*one_hot_value++); |
@@ -255,13 +263,13 @@ static int OneHotDecode(const nnMatrix* label_matrix) { | |||
255 | assert(label_matrix->cols == 10); | 263 | assert(label_matrix->cols == 10); |
256 | assert(label_matrix->rows == 1); | 264 | assert(label_matrix->rows == 1); |
257 | 265 | ||
258 | R max_value = 0; | 266 | R max_value = 0; |
259 | int pos_max = 0; | 267 | int pos_max = 0; |
260 | for (int i = 0; i < 10; ++i) { | 268 | for (int i = 0; i < 10; ++i) { |
261 | const R value = nnMatrixAt(label_matrix, 0, i); | 269 | const R value = nnMatrixAt(label_matrix, 0, i); |
262 | if (value > max_value) { | 270 | if (value > max_value) { |
263 | max_value = value; | 271 | max_value = value; |
264 | pos_max = i; | 272 | pos_max = i; |
265 | } | 273 | } |
266 | } | 274 | } |
267 | assert(pos_max >= 0); | 275 | assert(pos_max >= 0); |
@@ -269,7 +277,8 @@ static int OneHotDecode(const nnMatrix* label_matrix) { | |||
269 | return pos_max; | 277 | return pos_max; |
270 | } | 278 | } |
271 | 279 | ||
272 | static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { | 280 | static bool ReadLabels( |
281 | gzFile labels_file, int max_num_labels, ImageSet* image_set) { | ||
273 | assert(labels_file != Z_NULL); | 282 | assert(labels_file != Z_NULL); |
274 | assert(image_set != 0); | 283 | assert(image_set != 0); |
275 | 284 | ||
@@ -278,24 +287,28 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s | |||
278 | uint8_t* labels = 0; | 287 | uint8_t* labels = 0; |
279 | 288 | ||
280 | int32_t magic, total_labels; | 289 | int32_t magic, total_labels; |
281 | if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || | 290 | if ((gzread(labels_file, (char*)&magic, sizeof(int32_t)) != |
282 | (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { | 291 | sizeof(int32_t)) || |
292 | (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != | ||
293 | sizeof(int32_t))) { | ||
283 | fprintf(stderr, "Failed to read header\n"); | 294 | fprintf(stderr, "Failed to read header\n"); |
284 | goto cleanup; | 295 | goto cleanup; |
285 | } | 296 | } |
286 | 297 | ||
287 | magic = ReverseEndian32(magic); | 298 | magic = ReverseEndian32(magic); |
288 | total_labels = ReverseEndian32(total_labels); | 299 | total_labels = ReverseEndian32(total_labels); |
289 | 300 | ||
290 | if (magic != LABEL_FILE_MAGIC) { | 301 | if (magic != LABEL_FILE_MAGIC) { |
291 | fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", | 302 | fprintf( |
292 | magic, LABEL_FILE_MAGIC); | 303 | stderr, "Magic number mismatch. Got %x, expected: %x\n", magic, |
304 | LABEL_FILE_MAGIC); | ||
293 | goto cleanup; | 305 | goto cleanup; |
294 | } | 306 | } |
295 | 307 | ||
296 | printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); | 308 | printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); |
297 | 309 | ||
298 | total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; | 310 | total_labels = |
311 | max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; | ||
299 | 312 | ||
300 | assert(image_set->count == total_labels); | 313 | assert(image_set->count == total_labels); |
301 | 314 | ||
@@ -308,7 +321,8 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s | |||
308 | goto cleanup; | 321 | goto cleanup; |
309 | } | 322 | } |
310 | 323 | ||
311 | if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { | 324 | if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != |
325 | total_labels) { | ||
312 | fprintf(stderr, "Failed to read labels\n"); | 326 | fprintf(stderr, "Failed to read labels\n"); |
313 | goto cleanup; | 327 | goto cleanup; |
314 | } | 328 | } |
@@ -335,17 +349,17 @@ int main(int argc, const char** argv) { | |||
335 | 349 | ||
336 | bool success = false; | 350 | bool success = false; |
337 | 351 | ||
338 | gzFile train_images_file = Z_NULL; | 352 | gzFile train_images_file = Z_NULL; |
339 | gzFile train_labels_file = Z_NULL; | 353 | gzFile train_labels_file = Z_NULL; |
340 | gzFile test_images_file = Z_NULL; | 354 | gzFile test_images_file = Z_NULL; |
341 | gzFile test_labels_file = Z_NULL; | 355 | gzFile test_labels_file = Z_NULL; |
342 | ImageSet train_set = { 0 }; | 356 | ImageSet train_set = {0}; |
343 | ImageSet test_set = { 0 }; | 357 | ImageSet test_set = {0}; |
344 | nnNeuralNetwork* net = 0; | 358 | nnNeuralNetwork* net = 0; |
345 | nnQueryObject* query = 0; | 359 | nnQueryObject* query = 0; |
346 | 360 | ||
347 | const char* mnist_files_dir = argv[1]; | 361 | const char* mnist_files_dir = argv[1]; |
348 | const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; | 362 | const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; |
349 | 363 | ||
350 | char train_labels_path[PATH_MAX]; | 364 | char train_labels_path[PATH_MAX]; |
351 | char train_images_path[PATH_MAX]; | 365 | char train_images_path[PATH_MAX]; |
@@ -353,12 +367,12 @@ int main(int argc, const char** argv) { | |||
353 | char test_images_path[PATH_MAX]; | 367 | char test_images_path[PATH_MAX]; |
354 | strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); | 368 | strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); |
355 | strlcpy(train_images_path, mnist_files_dir, PATH_MAX); | 369 | strlcpy(train_images_path, mnist_files_dir, PATH_MAX); |
356 | strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); | 370 | strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); |
357 | strlcpy(test_images_path, mnist_files_dir, PATH_MAX); | 371 | strlcpy(test_images_path, mnist_files_dir, PATH_MAX); |
358 | strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); | 372 | strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); |
359 | strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); | 373 | strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); |
360 | strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); | 374 | strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); |
361 | strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); | 375 | strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); |
362 | 376 | ||
363 | train_images_file = gzopen(train_images_path, "r"); | 377 | train_images_file = gzopen(train_images_path, "r"); |
364 | if (train_images_file == Z_NULL) { | 378 | if (train_images_file == Z_NULL) { |
@@ -406,11 +420,18 @@ int main(int argc, const char** argv) { | |||
406 | } | 420 | } |
407 | 421 | ||
408 | // Network definition. | 422 | // Network definition. |
409 | const int image_size_pixels = train_set.rows * train_set.cols; | 423 | const int image_size_pixels = train_set.rows * train_set.cols; |
410 | const int num_layers = 2; | 424 | const int num_layers = 4; |
411 | const int layer_sizes[3] = { image_size_pixels, 100, 10 }; | 425 | const int hidden_size = 100; |
412 | const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; | 426 | const nnLayer layers[4] = { |
413 | if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { | 427 | {.type = nnLinear, |
428 | .linear = {.input_size = image_size_pixels, .output_size = hidden_size}}, | ||
429 | {.type = nnSigmoid}, | ||
430 | {.type = nnLinear, | ||
431 | .linear = {.input_size = hidden_size, .output_size = 10}}, | ||
432 | {.type = nnSigmoid} | ||
433 | }; | ||
434 | if (!(net = nnMakeNet(layers, num_layers, image_size_pixels))) { | ||
414 | fprintf(stderr, "Failed to create neural network\n"); | 435 | fprintf(stderr, "Failed to create neural network\n"); |
415 | goto cleanup; | 436 | goto cleanup; |
416 | } | 437 | } |
@@ -418,17 +439,17 @@ int main(int argc, const char** argv) { | |||
418 | // Train. | 439 | // Train. |
419 | printf("Training with up to %d images from the data set\n\n", max_num_images); | 440 | printf("Training with up to %d images from the data set\n\n", max_num_images); |
420 | const nnTrainingParams training_params = { | 441 | const nnTrainingParams training_params = { |
421 | .learning_rate = 0.1, | 442 | .learning_rate = 0.1, |
422 | .max_iterations = TRAIN_ITERATIONS, | 443 | .max_iterations = TRAIN_ITERATIONS, |
423 | .seed = 0, | 444 | .seed = 0, |
424 | .weight_init = nnWeightInitNormal, | 445 | .weight_init = nnWeightInitNormal, |
425 | .debug = true, | 446 | .debug = true, |
426 | }; | 447 | }; |
427 | nnTrain(net, &train_set.images, &train_set.labels, &training_params); | 448 | nnTrain(net, &train_set.images, &train_set.labels, &training_params); |
428 | 449 | ||
429 | // Test. | 450 | // Test. |
430 | int hits = 0; | 451 | int hits = 0; |
431 | query = nnMakeQueryObject(net, /*num_inputs=*/1); | 452 | query = nnMakeQueryObject(net, /*num_inputs=*/1); |
432 | for (int i = 0; i < test_set.count; ++i) { | 453 | for (int i = 0; i < test_set.count; ++i) { |
433 | const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); | 454 | const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); |
434 | const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); | 455 | const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); |
@@ -444,7 +465,7 @@ int main(int argc, const char** argv) { | |||
444 | } | 465 | } |
445 | const R hit_ratio = (R)hits / (R)test_set.count; | 466 | const R hit_ratio = (R)hits / (R)test_set.count; |
446 | printf("Test images: %d\n", test_set.count); | 467 | printf("Test images: %d\n", test_set.count); |
447 | printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); | 468 | printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio * 100); |
448 | 469 | ||
449 | success = true; | 470 | success = true; |
450 | 471 | ||