Clear secret stack values after use in the ED25519-functions
[openssl.git] / crypto / ec / curve25519.c
index 77f54940363d002a6048b39612f9ed5db6b32799..8002b3e05aba32259f3ed9a97680fa2bd47c6bf4 100644 (file)
@@ -3448,6 +3448,8 @@ static void ge_scalarmult_base(ge_p3 *h, const uint8_t *a) {
     ge_madd(&r, h, &t);
     ge_p1p1_to_p3(h, &r);
   }
+
+  OPENSSL_cleanse(e, sizeof(e));
 }
 
 /* Replace (f,g) with (g,f) if b == 1;
@@ -3578,6 +3580,8 @@ static void x25519_scalar_mult_generic(uint8_t out[32],
   fe_invert(z2, z2);
   fe_mul(x2, x2, z2);
   fe_tobytes(out, x2);
+
+  OPENSSL_cleanse(e, sizeof(e));
 }
 
 static void x25519_scalar_mult(uint8_t out[32], const uint8_t scalar[32],
@@ -3862,38 +3866,38 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry12 = (s12 + (1 << 20)) >> 21;
   s13 += carry12;
-  s12 -= carry12 << 21;
+  s12 -= carry12 * (1 << 21);
   carry14 = (s14 + (1 << 20)) >> 21;
   s15 += carry14;
-  s14 -= carry14 << 21;
+  s14 -= carry14 * (1 << 21);
   carry16 = (s16 + (1 << 20)) >> 21;
   s17 += carry16;
-  s16 -= carry16 << 21;
+  s16 -= carry16 * (1 << 21);
 
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
   carry13 = (s13 + (1 << 20)) >> 21;
   s14 += carry13;
-  s13 -= carry13 << 21;
+  s13 -= carry13 * (1 << 21);
   carry15 = (s15 + (1 << 20)) >> 21;
   s16 += carry15;
-  s15 -= carry15 << 21;
+  s15 -= carry15 * (1 << 21);
 
   s5 += s17 * 666643;
   s6 += s17 * 470296;
@@ -3945,41 +3949,41 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry0 = (s0 + (1 << 20)) >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry2 = (s2 + (1 << 20)) >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry4 = (s4 + (1 << 20)) >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
 
   carry1 = (s1 + (1 << 20)) >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry3 = (s3 + (1 << 20)) >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry5 = (s5 + (1 << 20)) >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -3991,40 +3995,40 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry11 = s11 >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -4036,37 +4040,37 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
 
   s[0] = s0 >> 0;
   s[1] = s0 >> 8;
@@ -4232,74 +4236,74 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = (s0 + (1 << 20)) >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry2 = (s2 + (1 << 20)) >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry4 = (s4 + (1 << 20)) >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry12 = (s12 + (1 << 20)) >> 21;
   s13 += carry12;
-  s12 -= carry12 << 21;
+  s12 -= carry12 * (1 << 21);
   carry14 = (s14 + (1 << 20)) >> 21;
   s15 += carry14;
-  s14 -= carry14 << 21;
+  s14 -= carry14 * (1 << 21);
   carry16 = (s16 + (1 << 20)) >> 21;
   s17 += carry16;
-  s16 -= carry16 << 21;
+  s16 -= carry16 * (1 << 21);
   carry18 = (s18 + (1 << 20)) >> 21;
   s19 += carry18;
-  s18 -= carry18 << 21;
+  s18 -= carry18 * (1 << 21);
   carry20 = (s20 + (1 << 20)) >> 21;
   s21 += carry20;
-  s20 -= carry20 << 21;
+  s20 -= carry20 * (1 << 21);
   carry22 = (s22 + (1 << 20)) >> 21;
   s23 += carry22;
-  s22 -= carry22 << 21;
+  s22 -= carry22 * (1 << 21);
 
   carry1 = (s1 + (1 << 20)) >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry3 = (s3 + (1 << 20)) >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry5 = (s5 + (1 << 20)) >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
   carry13 = (s13 + (1 << 20)) >> 21;
   s14 += carry13;
-  s13 -= carry13 << 21;
+  s13 -= carry13 * (1 << 21);
   carry15 = (s15 + (1 << 20)) >> 21;
   s16 += carry15;
-  s15 -= carry15 << 21;
+  s15 -= carry15 * (1 << 21);
   carry17 = (s17 + (1 << 20)) >> 21;
   s18 += carry17;
-  s17 -= carry17 << 21;
+  s17 -= carry17 * (1 << 21);
   carry19 = (s19 + (1 << 20)) >> 21;
   s20 += carry19;
-  s19 -= carry19 << 21;
+  s19 -= carry19 * (1 << 21);
   carry21 = (s21 + (1 << 20)) >> 21;
   s22 += carry21;
-  s21 -= carry21 << 21;
+  s21 -= carry21 * (1 << 21);
 
   s11 += s23 * 666643;
   s12 += s23 * 470296;
@@ -4351,38 +4355,38 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry12 = (s12 + (1 << 20)) >> 21;
   s13 += carry12;
-  s12 -= carry12 << 21;
+  s12 -= carry12 * (1 << 21);
   carry14 = (s14 + (1 << 20)) >> 21;
   s15 += carry14;
-  s14 -= carry14 << 21;
+  s14 -= carry14 * (1 << 21);
   carry16 = (s16 + (1 << 20)) >> 21;
   s17 += carry16;
-  s16 -= carry16 << 21;
+  s16 -= carry16 * (1 << 21);
 
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
   carry13 = (s13 + (1 << 20)) >> 21;
   s14 += carry13;
-  s13 -= carry13 << 21;
+  s13 -= carry13 * (1 << 21);
   carry15 = (s15 + (1 << 20)) >> 21;
   s16 += carry15;
-  s15 -= carry15 << 21;
+  s15 -= carry15 * (1 << 21);
 
   s5 += s17 * 666643;
   s6 += s17 * 470296;
@@ -4434,41 +4438,41 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = (s0 + (1 << 20)) >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry2 = (s2 + (1 << 20)) >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry4 = (s4 + (1 << 20)) >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
 
   carry1 = (s1 + (1 << 20)) >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry3 = (s3 + (1 << 20)) >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry5 = (s5 + (1 << 20)) >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -4480,40 +4484,40 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry11 = s11 >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -4525,37 +4529,37 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
 
   s[0] = s0 >> 0;
   s[1] = s0 >> 8;
@@ -4635,7 +4639,6 @@ int ED25519_sign(uint8_t *out_sig, const uint8_t *message, size_t message_len,
 int ED25519_verify(const uint8_t *message, size_t message_len,
                    const uint8_t signature[64], const uint8_t public_key[32]) {
   ge_p3 A;
-  uint8_t pkcopy[32];
   uint8_t rcopy[32];
   uint8_t scopy[32];
   SHA512_CTX hash_ctx;
@@ -4651,7 +4654,6 @@ int ED25519_verify(const uint8_t *message, size_t message_len,
   fe_neg(A.X, A.X);
   fe_neg(A.T, A.T);
 
-  memcpy(pkcopy, public_key, 32);
   memcpy(rcopy, signature, 32);
   memcpy(scopy, signature + 32, 32);
 
@@ -4683,6 +4685,8 @@ void ED25519_public_from_private(uint8_t out_public_key[32],
 
   ge_scalarmult_base(&A, az);
   ge_p3_tobytes(out_public_key, &A);
+
+  OPENSSL_cleanse(az, sizeof(az));
 }
 
 int X25519(uint8_t out_shared_key[32], const uint8_t private_key[32],
@@ -4713,4 +4717,6 @@ void X25519_public_from_private(uint8_t out_public_value[32],
   fe_invert(zminusy_inv, zminusy);
   fe_mul(zplusy, zplusy, zminusy_inv);
   fe_tobytes(out_public_value, zplusy);
+
+  OPENSSL_cleanse(e, sizeof(e));
 }