Add a test for unrecognised record types
[openssl.git] / test / constant_time_test.c
1 /*
2  * Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/constant_time_locl.h"
11 #include "e_os.h"
12
13 #include <limits.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16
17 static const unsigned int CONSTTIME_TRUE = (unsigned)(~0);
18 static const unsigned int CONSTTIME_FALSE = 0;
19 static const unsigned char CONSTTIME_TRUE_8 = 0xff;
20 static const unsigned char CONSTTIME_FALSE_8 = 0;
21
22 static int test_binary_op(unsigned int (*op) (unsigned int a, unsigned int b),
23                           const char *op_name, unsigned int a, unsigned int b,
24                           int is_true)
25 {
26     unsigned c = op(a, b);
27     if (is_true && c != CONSTTIME_TRUE) {
28         fprintf(stderr, "Test failed for %s(%du, %du): expected %du "
29                 "(TRUE), got %du\n", op_name, a, b, CONSTTIME_TRUE, c);
30         return 1;
31     } else if (!is_true && c != CONSTTIME_FALSE) {
32         fprintf(stderr, "Test failed for  %s(%du, %du): expected %du "
33                 "(FALSE), got %du\n", op_name, a, b, CONSTTIME_FALSE, c);
34         return 1;
35     }
36     return 0;
37 }
38
39 static int test_binary_op_8(unsigned
40                             char (*op) (unsigned int a, unsigned int b),
41                             const char *op_name, unsigned int a,
42                             unsigned int b, int is_true)
43 {
44     unsigned char c = op(a, b);
45     if (is_true && c != CONSTTIME_TRUE_8) {
46         fprintf(stderr, "Test failed for %s(%du, %du): expected %u "
47                 "(TRUE), got %u\n", op_name, a, b, CONSTTIME_TRUE_8, c);
48         return 1;
49     } else if (!is_true && c != CONSTTIME_FALSE_8) {
50         fprintf(stderr, "Test failed for  %s(%du, %du): expected %u "
51                 "(FALSE), got %u\n", op_name, a, b, CONSTTIME_FALSE_8, c);
52         return 1;
53     }
54     return 0;
55 }
56
57 static int test_is_zero(unsigned int a)
58 {
59     unsigned int c = constant_time_is_zero(a);
60     if (a == 0 && c != CONSTTIME_TRUE) {
61         fprintf(stderr, "Test failed for constant_time_is_zero(%du): "
62                 "expected %du (TRUE), got %du\n", a, CONSTTIME_TRUE, c);
63         return 1;
64     } else if (a != 0 && c != CONSTTIME_FALSE) {
65         fprintf(stderr, "Test failed for constant_time_is_zero(%du): "
66                 "expected %du (FALSE), got %du\n", a, CONSTTIME_FALSE, c);
67         return 1;
68     }
69     return 0;
70 }
71
72 static int test_is_zero_8(unsigned int a)
73 {
74     unsigned char c = constant_time_is_zero_8(a);
75     if (a == 0 && c != CONSTTIME_TRUE_8) {
76         fprintf(stderr, "Test failed for constant_time_is_zero(%du): "
77                 "expected %u (TRUE), got %u\n", a, CONSTTIME_TRUE_8, c);
78         return 1;
79     } else if (a != 0 && c != CONSTTIME_FALSE) {
80         fprintf(stderr, "Test failed for constant_time_is_zero(%du): "
81                 "expected %u (FALSE), got %u\n", a, CONSTTIME_FALSE_8, c);
82         return 1;
83     }
84     return 0;
85 }
86
87 static int test_select(unsigned int a, unsigned int b)
88 {
89     unsigned int selected = constant_time_select(CONSTTIME_TRUE, a, b);
90     if (selected != a) {
91         fprintf(stderr, "Test failed for constant_time_select(%du, %du,"
92                 "%du): expected %du(first value), got %du\n",
93                 CONSTTIME_TRUE, a, b, a, selected);
94         return 1;
95     }
96     selected = constant_time_select(CONSTTIME_FALSE, a, b);
97     if (selected != b) {
98         fprintf(stderr, "Test failed for constant_time_select(%du, %du,"
99                 "%du): expected %du(second value), got %du\n",
100                 CONSTTIME_FALSE, a, b, b, selected);
101         return 1;
102     }
103     return 0;
104 }
105
106 static int test_select_8(unsigned char a, unsigned char b)
107 {
108     unsigned char selected = constant_time_select_8(CONSTTIME_TRUE_8, a, b);
109     if (selected != a) {
110         fprintf(stderr, "Test failed for constant_time_select(%u, %u,"
111                 "%u): expected %u(first value), got %u\n",
112                 CONSTTIME_TRUE, a, b, a, selected);
113         return 1;
114     }
115     selected = constant_time_select_8(CONSTTIME_FALSE_8, a, b);
116     if (selected != b) {
117         fprintf(stderr, "Test failed for constant_time_select(%u, %u,"
118                 "%u): expected %u(second value), got %u\n",
119                 CONSTTIME_FALSE, a, b, b, selected);
120         return 1;
121     }
122     return 0;
123 }
124
125 static int test_select_int(int a, int b)
126 {
127     int selected = constant_time_select_int(CONSTTIME_TRUE, a, b);
128     if (selected != a) {
129         fprintf(stderr, "Test failed for constant_time_select(%du, %d,"
130                 "%d): expected %d(first value), got %d\n",
131                 CONSTTIME_TRUE, a, b, a, selected);
132         return 1;
133     }
134     selected = constant_time_select_int(CONSTTIME_FALSE, a, b);
135     if (selected != b) {
136         fprintf(stderr, "Test failed for constant_time_select(%du, %d,"
137                 "%d): expected %d(second value), got %d\n",
138                 CONSTTIME_FALSE, a, b, b, selected);
139         return 1;
140     }
141     return 0;
142 }
143
144 static int test_eq_int(int a, int b)
145 {
146     unsigned int equal = constant_time_eq_int(a, b);
147     if (a == b && equal != CONSTTIME_TRUE) {
148         fprintf(stderr, "Test failed for constant_time_eq_int(%d, %d): "
149                 "expected %du(TRUE), got %du\n", a, b, CONSTTIME_TRUE, equal);
150         return 1;
151     } else if (a != b && equal != CONSTTIME_FALSE) {
152         fprintf(stderr, "Test failed for constant_time_eq_int(%d, %d): "
153                 "expected %du(FALSE), got %du\n",
154                 a, b, CONSTTIME_FALSE, equal);
155         return 1;
156     }
157     return 0;
158 }
159
160 static int test_eq_int_8(int a, int b)
161 {
162     unsigned char equal = constant_time_eq_int_8(a, b);
163     if (a == b && equal != CONSTTIME_TRUE_8) {
164         fprintf(stderr, "Test failed for constant_time_eq_int_8(%d, %d): "
165                 "expected %u(TRUE), got %u\n", a, b, CONSTTIME_TRUE_8, equal);
166         return 1;
167     } else if (a != b && equal != CONSTTIME_FALSE_8) {
168         fprintf(stderr, "Test failed for constant_time_eq_int_8(%d, %d): "
169                 "expected %u(FALSE), got %u\n",
170                 a, b, CONSTTIME_FALSE_8, equal);
171         return 1;
172     }
173     return 0;
174 }
175
176 static unsigned int test_values[] =
177     { 0, 1, 1024, 12345, 32000, UINT_MAX / 2 - 1,
178     UINT_MAX / 2, UINT_MAX / 2 + 1, UINT_MAX - 1,
179     UINT_MAX
180 };
181
182 static unsigned char test_values_8[] =
183     { 0, 1, 2, 20, 32, 127, 128, 129, 255 };
184
185 static int signed_test_values[] = { 0, 1, -1, 1024, -1024, 12345, -12345,
186     32000, -32000, INT_MAX, INT_MIN, INT_MAX - 1,
187     INT_MIN + 1
188 };
189
190 int main(int argc, char *argv[])
191 {
192     unsigned int a, b, i, j;
193     int c, d;
194     unsigned char e, f;
195     int num_failed = 0, num_all = 0;
196     fprintf(stdout, "Testing constant time operations...\n");
197
198     for (i = 0; i < OSSL_NELEM(test_values); ++i) {
199         a = test_values[i];
200         num_failed += test_is_zero(a);
201         num_failed += test_is_zero_8(a);
202         num_all += 2;
203         for (j = 0; j < OSSL_NELEM(test_values); ++j) {
204             b = test_values[j];
205             num_failed += test_binary_op(&constant_time_lt,
206                                          "constant_time_lt", a, b, a < b);
207             num_failed += test_binary_op_8(&constant_time_lt_8,
208                                            "constant_time_lt_8", a, b, a < b);
209             num_failed += test_binary_op(&constant_time_lt,
210                                          "constant_time_lt_8", b, a, b < a);
211             num_failed += test_binary_op_8(&constant_time_lt_8,
212                                            "constant_time_lt_8", b, a, b < a);
213             num_failed += test_binary_op(&constant_time_ge,
214                                          "constant_time_ge", a, b, a >= b);
215             num_failed += test_binary_op_8(&constant_time_ge_8,
216                                            "constant_time_ge_8", a, b,
217                                            a >= b);
218             num_failed +=
219                 test_binary_op(&constant_time_ge, "constant_time_ge", b, a,
220                                b >= a);
221             num_failed +=
222                 test_binary_op_8(&constant_time_ge_8, "constant_time_ge_8", b,
223                                  a, b >= a);
224             num_failed +=
225                 test_binary_op(&constant_time_eq, "constant_time_eq", a, b,
226                                a == b);
227             num_failed +=
228                 test_binary_op_8(&constant_time_eq_8, "constant_time_eq_8", a,
229                                  b, a == b);
230             num_failed +=
231                 test_binary_op(&constant_time_eq, "constant_time_eq", b, a,
232                                b == a);
233             num_failed +=
234                 test_binary_op_8(&constant_time_eq_8, "constant_time_eq_8", b,
235                                  a, b == a);
236             num_failed += test_select(a, b);
237             num_all += 13;
238         }
239     }
240
241     for (i = 0; i < OSSL_NELEM(signed_test_values); ++i) {
242         c = signed_test_values[i];
243         for (j = 0; j < OSSL_NELEM(signed_test_values); ++j) {
244             d = signed_test_values[j];
245             num_failed += test_select_int(c, d);
246             num_failed += test_eq_int(c, d);
247             num_failed += test_eq_int_8(c, d);
248             num_all += 3;
249         }
250     }
251
252     for (i = 0; i < sizeof(test_values_8); ++i) {
253         e = test_values_8[i];
254         for (j = 0; j < sizeof(test_values_8); ++j) {
255             f = test_values_8[j];
256             num_failed += test_select_8(e, f);
257             num_all += 1;
258         }
259     }
260
261     if (!num_failed) {
262         fprintf(stdout, "success (ran %d tests)\n", num_all);
263         return EXIT_SUCCESS;
264     } else {
265         fprintf(stdout, "%d of %d tests failed!\n", num_failed, num_all);
266         return EXIT_FAILURE;
267     }
268 }