aboutsummaryrefslogtreecommitdiff
path: root/src/lib/test/matrix_test.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/test/matrix_test.c')
-rw-r--r--src/lib/test/matrix_test.c350
1 files changed, 350 insertions, 0 deletions
diff --git a/src/lib/test/matrix_test.c b/src/lib/test/matrix_test.c
new file mode 100644
index 0000000..8191c97
--- /dev/null
+++ b/src/lib/test/matrix_test.c
@@ -0,0 +1,350 @@
1#include <neuralnet/matrix.h>
2
3#include "test.h"
4#include "test_util.h"
5
6#include <assert.h>
7#include <stdlib.h>
8
9// static void PrintMatrix(const nnMatrix* matrix) {
10// assert(matrix);
11
12// for (int i = 0; i < matrix->rows; ++i) {
13// for (int j = 0; j < matrix->cols; ++j) {
14// printf("%f ", nnMatrixAt(matrix, i, j));
15// }
16// printf("\n");
17// }
18// }
19
20TEST_CASE(nnMatrixMake_1x1) {
21 nnMatrix A = nnMatrixMake(1, 1);
22 TEST_EQUAL(A.rows, 1);
23 TEST_EQUAL(A.cols, 1);
24}
25
26TEST_CASE(nnMatrixMake_3x1) {
27 nnMatrix A = nnMatrixMake(3, 1);
28 TEST_EQUAL(A.rows, 3);
29 TEST_EQUAL(A.cols, 1);
30}
31
32TEST_CASE(nnMatrixInit_3x1) {
33 nnMatrix A = nnMatrixMake(3, 1);
34 nnMatrixInit(&A, (R[]) { 1, 2, 3 });
35 TEST_EQUAL(A.values[0], 1);
36 TEST_EQUAL(A.values[1], 2);
37 TEST_EQUAL(A.values[2], 3);
38}
39
40TEST_CASE(nnMatrixCopyCol_test) {
41 nnMatrix A = nnMatrixMake(3, 2);
42 nnMatrix B = nnMatrixMake(3, 1);
43
44 nnMatrixInit(&A, (R[]) {
45 1, 2,
46 3, 4,
47 5, 6,
48 });
49
50 nnMatrixCopyCol(&A, &B, 1, 0);
51
52 TEST_EQUAL(nnMatrixAt(&B, 0, 0), 2);
53 TEST_EQUAL(nnMatrixAt(&B, 1, 0), 4);
54 TEST_EQUAL(nnMatrixAt(&B, 2, 0), 6);
55
56 nnMatrixDel(&A);
57 nnMatrixDel(&B);
58}
59
60TEST_CASE(nnMatrixMul_square_3x3) {
61 nnMatrix A = nnMatrixMake(3, 3);
62 nnMatrix B = nnMatrixMake(3, 3);
63 nnMatrix O = nnMatrixMake(3, 3);
64
65 nnMatrixInit(&A, (const R[]){
66 1, 2, 3,
67 4, 5, 6,
68 7, 8, 9,
69 });
70 nnMatrixInit(&B, (const R[]){
71 2, 4, 3,
72 6, 8, 5,
73 1, 7, 9,
74 });
75 nnMatrixMul(&A, &B, &O);
76
77 const R expected[3][3] = {
78 { 17, 41, 40 },
79 { 44, 98, 91 },
80 { 71, 155, 142 },
81 };
82 for (int i = 0; i < O.rows; ++i) {
83 for (int j = 0; j < O.cols; ++j) {
84 TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
85 }
86 }
87
88 nnMatrixDel(&A);
89 nnMatrixDel(&B);
90 nnMatrixDel(&O);
91}
92
93TEST_CASE(nnMatrixMul_non_square_2x3_3x1) {
94 nnMatrix A = nnMatrixMake(2, 3);
95 nnMatrix B = nnMatrixMake(3, 1);
96 nnMatrix O = nnMatrixMake(2, 1);
97
98 nnMatrixInit(&A, (const R[]){
99 1, 2, 3,
100 4, 5, 6,
101 });
102 nnMatrixInit(&B, (const R[]){
103 2,
104 6,
105 1,
106 });
107 nnMatrixMul(&A, &B, &O);
108
109 const R expected[2][1] = {
110 { 17 },
111 { 44 },
112 };
113 for (int i = 0; i < O.rows; ++i) {
114 for (int j = 0; j < O.cols; ++j) {
115 TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
116 }
117 }
118
119 nnMatrixDel(&A);
120 nnMatrixDel(&B);
121 nnMatrixDel(&O);
122}
123
124TEST_CASE(nnMatrixMulAdd_test) {
125 nnMatrix A = nnMatrixMake(2, 3);
126 nnMatrix B = nnMatrixMake(2, 3);
127 nnMatrix O = nnMatrixMake(2, 3);
128 const R scale = 2;
129
130 nnMatrixInit(&A, (const R[]){
131 1, 2, 3,
132 4, 5, 6,
133 });
134 nnMatrixInit(&B, (const R[]){
135 2, 3, 1,
136 7, 4, 3
137 });
138 nnMatrixMulAdd(&A, &B, scale, &O); // O = A + B * scale
139
140 const R expected[2][3] = {
141 { 5, 8, 5 },
142 { 18, 13, 12 },
143 };
144 for (int i = 0; i < O.rows; ++i) {
145 for (int j = 0; j < O.cols; ++j) {
146 TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
147 }
148 }
149
150 nnMatrixDel(&A);
151 nnMatrixDel(&B);
152 nnMatrixDel(&O);
153}
154
155TEST_CASE(nnMatrixMulSub_test) {
156 nnMatrix A = nnMatrixMake(2, 3);
157 nnMatrix B = nnMatrixMake(2, 3);
158 nnMatrix O = nnMatrixMake(2, 3);
159 const R scale = 2;
160
161 nnMatrixInit(&A, (const R[]){
162 1, 2, 3,
163 4, 5, 6,
164 });
165 nnMatrixInit(&B, (const R[]){
166 2, 3, 1,
167 7, 4, 3
168 });
169 nnMatrixMulSub(&A, &B, scale, &O); // O = A - B * scale
170
171 const R expected[2][3] = {
172 { -3, -4, 1 },
173 { -10, -3, 0 },
174 };
175 for (int i = 0; i < O.rows; ++i) {
176 for (int j = 0; j < O.cols; ++j) {
177 TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
178 }
179 }
180
181 nnMatrixDel(&A);
182 nnMatrixDel(&B);
183 nnMatrixDel(&O);
184}
185
186TEST_CASE(nnMatrixMulPairs_2x3) {
187 nnMatrix A = nnMatrixMake(2, 3);
188 nnMatrix B = nnMatrixMake(2, 3);
189 nnMatrix O = nnMatrixMake(2, 3);
190
191 nnMatrixInit(&A, (const R[]){
192 1, 2, 3,
193 4, 5, 6,
194 });
195 nnMatrixInit(&B, (const R[]){
196 2, 3, 1,
197 7, 4, 3
198 });
199 nnMatrixMulPairs(&A, &B, &O);
200
201 const R expected[2][3] = {
202 { 2, 6, 3 },
203 { 28, 20, 18 },
204 };
205 for (int i = 0; i < O.rows; ++i) {
206 for (int j = 0; j < O.cols; ++j) {
207 TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
208 }
209 }
210
211 nnMatrixDel(&A);
212 nnMatrixDel(&B);
213 nnMatrixDel(&O);
214}
215
216TEST_CASE(nnMatrixAdd_square_2x2) {
217 nnMatrix A = nnMatrixMake(2, 2);
218 nnMatrix B = nnMatrixMake(2, 2);
219 nnMatrix C = nnMatrixMake(2, 2);
220
221 nnMatrixInit(&A, (R[]) {
222 1, 2,
223 3, 4,
224 });
225 nnMatrixInit(&B, (R[]) {
226 2, 1,
227 5, 3,
228 });
229
230 nnMatrixAdd(&A, &B, &C);
231
232 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), 3, EPS));
233 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), 3, EPS));
234 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), 8, EPS));
235 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), 7, EPS));
236
237 nnMatrixDel(&A);
238 nnMatrixDel(&B);
239 nnMatrixDel(&C);
240}
241
242TEST_CASE(nnMatrixSub_square_2x2) {
243 nnMatrix A = nnMatrixMake(2, 2);
244 nnMatrix B = nnMatrixMake(2, 2);
245 nnMatrix C = nnMatrixMake(2, 2);
246
247 nnMatrixInit(&A, (R[]) {
248 1, 2,
249 3, 4,
250 });
251 nnMatrixInit(&B, (R[]) {
252 2, 1,
253 5, 3,
254 });
255
256 nnMatrixSub(&A, &B, &C);
257
258 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), -1, EPS));
259 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), +1, EPS));
260 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), -2, EPS));
261 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), +1, EPS));
262
263 nnMatrixDel(&A);
264 nnMatrixDel(&B);
265 nnMatrixDel(&C);
266}
267
268TEST_CASE(nnMatrixAddRow_test) {
269 nnMatrix A = nnMatrixMake(2, 3);
270 nnMatrix B = nnMatrixMake(1, 3);
271 nnMatrix C = nnMatrixMake(2, 3);
272
273 nnMatrixInit(&A, (R[]) {
274 1, 2, 3,
275 4, 5, 6,
276 });
277 nnMatrixInit(&B, (R[]) {
278 2, 1, 3,
279 });
280
281 nnMatrixAddRow(&A, &B, &C);
282
283 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), 3, EPS));
284 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), 3, EPS));
285 TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 2), 6, EPS));
286 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), 6, EPS));
287 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), 6, EPS));
288 TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 2), 9, EPS));
289
290 nnMatrixDel(&A);
291 nnMatrixDel(&B);
292 nnMatrixDel(&C);
293}
294
295TEST_CASE(nnMatrixTranspose_square_2x2) {
296 nnMatrix A = nnMatrixMake(2, 2);
297 nnMatrix B = nnMatrixMake(2, 2);
298
299 nnMatrixInit(&A, (R[]) {
300 1, 2,
301 3, 4
302 });
303
304 nnMatrixTranspose(&A, &B);
305 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 1, EPS));
306 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 3, EPS));
307 TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 0), 2, EPS));
308 TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 1), 4, EPS));
309
310 nnMatrixDel(&A);
311 nnMatrixDel(&B);
312}
313
314TEST_CASE(nnMatrixTranspose_non_square_2x1) {
315 nnMatrix A = nnMatrixMake(2, 1);
316 nnMatrix B = nnMatrixMake(1, 2);
317
318 nnMatrixInit(&A, (R[]) {
319 1,
320 3,
321 });
322
323 nnMatrixTranspose(&A, &B);
324 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 1, EPS));
325 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 3, EPS));
326
327 nnMatrixDel(&A);
328 nnMatrixDel(&B);
329}
330
331TEST_CASE(nnMatrixGt_test) {
332 nnMatrix A = nnMatrixMake(2, 3);
333 nnMatrix B = nnMatrixMake(2, 3);
334
335 nnMatrixInit(&A, (R[]) {
336 -3, 2, 0,
337 4, -1, 5
338 });
339
340 nnMatrixGt(&A, 0, &B);
341 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 0, EPS));
342 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 1, EPS));
343 TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 2), 0, EPS));
344 TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 0), 1, EPS));
345 TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 1), 0, EPS));
346 TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 2), 1, EPS));
347
348 nnMatrixDel(&A);
349 nnMatrixDel(&B);
350}