Add some new constant time functions needed by curve448
[openssl.git] / crypto / ec / curve448 / constant_time.h
index 3f02694e370df6fdf21a4246d834adc7355a7fba..61389a2b2182d4ac852361871e4ae00d2f47c771 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
  * Copyright 2014 Cryptography Research, Inc.
  *
  * Licensed under the OpenSSL license (the "License").  You may not use
  * Instead, we're putting our trust in the loop unroller and unswitcher.
  */
 
+# if defined(__GNUC__) || defined(__clang__)
 /*
  * Unaligned big (vector?) register.
  */
 typedef struct {
     big_register_t unaligned;
-} __attribute__ ((packed)) unaligned_br_t;
+} __attribute((packed)) unaligned_br_t;
 
 /*
  * Unaligned word register, for architectures where that matters.
  */
 typedef struct {
     word_t unaligned;
-} __attribute__ ((packed)) unaligned_word_t;
+} __attribute((packed)) unaligned_word_t;
+
+#  define HAS_UNALIGNED_STRUCTS
+#  define RESTRICT __restrict__
+#else
+#  define RESTRICT
+# endif
 
 /*
  * Constant-time conditional swap.
@@ -58,26 +65,41 @@ typedef struct {
  * *a and *b must not alias.  Also, they must be at least as aligned
  * as their sizes, if the CPU cares about that sort of thing.
  */
-static ossl_inline void constant_time_cond_swap(void *__restrict__ a_,
-                                                void *__restrict__ b_,
+static ossl_inline void constant_time_cond_swap(void *RESTRICT a_,
+                                                void *RESTRICT b_,
                                                 word_t elem_bytes,
                                                 mask_t doswap)
 {
     word_t k;
     unsigned char *a = (unsigned char *)a_;
     unsigned char *b = (unsigned char *)b_;
-
     big_register_t br_mask = br_set_to_mask(doswap);
+# ifndef HAS_UNALIGNED_STRUCTS
+    unsigned char doswapc = (unsigned char)(doswap & 0xFF);
+# endif
+
     for (k = 0; k <= elem_bytes - sizeof(big_register_t);
          k += sizeof(big_register_t)) {
         if (elem_bytes % sizeof(big_register_t)) {
             /* unaligned */
+# ifdef HAS_UNALIGNED_STRUCTS
             big_register_t xor = ((unaligned_br_t *) (&a[k]))->unaligned
                                  ^ ((unaligned_br_t *) (&b[k]))->unaligned;
 
             xor &= br_mask;
             ((unaligned_br_t *)(&a[k]))->unaligned ^= xor;
             ((unaligned_br_t *)(&b[k]))->unaligned ^= xor;
+# else
+            size_t i;
+
+            for (i = 0; i < sizeof(big_register_t); i++) {
+                unsigned char xor = a[k + i] ^ b[k + i];
+
+                xor &= doswapc;
+                a[k + i] ^= xor;
+                b[k + i] ^= xor;
+            }
+# endif
         } else {
             /* aligned */
             big_register_t xor = *((big_register_t *) (&a[k]))
@@ -92,12 +114,24 @@ static ossl_inline void constant_time_cond_swap(void *__restrict__ a_,
         for (; k <= elem_bytes - sizeof(word_t); k += sizeof(word_t)) {
             if (elem_bytes % sizeof(word_t)) {
                 /* unaligned */
+# ifdef HAS_UNALIGNED_STRUCTS
                 word_t xor = ((unaligned_word_t *)(&a[k]))->unaligned
                              ^ ((unaligned_word_t *)(&b[k]))->unaligned;
 
                 xor &= doswap;
                 ((unaligned_word_t *)(&a[k]))->unaligned ^= xor;
                 ((unaligned_word_t *)(&b[k]))->unaligned ^= xor;
+# else
+                size_t i;
+
+                for (i = 0; i < sizeof(word_t); i++) {
+                    unsigned char xor = a[k + i] ^ b[k + i];
+
+                    xor &= doswapc;
+                    a[k + i] ^= xor;
+                    b[k + i] ^= xor;
+                }
+# endif
             } else {
                 /* aligned */
                 word_t xor = *((word_t *) (&a[k])) ^ *((word_t *) (&b[k]));
@@ -127,7 +161,7 @@ static ossl_inline void constant_time_cond_swap(void *__restrict__ a_,
  *
  * The table and output must not alias.
  */
-static ossl_inline void constant_time_lookup(void *__restrict__ out_,
+static ossl_inline void constant_time_lookup(void *RESTRICT out_,
                                              const void *table_,
                                              word_t elem_bytes,
                                              word_t n_table,
@@ -139,20 +173,36 @@ static ossl_inline void constant_time_lookup(void *__restrict__ out_,
     unsigned char *out = (unsigned char *)out_;
     const unsigned char *table = (const unsigned char *)table_;
     word_t j, k;
+# ifndef HAS_UNALIGNED_STRUCTS
+    unsigned char maskc;
+# endif
 
     memset(out, 0, elem_bytes);
     for (j = 0; j < n_table; j++, big_i -= big_one) {
         big_register_t br_mask = br_is_zero(big_i);
         word_t mask;
 
+# ifndef HAS_UNALIGNED_STRUCTS
+        maskc = (unsigned char)br_mask;
+# endif
+
         for (k = 0; k <= elem_bytes - sizeof(big_register_t);
              k += sizeof(big_register_t)) {
             if (elem_bytes % sizeof(big_register_t)) {
                 /* unaligned */
+# ifdef HAS_UNALIGNED_STRUCTS
                 ((unaligned_br_t *)(out + k))->unaligned |=
                         br_mask
                         & ((const unaligned_br_t *)
                            (&table[k + j * elem_bytes]))->unaligned;
+# else
+                size_t i;
+
+                for (i = 0; i < sizeof(big_register_t); i++)
+                    out[k + i] |= maskc
+                                  & ((unsigned char *) table)
+                                    [k + (j * elem_bytes) + i];
+# endif
             } else {
                 /* aligned */
                 *(big_register_t *)(out + k) |=
@@ -162,14 +212,26 @@ static ossl_inline void constant_time_lookup(void *__restrict__ out_,
         }
 
         mask = word_is_zero(idx ^ j);
+# ifndef HAS_UNALIGNED_STRUCTS
+        maskc = (unsigned char)mask;
+# endif
         if (elem_bytes % sizeof(big_register_t) >= sizeof(word_t)) {
             for (; k <= elem_bytes - sizeof(word_t); k += sizeof(word_t)) {
                 if (elem_bytes % sizeof(word_t)) {
                     /* input unaligned, output aligned */
+# ifdef HAS_UNALIGNED_STRUCTS
                     *(word_t *)(out + k) |=
                             mask
                             & ((const unaligned_word_t *)
                                (&table[k + j * elem_bytes]))->unaligned;
+# else
+                    size_t i;
+
+                    for (i = 0; i < sizeof(word_t); i++)
+                        out[k + i] |= maskc
+                                      & ((unsigned char *)table)
+                                         [k + (j * elem_bytes) + i];
+# endif
                 } else {
                     /* aligned */
                     *(word_t *)(out + k) |=
@@ -196,18 +258,21 @@ static ossl_inline void constant_time_lookup(void *__restrict__ out_,
  * Note that the output is not __restrict__, but if it overlaps either
  * input, it must be equal and not partially overlap.
  */
-static ossl_inline void constant_time_select(void *a_,
-                                             const void *bFalse_,
-                                             const void *bTrue_,
-                                             word_t elem_bytes,
-                                             mask_t mask,
-                                             size_t alignment_bytes)
+static ossl_inline void constant_time_select_c448(void *a_,
+                                                  const void *bFalse_,
+                                                  const void *bTrue_,
+                                                  word_t elem_bytes,
+                                                  mask_t mask,
+                                                  size_t alignment_bytes)
 {
     unsigned char *a = (unsigned char *)a_;
     const unsigned char *bTrue = (const unsigned char *)bTrue_;
     const unsigned char *bFalse = (const unsigned char *)bFalse_;
     word_t k;
     big_register_t br_mask = br_set_to_mask(mask);
+# ifndef HAS_UNALIGNED_STRUCTS
+    unsigned char maskc = (unsigned char)mask;
+# endif
 
     alignment_bytes |= elem_bytes;
 
@@ -215,10 +280,18 @@ static ossl_inline void constant_time_select(void *a_,
          k += sizeof(big_register_t)) {
         if (alignment_bytes % sizeof(big_register_t)) {
             /* unaligned */
+# ifdef HAS_UNALIGNED_STRUCTS
             ((unaligned_br_t *)(&a[k]))->unaligned =
                     (br_mask & ((const unaligned_br_t *)(&bTrue[k]))->unaligned)
                     | (~br_mask
                        & ((const unaligned_br_t *)(&bFalse[k]))->unaligned);
+# else
+                    size_t i;
+
+                    for (i = 0; i < sizeof(big_register_t); i++)
+                        a[k + i] = (maskc & ((unsigned char *)bTrue)[k + i])
+                                   | (~maskc & ((unsigned char *)bFalse)[k + i]);
+# endif
         } else {
             /* aligned */
             *(big_register_t *) (a + k) =
@@ -231,10 +304,18 @@ static ossl_inline void constant_time_select(void *a_,
         for (; k <= elem_bytes - sizeof(word_t); k += sizeof(word_t)) {
             if (alignment_bytes % sizeof(word_t)) {
                 /* unaligned */
+# ifdef HAS_UNALIGNED_STRUCTS
                 ((unaligned_word_t *) (&a[k]))->unaligned =
                     (mask & ((const unaligned_word_t *)(&bTrue[k]))->unaligned)
                     | (~mask &
                        ((const unaligned_word_t *)(&bFalse[k]))->unaligned);
+# else
+                size_t i;
+
+                for (i = 0; i < sizeof(word_t); i++)
+                    a[k + i] = (maskc & ((unsigned char *)bTrue)[k + i])
+                               | (~maskc & ((unsigned char *)bFalse)[k + i]);
+# endif
             } else {
                 /* aligned */
                 *(word_t *) (a + k) = (mask & *(const word_t *)(&bTrue[k]))
@@ -250,4 +331,7 @@ static ossl_inline void constant_time_select(void *a_,
     }
 }
 
+#undef RESTRICT
+#undef HAS_UNALIGNED_STRUCTS
+
 #endif                          /* __CONSTANT_TIME_H__ */