rsaz_exp_x2.c: Remove unused ALIGN64 macro
[openssl.git] / crypto / bn / rsaz_exp_x2.c
1 /*
2  * Copyright 2020-2022 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright (c) 2020, Intel Corporation. All Rights Reserved.
4  *
5  * Licensed under the Apache License 2.0 (the "License").  You may not use
6  * this file except in compliance with the License.  You can obtain a copy
7  * in the file LICENSE in the source distribution or at
8  * https://www.openssl.org/source/license.html
9  *
10  *
11  * Originally written by Ilya Albrekht, Sergey Kirillov and Andrey Matyukov
12  * Intel Corporation
13  *
14  */
15
16 #include <openssl/opensslconf.h>
17 #include <openssl/crypto.h>
18 #include "rsaz_exp.h"
19
20 #ifndef RSAZ_ENABLED
21 NON_EMPTY_TRANSLATION_UNIT
22 #else
23 # include <assert.h>
24 # include <string.h>
25
26 # define ALIGN_OF(ptr, boundary) \
27     ((unsigned char *)(ptr) + (boundary - (((size_t)(ptr)) & (boundary - 1))))
28
29 /* Internal radix */
30 # define DIGIT_SIZE (52)
31 /* 52-bit mask */
32 # define DIGIT_MASK ((uint64_t)0xFFFFFFFFFFFFF)
33
34 # define BITS2WORD8_SIZE(x)  (((x) + 7) >> 3)
35 # define BITS2WORD64_SIZE(x) (((x) + 63) >> 6)
36
37 typedef uint64_t ALIGN1 uint64_t_align1;
38
39 static ossl_inline uint64_t get_digit52(const uint8_t *in, int in_len);
40 static ossl_inline void put_digit52(uint8_t *out, int out_len, uint64_t digit);
41 static void to_words52(BN_ULONG *out, int out_len, const BN_ULONG *in,
42                        int in_bitsize);
43 static void from_words52(BN_ULONG *bn_out, int out_bitsize, const BN_ULONG *in);
44 static ossl_inline void set_bit(BN_ULONG *a, int idx);
45
46 /* Number of |digit_size|-bit digits in |bitsize|-bit value */
47 static ossl_inline int number_of_digits(int bitsize, int digit_size)
48 {
49     return (bitsize + digit_size - 1) / digit_size;
50 }
51
52 typedef void (*AMM52)(BN_ULONG *res, const BN_ULONG *base,
53                       const BN_ULONG *exp, const BN_ULONG *m, BN_ULONG k0);
54 typedef void (*EXP52_x2)(BN_ULONG *res, const BN_ULONG *base,
55                          const BN_ULONG *exp[2], const BN_ULONG *m,
56                          const BN_ULONG *rr, const BN_ULONG k0[2]);
57
58 /*
59  * For details of the methods declared below please refer to
60  *    crypto/bn/asm/rsaz-avx512.pl
61  *
62  * Naming notes:
63  *  amm = Almost Montgomery Multiplication
64  *  ams = Almost Montgomery Squaring
65  *  52x20 - data represented as array of 20 digits in 52-bit radix
66  *  _x1_/_x2_ - 1 or 2 independent inputs/outputs
67  *  _256 suffix - uses 256-bit (AVX512VL) registers
68  */
69
70 /*AMM = Almost Montgomery Multiplication. */
71 void ossl_rsaz_amm52x20_x1_256(BN_ULONG *res, const BN_ULONG *base,
72                                const BN_ULONG *exp, const BN_ULONG *m,
73                                BN_ULONG k0);
74 static void RSAZ_exp52x20_x2_256(BN_ULONG *res, const BN_ULONG *base,
75                                  const BN_ULONG *exp[2], const BN_ULONG *m,
76                                  const BN_ULONG *rr, const BN_ULONG k0[2]);
77 void ossl_rsaz_amm52x20_x2_256(BN_ULONG *out, const BN_ULONG *a,
78                                const BN_ULONG *b, const BN_ULONG *m,
79                                const BN_ULONG k0[2]);
80 void ossl_extract_multiplier_2x20_win5(BN_ULONG *red_Y,
81                                        const BN_ULONG *red_table,
82                                        int red_table_idx, int tbl_idx);
83
84 /*
85  * Dual Montgomery modular exponentiation using prime moduli of the
86  * same bit size, optimized with AVX512 ISA.
87  *
88  * Input and output parameters for each exponentiation are independent and
89  * denoted here by index |i|, i = 1..2.
90  *
91  * Input and output are all in regular 2^64 radix.
92  *
93  * Each moduli shall be |factor_size| bit size.
94  *
95  * NOTE: currently only 2x1024 case is supported.
96  *
97  *  [out] res|i|      - result of modular exponentiation: array of qword values
98  *                      in regular (2^64) radix. Size of array shall be enough
99  *                      to hold |factor_size| bits.
100  *  [in]  base|i|     - base
101  *  [in]  exp|i|      - exponent
102  *  [in]  m|i|        - moduli
103  *  [in]  rr|i|       - Montgomery parameter RR = R^2 mod m|i|
104  *  [in]  k0_|i|      - Montgomery parameter k0 = -1/m|i| mod 2^64
105  *  [in]  factor_size - moduli bit size
106  *
107  * \return 0 in case of failure,
108  *         1 in case of success.
109  */
110 int ossl_rsaz_mod_exp_avx512_x2(BN_ULONG *res1,
111                                 const BN_ULONG *base1,
112                                 const BN_ULONG *exp1,
113                                 const BN_ULONG *m1,
114                                 const BN_ULONG *rr1,
115                                 BN_ULONG k0_1,
116                                 BN_ULONG *res2,
117                                 const BN_ULONG *base2,
118                                 const BN_ULONG *exp2,
119                                 const BN_ULONG *m2,
120                                 const BN_ULONG *rr2,
121                                 BN_ULONG k0_2,
122                                 int factor_size)
123 {
124     int ret = 0;
125
126     /*
127      * Number of word-size (BN_ULONG) digits to store exponent in redundant
128      * representation.
129      */
130     int exp_digits = number_of_digits(factor_size + 2, DIGIT_SIZE);
131     int coeff_pow = 4 * (DIGIT_SIZE * exp_digits - factor_size);
132     BN_ULONG *base1_red, *m1_red, *rr1_red;
133     BN_ULONG *base2_red, *m2_red, *rr2_red;
134     BN_ULONG *coeff_red;
135     BN_ULONG *storage = NULL;
136     BN_ULONG *storage_aligned = NULL;
137     BN_ULONG storage_len_bytes = 7 * exp_digits * sizeof(BN_ULONG);
138
139     /* AMM = Almost Montgomery Multiplication */
140     AMM52 amm = NULL;
141     /* Dual (2-exps in parallel) exponentiation */
142     EXP52_x2 exp_x2 = NULL;
143
144     const BN_ULONG *exp[2] = {0};
145     BN_ULONG k0[2] = {0};
146
147     /* Only 1024-bit factor size is supported now */
148     switch (factor_size) {
149     case 1024:
150         amm = ossl_rsaz_amm52x20_x1_256;
151         exp_x2 = RSAZ_exp52x20_x2_256;
152         break;
153     default:
154         goto err;
155     }
156
157     storage = (BN_ULONG *)OPENSSL_malloc(storage_len_bytes + 64);
158     if (storage == NULL)
159         goto err;
160     storage_aligned = (BN_ULONG *)ALIGN_OF(storage, 64);
161
162     /* Memory layout for red(undant) representations */
163     base1_red = storage_aligned;
164     base2_red = storage_aligned + 1 * exp_digits;
165     m1_red    = storage_aligned + 2 * exp_digits;
166     m2_red    = storage_aligned + 3 * exp_digits;
167     rr1_red   = storage_aligned + 4 * exp_digits;
168     rr2_red   = storage_aligned + 5 * exp_digits;
169     coeff_red = storage_aligned + 6 * exp_digits;
170
171     /* Convert base_i, m_i, rr_i, from regular to 52-bit radix */
172     to_words52(base1_red, exp_digits, base1, factor_size);
173     to_words52(base2_red, exp_digits, base2, factor_size);
174     to_words52(m1_red, exp_digits, m1, factor_size);
175     to_words52(m2_red, exp_digits, m2, factor_size);
176     to_words52(rr1_red, exp_digits, rr1, factor_size);
177     to_words52(rr2_red, exp_digits, rr2, factor_size);
178
179     /*
180      * Compute target domain Montgomery converters RR' for each modulus
181      * based on precomputed original domain's RR.
182      *
183      * RR -> RR' transformation steps:
184      *  (1) coeff = 2^k
185      *  (2) t = AMM(RR,RR) = RR^2 / R' mod m
186      *  (3) RR' = AMM(t, coeff) = RR^2 * 2^k / R'^2 mod m
187      * where
188      *  k = 4 * (52 * digits52 - modlen)
189      *  R  = 2^(64 * ceil(modlen/64)) mod m
190      *  RR = R^2 mod M
191      *  R' = 2^(52 * ceil(modlen/52)) mod m
192      *
193      *  modlen = 1024: k = 64, RR = 2^2048 mod m, RR' = 2^2080 mod m
194      */
195     memset(coeff_red, 0, exp_digits * sizeof(BN_ULONG));
196     /* (1) in reduced domain representation */
197     set_bit(coeff_red, 64 * (int)(coeff_pow / 52) + coeff_pow % 52);
198
199     amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1);     /* (2) for m1 */
200     amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1);   /* (3) for m1 */
201
202     amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2);     /* (2) for m2 */
203     amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2);   /* (3) for m2 */
204
205     exp[0] = exp1;
206     exp[1] = exp2;
207
208     k0[0] = k0_1;
209     k0[1] = k0_2;
210
211     exp_x2(rr1_red, base1_red, exp, m1_red, rr1_red, k0);
212
213     /* Convert rr_i back to regular radix */
214     from_words52(res1, factor_size, rr1_red);
215     from_words52(res2, factor_size, rr2_red);
216
217     /* bn_reduce_once_in_place expects number of BN_ULONG, not bit size */
218     factor_size /= sizeof(BN_ULONG) * 8;
219
220     bn_reduce_once_in_place(res1, /*carry=*/0, m1, storage, factor_size);
221     bn_reduce_once_in_place(res2, /*carry=*/0, m2, storage, factor_size);
222
223     ret = 1;
224 err:
225     if (storage != NULL) {
226         OPENSSL_cleanse(storage, storage_len_bytes);
227         OPENSSL_free(storage);
228     }
229     return ret;
230 }
231
232 /*
233  * Dual 1024-bit w-ary modular exponentiation using prime moduli of the same
234  * bit size using Almost Montgomery Multiplication, optimized with AVX512_IFMA
235  * ISA.
236  *
237  * The parameter w (window size) = 5.
238  *
239  *  [out] res      - result of modular exponentiation: 2x20 qword
240  *                   values in 2^52 radix.
241  *  [in]  base     - base (2x20 qword values in 2^52 radix)
242  *  [in]  exp      - array of 2 pointers to 16 qword values in 2^64 radix.
243  *                   Exponent is not converted to redundant representation.
244  *  [in]  m        - moduli (2x20 qword values in 2^52 radix)
245  *  [in]  rr       - Montgomery parameter for 2 moduli: RR = 2^2080 mod m.
246  *                   (2x20 qword values in 2^52 radix)
247  *  [in]  k0       - Montgomery parameter for 2 moduli: k0 = -1/m mod 2^64
248  *
249  * \return (void).
250  */
251 static void RSAZ_exp52x20_x2_256(BN_ULONG *out,          /* [2][20] */
252                                  const BN_ULONG *base,   /* [2][20] */
253                                  const BN_ULONG *exp[2], /* 2x16    */
254                                  const BN_ULONG *m,      /* [2][20] */
255                                  const BN_ULONG *rr,     /* [2][20] */
256                                  const BN_ULONG k0[2])
257 {
258 # define BITSIZE_MODULUS (1024)
259 # define EXP_WIN_SIZE (5)
260 # define EXP_WIN_MASK ((1U << EXP_WIN_SIZE) - 1)
261 /*
262  * Number of digits (64-bit words) in redundant representation to handle
263  * modulus bits
264  */
265 # define RED_DIGITS (20)
266 # define EXP_DIGITS (16)
267 # define DAMM ossl_rsaz_amm52x20_x2_256
268 /*
269  * Squaring is done using multiplication now. That can be a subject of
270  * optimization in future.
271  */
272 # define DAMS(r,a,m,k0) \
273               ossl_rsaz_amm52x20_x2_256((r),(a),(a),(m),(k0))
274
275     /* Allocate stack for red(undant) result Y and multiplier X */
276     ALIGN64 BN_ULONG red_Y[2][RED_DIGITS];
277     ALIGN64 BN_ULONG red_X[2][RED_DIGITS];
278
279     /* Allocate expanded exponent */
280     ALIGN64 BN_ULONG expz[2][EXP_DIGITS + 1];
281
282     /* Pre-computed table of base powers */
283     ALIGN64 BN_ULONG red_table[1U << EXP_WIN_SIZE][2][RED_DIGITS];
284
285     int idx;
286
287     memset(red_Y, 0, sizeof(red_Y));
288     memset(red_table, 0, sizeof(red_table));
289     memset(red_X, 0, sizeof(red_X));
290
291     /*
292      * Compute table of powers base^i, i = 0, ..., (2^EXP_WIN_SIZE) - 1
293      *   table[0] = mont(x^0) = mont(1)
294      *   table[1] = mont(x^1) = mont(x)
295      */
296     red_X[0][0] = 1;
297     red_X[1][0] = 1;
298     DAMM(red_table[0][0], (const BN_ULONG*)red_X, rr, m, k0);
299     DAMM(red_table[1][0], base,  rr, m, k0);
300
301     for (idx = 1; idx < (int)((1U << EXP_WIN_SIZE) / 2); idx++) {
302         DAMS(red_table[2 * idx + 0][0], red_table[1 * idx][0], m, k0);
303         DAMM(red_table[2 * idx + 1][0], red_table[2 * idx][0], red_table[1][0], m, k0);
304     }
305
306     /* Copy and expand exponents */
307     memcpy(expz[0], exp[0], EXP_DIGITS * sizeof(BN_ULONG));
308     expz[0][EXP_DIGITS] = 0;
309     memcpy(expz[1], exp[1], EXP_DIGITS * sizeof(BN_ULONG));
310     expz[1][EXP_DIGITS] = 0;
311
312     /* Exponentiation */
313     {
314         const int rem = BITSIZE_MODULUS % EXP_WIN_SIZE;
315         BN_ULONG table_idx_mask = EXP_WIN_MASK;
316
317         int exp_bit_no = BITSIZE_MODULUS - rem;
318         int exp_chunk_no = exp_bit_no / 64;
319         int exp_chunk_shift = exp_bit_no % 64;
320
321         BN_ULONG red_table_idx_0, red_table_idx_1;
322
323         /*
324          * If rem == 0, then
325          *      exp_bit_no = modulus_bitsize - exp_win_size
326          * However, this isn't possible because rem is { 1024, 1536, 2048 } % 5
327          * which is { 4, 1, 3 } respectively.
328          *
329          * If this assertion ever fails the fix above is easy.
330          */
331         OPENSSL_assert(rem != 0);
332
333         /* Process 1-st exp window - just init result */
334         red_table_idx_0 = expz[0][exp_chunk_no];
335         red_table_idx_1 = expz[1][exp_chunk_no];
336         /*
337          * The function operates with fixed moduli sizes divisible by 64,
338          * thus table index here is always in supported range [0, EXP_WIN_SIZE).
339          */
340         red_table_idx_0 >>= exp_chunk_shift;
341         red_table_idx_1 >>= exp_chunk_shift;
342
343         ossl_extract_multiplier_2x20_win5(red_Y[0], (const BN_ULONG*)red_table,
344                                           (int)red_table_idx_0, 0);
345         ossl_extract_multiplier_2x20_win5(red_Y[1], (const BN_ULONG*)red_table,
346                                           (int)red_table_idx_1, 1);
347
348         /* Process other exp windows */
349         for (exp_bit_no -= EXP_WIN_SIZE; exp_bit_no >= 0; exp_bit_no -= EXP_WIN_SIZE) {
350             /* Extract pre-computed multiplier from the table */
351             {
352                 BN_ULONG T;
353
354                 exp_chunk_no = exp_bit_no / 64;
355                 exp_chunk_shift = exp_bit_no % 64;
356                 {
357                     red_table_idx_0 = expz[0][exp_chunk_no];
358                     T = expz[0][exp_chunk_no + 1];
359
360                     red_table_idx_0 >>= exp_chunk_shift;
361                     /*
362                      * Get additional bits from then next quadword
363                      * when 64-bit boundaries are crossed.
364                      */
365                     if (exp_chunk_shift > 64 - EXP_WIN_SIZE) {
366                         T <<= (64 - exp_chunk_shift);
367                         red_table_idx_0 ^= T;
368                     }
369                     red_table_idx_0 &= table_idx_mask;
370
371                     ossl_extract_multiplier_2x20_win5(red_X[0],
372                                                       (const BN_ULONG*)red_table,
373                                                       (int)red_table_idx_0, 0);
374                 }
375                 {
376                     red_table_idx_1 = expz[1][exp_chunk_no];
377                     T = expz[1][exp_chunk_no + 1];
378
379                     red_table_idx_1 >>= exp_chunk_shift;
380                     /*
381                      * Get additional bits from then next quadword
382                      * when 64-bit boundaries are crossed.
383                      */
384                     if (exp_chunk_shift > 64 - EXP_WIN_SIZE) {
385                         T <<= (64 - exp_chunk_shift);
386                         red_table_idx_1 ^= T;
387                     }
388                     red_table_idx_1 &= table_idx_mask;
389
390                     ossl_extract_multiplier_2x20_win5(red_X[1],
391                                                       (const BN_ULONG*)red_table,
392                                                       (int)red_table_idx_1, 1);
393                 }
394             }
395
396             /* Series of squaring */
397             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
398             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
399             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
400             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
401             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
402
403             DAMM((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, (const BN_ULONG*)red_X, m, k0);
404         }
405     }
406
407     /*
408      *
409      * NB: After the last AMM of exponentiation in Montgomery domain, the result
410      * may be 1025-bit, but the conversion out of Montgomery domain performs an
411      * AMM(x,1) which guarantees that the final result is less than |m|, so no
412      * conditional subtraction is needed here. See "Efficient Software
413      * Implementations of Modular Exponentiation" (by Shay Gueron) paper for details.
414      */
415
416     /* Convert result back in regular 2^52 domain */
417     memset(red_X, 0, sizeof(red_X));
418     red_X[0][0] = 1;
419     red_X[1][0] = 1;
420     DAMM(out, (const BN_ULONG*)red_Y, (const BN_ULONG*)red_X, m, k0);
421
422     /* Clear exponents */
423     OPENSSL_cleanse(expz, sizeof(expz));
424     OPENSSL_cleanse(red_Y, sizeof(red_Y));
425
426 # undef DAMS
427 # undef DAMM
428 # undef EXP_DIGITS
429 # undef RED_DIGITS
430 # undef EXP_WIN_MASK
431 # undef EXP_WIN_SIZE
432 # undef BITSIZE_MODULUS
433 }
434
435 static ossl_inline uint64_t get_digit52(const uint8_t *in, int in_len)
436 {
437     uint64_t digit = 0;
438
439     assert(in != NULL);
440
441     for (; in_len > 0; in_len--) {
442         digit <<= 8;
443         digit += (uint64_t)(in[in_len - 1]);
444     }
445     return digit;
446 }
447
448 /*
449  * Convert array of words in regular (base=2^64) representation to array of
450  * words in redundant (base=2^52) one.
451  */
452 static void to_words52(BN_ULONG *out, int out_len,
453                        const BN_ULONG *in, int in_bitsize)
454 {
455     uint8_t *in_str = NULL;
456
457     assert(out != NULL);
458     assert(in != NULL);
459     /* Check destination buffer capacity */
460     assert(out_len >= number_of_digits(in_bitsize, DIGIT_SIZE));
461
462     in_str = (uint8_t *)in;
463
464     for (; in_bitsize >= (2 * DIGIT_SIZE); in_bitsize -= (2 * DIGIT_SIZE), out += 2) {
465         uint64_t digit;
466
467         memcpy(&digit, in_str, sizeof(digit));
468         out[0] = digit & DIGIT_MASK;
469         in_str += 6;
470         memcpy(&digit, in_str, sizeof(digit));
471         out[1] = (digit >> 4) & DIGIT_MASK;
472         in_str += 7;
473         out_len -= 2;
474     }
475
476     if (in_bitsize > DIGIT_SIZE) {
477         uint64_t digit = get_digit52(in_str, 7);
478
479         out[0] = digit & DIGIT_MASK;
480         in_str += 6;
481         in_bitsize -= DIGIT_SIZE;
482         digit = get_digit52(in_str, BITS2WORD8_SIZE(in_bitsize));
483         out[1] = digit >> 4;
484         out += 2;
485         out_len -= 2;
486     } else if (in_bitsize > 0) {
487         out[0] = get_digit52(in_str, BITS2WORD8_SIZE(in_bitsize));
488         out++;
489         out_len--;
490     }
491
492     while (out_len > 0) {
493         *out = 0;
494         out_len--;
495         out++;
496     }
497 }
498
499 static ossl_inline void put_digit52(uint8_t *pStr, int strLen, uint64_t digit)
500 {
501     assert(pStr != NULL);
502
503     for (; strLen > 0; strLen--) {
504         *pStr++ = (uint8_t)(digit & 0xFF);
505         digit >>= 8;
506     }
507 }
508
509 /*
510  * Convert array of words in redundant (base=2^52) representation to array of
511  * words in regular (base=2^64) one.
512  */
513 static void from_words52(BN_ULONG *out, int out_bitsize, const BN_ULONG *in)
514 {
515     int i;
516     int out_len = BITS2WORD64_SIZE(out_bitsize);
517
518     assert(out != NULL);
519     assert(in != NULL);
520
521     for (i = 0; i < out_len; i++)
522         out[i] = 0;
523
524     {
525         uint8_t *out_str = (uint8_t *)out;
526
527         for (; out_bitsize >= (2 * DIGIT_SIZE);
528                out_bitsize -= (2 * DIGIT_SIZE), in += 2) {
529             uint64_t digit;
530
531             digit = in[0];
532             memcpy(out_str, &digit, sizeof(digit));
533             out_str += 6;
534             digit = digit >> 48 | in[1] << 4;
535             memcpy(out_str, &digit, sizeof(digit));
536             out_str += 7;
537         }
538
539         if (out_bitsize > DIGIT_SIZE) {
540             put_digit52(out_str, 7, in[0]);
541             out_str += 6;
542             out_bitsize -= DIGIT_SIZE;
543             put_digit52(out_str, BITS2WORD8_SIZE(out_bitsize),
544                         (in[1] << 4 | in[0] >> 48));
545         } else if (out_bitsize) {
546             put_digit52(out_str, BITS2WORD8_SIZE(out_bitsize), in[0]);
547         }
548     }
549 }
550
551 /*
552  * Set bit at index |idx| in the words array |a|.
553  * It does not do any boundaries checks, make sure the index is valid before
554  * calling the function.
555  */
556 static ossl_inline void set_bit(BN_ULONG *a, int idx)
557 {
558     assert(a != NULL);
559
560     {
561         int i, j;
562
563         i = idx / BN_BITS2;
564         j = idx % BN_BITS2;
565         a[i] |= (((BN_ULONG)1) << j);
566     }
567 }
568
569 #endif