RT3066: rewrite RSA padding checks to be slightly more constant time.
[openssl.git] / crypto / constant_time_locl.h
index 782da6c8b252e57dc3d042bc7f40e63bce772b8f..ccf7b62f5fd147c2f8e08a5ec5b544a9eb286a9f 100644 (file)
@@ -54,7 +54,7 @@ extern "C" {
 #endif
 
 /*
- * The following methods return a bitmask of all ones (0xff...f) for true
+ * The boolean methods return a bitmask of all ones (0xff...f) for true
  * and 0 for false. This is useful for choosing a value based on the result
  * of a conditional in constant time. For example,
  *
@@ -67,7 +67,7 @@ extern "C" {
  * can be written as
  *
  * unsigned int lt = constant_time_lt(a, b);
- * c = a & lt | b & ~lt;
+ * c = constant_time_select(lt, a, b);
  */
 
 /*
@@ -81,38 +81,53 @@ static inline unsigned int constant_time_msb(unsigned int a);
 /*
  * Returns 0xff..f if a < b and 0 otherwise.
  */
-inline unsigned int constant_time_lt(unsigned int a, unsigned int b);
+static inline unsigned int constant_time_lt(unsigned int a, unsigned int b);
 /* Convenience method for getting an 8-bit mask. */
-inline unsigned char constant_time_lt_8(unsigned int a, unsigned int b);
+static inline unsigned char constant_time_lt_8(unsigned int a, unsigned int b);
 
 /*
  * Returns 0xff..f if a >= b and 0 otherwise.
  */
-inline unsigned int constant_time_ge(unsigned int a, unsigned int b);
+static inline unsigned int constant_time_ge(unsigned int a, unsigned int b);
 /* Convenience method for getting an 8-bit mask. */
-inline unsigned char constant_time_ge_8(unsigned int a, unsigned int b);
+static inline unsigned char constant_time_ge_8(unsigned int a, unsigned int b);
 
 /*
  * Returns 0xff..f if a == 0 and 0 otherwise.
  */
-inline unsigned int constant_time_is_zero(unsigned int a);
+static inline unsigned int constant_time_is_zero(unsigned int a);
 /* Convenience method for getting an 8-bit mask. */
-inline unsigned char constant_time_is_zero_8(unsigned int a);
+static inline unsigned char constant_time_is_zero_8(unsigned int a);
 
 
 /*
  * Returns 0xff..f if a == b and 0 otherwise.
  */
-inline unsigned int constant_time_eq(unsigned int a, unsigned int b);
+static inline unsigned int constant_time_eq(unsigned int a, unsigned int b);
 /* Convenience method for getting an 8-bit mask. */
-inline unsigned char constant_time_eq_8(unsigned int a, unsigned int b);
+static inline unsigned char constant_time_eq_8(unsigned int a, unsigned int b);
+
+/*
+ * Returns (mask & a) | (~mask & b).
+ *
+ * When |mask| is all 1s or all 0s (as returned by the methods above),
+ * the select methods return either |a| (if |mask| is nonzero) or |b|
+ * (if |mask| is zero).
+ */
+static inline unsigned int constant_time_select(unsigned int mask,
+       unsigned int a, unsigned int b);
+/* Convenience method for unsigned chars. */
+static inline unsigned char constant_time_select_8(unsigned char mask,
+       unsigned char a, unsigned char b);
+/* Convenience method for signed integers. */
+static inline int constant_time_select_int(unsigned int mask, int a, int b);
 
 static inline unsigned int constant_time_msb(unsigned int a)
        {
        return (unsigned int)((int)(a) >> (sizeof(int) * 8 - 1));
        }
 
-inline unsigned int constant_time_lt(unsigned int a, unsigned int b)
+static inline unsigned int constant_time_lt(unsigned int a, unsigned int b)
        {
        unsigned int lt;
        /* Case 1: msb(a) == msb(b). a < b iff the MSB of a - b is set.*/
@@ -122,12 +137,12 @@ inline unsigned int constant_time_lt(unsigned int a, unsigned int b)
        return constant_time_msb(lt);
        }
 
-inline unsigned char constant_time_lt_8(unsigned int a, unsigned int b)
+static inline unsigned char constant_time_lt_8(unsigned int a, unsigned int b)
        {
        return (unsigned char)(constant_time_lt(a, b));
        }
 
-inline unsigned int constant_time_ge(unsigned int a, unsigned int b)
+static inline unsigned int constant_time_ge(unsigned int a, unsigned int b)
        {
        unsigned int ge;
        /* Case 1: msb(a) == msb(b). a >= b iff the MSB of a - b is not set.*/
@@ -137,31 +152,48 @@ inline unsigned int constant_time_ge(unsigned int a, unsigned int b)
        return constant_time_msb(ge);
        }
 
-inline unsigned char constant_time_ge_8(unsigned int a, unsigned int b)
+static inline unsigned char constant_time_ge_8(unsigned int a, unsigned int b)
        {
        return (unsigned char)(constant_time_ge(a, b));
        }
 
-inline unsigned int constant_time_is_zero(unsigned int a)
+static inline unsigned int constant_time_is_zero(unsigned int a)
        {
        return constant_time_msb(~a & (a - 1));
        }
 
-inline unsigned char constant_time_is_zero_8(unsigned int a)
+static inline unsigned char constant_time_is_zero_8(unsigned int a)
        {
        return (unsigned char)(constant_time_is_zero(a));
        }
 
-inline unsigned int constant_time_eq(unsigned int a, unsigned int b)
+static inline unsigned int constant_time_eq(unsigned int a, unsigned int b)
        {
        return constant_time_is_zero(a ^ b);
        }
 
-inline unsigned char constant_time_eq_8(unsigned int a, unsigned int b)
+static inline unsigned char constant_time_eq_8(unsigned int a, unsigned int b)
        {
        return (unsigned char)(constant_time_eq(a, b));
        }
 
+static inline unsigned int constant_time_select(unsigned int mask,
+       unsigned int a, unsigned int b)
+       {
+       return (mask & a) | (~mask & b);
+       }
+
+static inline unsigned char constant_time_select_8(unsigned char mask,
+       unsigned char a, unsigned char b)
+       {
+       return (unsigned char)(constant_time_select(mask, a, b));
+       }
+
+inline int constant_time_select_int(unsigned int mask, int a, int b)
+       {
+       return (int)(constant_time_select(mask, (unsigned)(a), (unsigned)(b)));
+       }
+
 #ifdef __cplusplus
 }
 #endif