ec/curve25519.c: facilitate assembly implementations.
[openssl.git] / crypto / ec / curve25519.c
index 72580334ff539279bca01ee093ecd4546b1f5d9a..f354107c5dcf64e16e12cd0fd75a31e0627bdec5 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the OpenSSL license (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * https://www.openssl.org/source/license.html
  */
 
-/* This code is mostly taken from the ref10 version of Ed25519 in SUPERCOP
- * 20141124 (http://bench.cr.yp.to/supercop.html).
- *
- * The field functions are shared by Ed25519 and X25519 where possible. */
-
 #include <string.h>
 #include "ec_lcl.h"
 #include <openssl/sha.h>
 
+#if defined(X25519_ASM) \
+    || ( !defined(PEDANTIC) && \
+         !defined(__sparc__) && \
+         (defined(__SIZEOF_INT128__) && __SIZEOF_INT128__==16) )
+/*
+ * Base 2^51 implementation.
+ */
+# define BASE_2_51_IMPLEMENTED
+
+typedef uint64_t fe51[5];
+# if !defined(X25519_ASM)
+typedef unsigned __int128 u128;
+# endif
+
+static const uint64_t MASK51 = 0x7ffffffffffff;
+
+static uint64_t load_7(const uint8_t *in)
+{
+    uint64_t result;
+
+    result = in[0];
+    result |= ((uint64_t)in[1]) << 8;
+    result |= ((uint64_t)in[2]) << 16;
+    result |= ((uint64_t)in[3]) << 24;
+    result |= ((uint64_t)in[4]) << 32;
+    result |= ((uint64_t)in[5]) << 40;
+    result |= ((uint64_t)in[6]) << 48;
+
+    return result;
+}
+
+static uint64_t load_6(const uint8_t *in)
+{
+    uint64_t result;
+
+    result = in[0];
+    result |= ((uint64_t)in[1]) << 8;
+    result |= ((uint64_t)in[2]) << 16;
+    result |= ((uint64_t)in[3]) << 24;
+    result |= ((uint64_t)in[4]) << 32;
+    result |= ((uint64_t)in[5]) << 40;
+
+    return result;
+}
+
+static void fe51_frombytes(fe51 h, const uint8_t *s)
+{
+    uint64_t h0 = load_7(s);                                /* 56 bits */
+    uint64_t h1 = load_6(s + 7) << 5;                       /* 53 bits */
+    uint64_t h2 = load_7(s + 13) << 2;                      /* 58 bits */
+    uint64_t h3 = load_6(s + 20) << 7;                      /* 55 bits */
+    uint64_t h4 = (load_6(s + 26) & 0x7fffffffffff) << 4;   /* 51 bits */
+
+    h1 |= h0 >> 51; h0 &= MASK51;
+    h2 |= h1 >> 51; h1 &= MASK51;
+    h3 |= h2 >> 51; h2 &= MASK51;
+    h4 |= h3 >> 51; h3 &= MASK51;
+
+    h[0] = h0;
+    h[1] = h1;
+    h[2] = h2;
+    h[3] = h3;
+    h[4] = h4;
+}
+
+static void fe51_tobytes(uint8_t *s, const fe51 h)
+{
+    uint64_t h0 = h[0];
+    uint64_t h1 = h[1];
+    uint64_t h2 = h[2];
+    uint64_t h3 = h[3];
+    uint64_t h4 = h[4];
+    uint64_t q;
+
+    /* compare to modulus */
+    q = (h0 + 19) >> 51;
+    q = (h1 + q) >> 51;
+    q = (h2 + q) >> 51;
+    q = (h3 + q) >> 51;
+    q = (h4 + q) >> 51;
+
+    /* full reduce */
+    h0 += 19 * q;
+    h1 += h0 >> 51; h0 &= MASK51;
+    h2 += h1 >> 51; h1 &= MASK51;
+    h3 += h2 >> 51; h2 &= MASK51;
+    h4 += h3 >> 51; h3 &= MASK51;
+                    h4 &= MASK51;
+
+    /* smash */
+    s[0] = h0 >> 0;
+    s[1] = h0 >> 8;
+    s[2] = h0 >> 16;
+    s[3] = h0 >> 24;
+    s[4] = h0 >> 32;
+    s[5] = h0 >> 40;
+    s[6] = (h0 >> 48) | ((uint32_t)h1 << 3);
+    s[7] = h1 >> 5;
+    s[8] = h1 >> 13;
+    s[9] = h1 >> 21;
+    s[10] = h1 >> 29;
+    s[11] = h1 >> 37;
+    s[12] = (h1 >> 45) | ((uint32_t)h2 << 6);
+    s[13] = h2 >> 2;
+    s[14] = h2 >> 10;
+    s[15] = h2 >> 18;
+    s[16] = h2 >> 26;
+    s[17] = h2 >> 34;
+    s[18] = h2 >> 42;
+    s[19] = (h2 >> 50) | ((uint32_t)h3 << 1);
+    s[20] = h3 >> 7;
+    s[21] = h3 >> 15;
+    s[22] = h3 >> 23;
+    s[23] = h3 >> 31;
+    s[24] = h3 >> 39;
+    s[25] = (h3 >> 47) | ((uint32_t)h4 << 4);
+    s[26] = h4 >> 4;
+    s[27] = h4 >> 12;
+    s[28] = h4 >> 20;
+    s[29] = h4 >> 28;
+    s[30] = h4 >> 36;
+    s[31] = h4 >> 44;
+}
+
+# ifdef X25519_ASM
+void x25519_fe51_mul(fe51 h, const fe51 f, const fe51 g);
+void x25519_fe51_sqr(fe51 h, const fe51 f);
+void x25519_fe51_mul121666(fe51 h, fe51 f);
+#  define fe51_mul x25519_fe51_mul
+#  define fe51_sq  x25519_fe51_sqr
+#  define fe51_mul121666 x25519_fe51_mul121666
+
+#  if defined(__x86_64) || defined(__x86_64__) || \
+      defined(_M_AMD64) || defined(_M_X64)
+
+#   define BASE_2_64_IMPLEMENTED
+
+typedef uint64_t fe64[4];
+
+int x25519_fe64_eligible();
+
+/*
+ * There are no reference C implementations for this radix.
+ */
+void x25519_fe64_mul(fe64 h, const fe64 f, const fe64 g);
+void x25519_fe64_sqr(fe64 h, const fe64 f);
+void x25519_fe64_mul121666(fe64 h, fe64 f);
+void x25519_fe64_add(fe64 h, const fe64 f, const fe64 g);
+void x25519_fe64_sub(fe64 h, const fe64 f, const fe64 g);
+void x25519_fe64_tobytes(uint8_t *s, const fe64 f);
+#   define fe64_mul x25519_fe64_mul
+#   define fe64_sqr x25519_fe64_sqr
+#   define fe64_mul121666 x25519_fe64_mul121666
+#   define fe64_add x25519_fe64_add
+#   define fe64_sub x25519_fe64_sub
+#   define fe64_tobytes x25519_fe64_tobytes
+
+static uint64_t load_8(const uint8_t *in)
+{
+    uint64_t result;
+
+    result = in[0];
+    result |= ((uint64_t)in[1]) << 8;
+    result |= ((uint64_t)in[2]) << 16;
+    result |= ((uint64_t)in[3]) << 24;
+    result |= ((uint64_t)in[4]) << 32;
+    result |= ((uint64_t)in[5]) << 40;
+    result |= ((uint64_t)in[6]) << 48;
+    result |= ((uint64_t)in[7]) << 56;
+
+    return result;
+}
+
+static void fe64_frombytes(fe64 h, const uint8_t *s)
+{
+    h[0] = load_8(s);
+    h[1] = load_8(s + 8);
+    h[2] = load_8(s + 16);
+    h[3] = load_8(s + 24) & 0x7fffffffffffffff;
+}
+
+static void fe64_0(fe64 h)
+{
+    h[0] = 0;
+    h[1] = 0;
+    h[2] = 0;
+    h[3] = 0;
+}
+
+static void fe64_1(fe64 h)
+{
+    h[0] = 1;
+    h[1] = 0;
+    h[2] = 0;
+    h[3] = 0;
+}
+
+static void fe64_copy(fe64 h, const fe64 f)
+{
+    h[0] = f[0];
+    h[1] = f[1];
+    h[2] = f[2];
+    h[3] = f[3];
+}
+
+static void fe64_cswap(fe64 f, fe64 g, unsigned int b)
+{
+    int i;
+    uint64_t mask = 0 - (uint64_t)b;
+
+    for (i = 0; i < 4; i++) {
+        uint64_t x = f[i] ^ g[i];
+        x &= mask;
+        f[i] ^= x;
+        g[i] ^= x;
+    }
+}
+
+static void fe64_invert(fe64 out, const fe64 z)
+{
+    fe64 t0;
+    fe64 t1;
+    fe64 t2;
+    fe64 t3;
+    int i;
+
+    /*
+     * Compute z ** -1 = z ** (2 ** 255 - 19 - 2) with the exponent as
+     * 2 ** 255 - 21 = (2 ** 5) * (2 ** 250 - 1) + 11.
+     */
+
+    /* t0 = z ** 2 */
+    fe64_sqr(t0, z);
+
+    /* t1 = t0 ** (2 ** 2) = z ** 8 */
+    fe64_sqr(t1, t0);
+    fe64_sqr(t1, t1);
+
+    /* t1 = z * t1 = z ** 9 */
+    fe64_mul(t1, z, t1);
+    /* t0 = t0 * t1 = z ** 11 -- stash t0 away for the end. */
+    fe64_mul(t0, t0, t1);
+
+    /* t2 = t0 ** 2 = z ** 22 */
+    fe64_sqr(t2, t0);
+
+    /* t1 = t1 * t2 = z ** (2 ** 5 - 1) */
+    fe64_mul(t1, t1, t2);
+
+    /* t2 = t1 ** (2 ** 5) = z ** ((2 ** 5) * (2 ** 5 - 1)) */
+    fe64_sqr(t2, t1);
+    for (i = 1; i < 5; ++i)
+        fe64_sqr(t2, t2);
+
+    /* t1 = t1 * t2 = z ** ((2 ** 5 + 1) * (2 ** 5 - 1)) = z ** (2 ** 10 - 1) */
+    fe64_mul(t1, t2, t1);
+
+    /* Continuing similarly... */
+
+    /* t2 = z ** (2 ** 20 - 1) */
+    fe64_sqr(t2, t1);
+    for (i = 1; i < 10; ++i)
+        fe64_sqr(t2, t2);
+
+    fe64_mul(t2, t2, t1);
+
+    /* t2 = z ** (2 ** 40 - 1) */
+    fe64_sqr(t3, t2);
+    for (i = 1; i < 20; ++i)
+        fe64_sqr(t3, t3);
+
+    fe64_mul(t2, t3, t2);
+
+    /* t2 = z ** (2 ** 10) * (2 ** 40 - 1) */
+    for (i = 0; i < 10; ++i)
+        fe64_sqr(t2, t2);
+
+    /* t1 = z ** (2 ** 50 - 1) */
+    fe64_mul(t1, t2, t1);
+
+    /* t2 = z ** (2 ** 100 - 1) */
+    fe64_sqr(t2, t1);
+    for (i = 1; i < 50; ++i)
+        fe64_sqr(t2, t2);
+
+    fe64_mul(t2, t2, t1);
+
+    /* t2 = z ** (2 ** 200 - 1) */
+    fe64_sqr(t3, t2);
+    for (i = 1; i < 100; ++i)
+        fe64_sqr(t3, t3);
+
+    fe64_mul(t2, t3, t2);
+
+    /* t2 = z ** ((2 ** 50) * (2 ** 200 - 1) */
+    for (i = 0; i < 50; ++i)
+        fe64_sqr(t2, t2);
+
+    /* t1 = z ** (2 ** 250 - 1) */
+    fe64_mul(t1, t2, t1);
+
+    /* t1 = z ** ((2 ** 5) * (2 ** 250 - 1)) */
+    for (i = 0; i < 5; ++i)
+        fe64_sqr(t1, t1);
+
+    /* Recall t0 = z ** 11; out = z ** (2 ** 255 - 21) */
+    fe64_mul(out, t1, t0);
+}
+
+/*
+ * Duplicate of original x25519_scalar_mult_generic, but using
+ * fe64_* subroutines.
+ */
+static void x25519_scalar_mulx(uint8_t out[32], const uint8_t scalar[32],
+                               const uint8_t point[32])
+{
+    fe64 x1, x2, z2, x3, z3, tmp0, tmp1;
+    uint8_t e[32];
+    unsigned swap = 0;
+    int pos;
+
+    memcpy(e, scalar, 32);
+    e[0]  &= 0xf8;
+    e[31] &= 0x7f;
+    e[31] |= 0x40;
+    fe64_frombytes(x1, point);
+    fe64_1(x2);
+    fe64_0(z2);
+    fe64_copy(x3, x1);
+    fe64_1(z3);
+
+    for (pos = 254; pos >= 0; --pos) {
+        unsigned int b = 1 & (e[pos / 8] >> (pos & 7));
+
+        swap ^= b;
+        fe64_cswap(x2, x3, swap);
+        fe64_cswap(z2, z3, swap);
+        swap = b;
+        fe64_sub(tmp0, x3, z3);
+        fe64_sub(tmp1, x2, z2);
+        fe64_add(x2, x2, z2);
+        fe64_add(z2, x3, z3);
+        fe64_mul(z3, x2, tmp0);
+        fe64_mul(z2, z2, tmp1);
+        fe64_sqr(tmp0, tmp1);
+        fe64_sqr(tmp1, x2);
+        fe64_add(x3, z3, z2);
+        fe64_sub(z2, z3, z2);
+        fe64_mul(x2, tmp1, tmp0);
+        fe64_sub(tmp1, tmp1, tmp0);
+        fe64_sqr(z2, z2);
+        fe64_mul121666(z3, tmp1);
+        fe64_sqr(x3, x3);
+        fe64_add(tmp0, tmp0, z3);
+        fe64_mul(z3, x1, z2);
+        fe64_mul(z2, tmp1, tmp0);
+    }
+
+    fe64_invert(z2, z2);
+    fe64_mul(x2, x2, z2);
+    fe64_tobytes(out, x2);
+
+    OPENSSL_cleanse(e, sizeof(e));
+}
+#  endif
+
+# else
+
+static void fe51_mul(fe51 h, const fe51 f, const fe51 g)
+{
+    u128 h0, h1, h2, h3, h4;
+    uint64_t f_i, g0, g1, g2, g3, g4;
+
+    f_i = f[0];
+    h0 = (u128)f_i * (g0 = g[0]);
+    h1 = (u128)f_i * (g1 = g[1]);
+    h2 = (u128)f_i * (g2 = g[2]);
+    h3 = (u128)f_i * (g3 = g[3]);
+    h4 = (u128)f_i * (g4 = g[4]);
+
+    f_i = f[1];
+    h0 += (u128)f_i * (g4 *= 19);
+    h1 += (u128)f_i * g0;
+    h2 += (u128)f_i * g1;
+    h3 += (u128)f_i * g2;
+    h4 += (u128)f_i * g3;
+
+    f_i = f[2];
+    h0 += (u128)f_i * (g3 *= 19);
+    h1 += (u128)f_i * g4;
+    h2 += (u128)f_i * g0;
+    h3 += (u128)f_i * g1;
+    h4 += (u128)f_i * g2;
+
+    f_i = f[3];
+    h0 += (u128)f_i * (g2 *= 19);
+    h1 += (u128)f_i * g3;
+    h2 += (u128)f_i * g4;
+    h3 += (u128)f_i * g0;
+    h4 += (u128)f_i * g1;
+
+    f_i = f[4];
+    h0 += (u128)f_i * (g1 *= 19);
+    h1 += (u128)f_i * g2;
+    h2 += (u128)f_i * g3;
+    h3 += (u128)f_i * g4;
+    h4 += (u128)f_i * g0;
+
+    /* partial [lazy] reduction */
+    h3 += (uint64_t)(h2 >> 51); g2 = (uint64_t)h2 & MASK51;
+    h1 += (uint64_t)(h0 >> 51); g0 = (uint64_t)h0 & MASK51;
+
+    h4 += (uint64_t)(h3 >> 51); g3 = (uint64_t)h3 & MASK51;
+    g2 += (uint64_t)(h1 >> 51); g1 = (uint64_t)h1 & MASK51;
+
+    g0 += (uint64_t)(h4 >> 51) * 19; g4 = (uint64_t)h4 & MASK51;
+    g3 += g2 >> 51; g2 &= MASK51;
+    g1 += g0 >> 51; g0 &= MASK51;
+
+    h[0] = g0;
+    h[1] = g1;
+    h[2] = g2;
+    h[3] = g3;
+    h[4] = g4;
+}
+
+static void fe51_sq(fe51 h, const fe51 f)
+{
+#  if defined(OPENSSL_SMALL_FOOTPRINT)
+    fe51_mul(h, f, f);
+#  else
+    /* dedicated squaring gives 16-25% overall improvement */
+    uint64_t g0 = f[0];
+    uint64_t g1 = f[1];
+    uint64_t g2 = f[2];
+    uint64_t g3 = f[3];
+    uint64_t g4 = f[4];
+    u128 h0, h1, h2, h3, h4;
+
+    h0 = (u128)g0 * g0;     g0 *= 2;
+    h1 = (u128)g0 * g1;
+    h2 = (u128)g0 * g2;
+    h3 = (u128)g0 * g3;
+    h4 = (u128)g0 * g4;
+
+    g0 = g4;                /* borrow g0 */
+    h3 += (u128)g0 * (g4 *= 19);
+
+    h2 += (u128)g1 * g1;    g1 *= 2;
+    h3 += (u128)g1 * g2;
+    h4 += (u128)g1 * g3;
+    h0 += (u128)g1 * g4;
+
+    g0 = g3;                /* borrow g0 */
+    h1 += (u128)g0 * (g3 *= 19);
+    h2 += (u128)(g0 * 2) * g4;
+
+    h4 += (u128)g2 * g2;    g2 *= 2;
+    h0 += (u128)g2 * g3;
+    h1 += (u128)g2 * g4;
+
+    /* partial [lazy] reduction */
+    h3 += (uint64_t)(h2 >> 51); g2 = (uint64_t)h2 & MASK51;
+    h1 += (uint64_t)(h0 >> 51); g0 = (uint64_t)h0 & MASK51;
+
+    h4 += (uint64_t)(h3 >> 51); g3 = (uint64_t)h3 & MASK51;
+    g2 += (uint64_t)(h1 >> 51); g1 = (uint64_t)h1 & MASK51;
+
+    g0 += (uint64_t)(h4 >> 51) * 19; g4 = (uint64_t)h4 & MASK51;
+    g3 += g2 >> 51; g2 &= MASK51;
+    g1 += g0 >> 51; g0 &= MASK51;
+
+    h[0] = g0;
+    h[1] = g1;
+    h[2] = g2;
+    h[3] = g3;
+    h[4] = g4;
+#  endif
+}
+
+static void fe51_mul121666(fe51 h, fe51 f)
+{
+    u128 h0 = f[0] * (u128)121666;
+    u128 h1 = f[1] * (u128)121666;
+    u128 h2 = f[2] * (u128)121666;
+    u128 h3 = f[3] * (u128)121666;
+    u128 h4 = f[4] * (u128)121666;
+    uint64_t g0, g1, g2, g3, g4;
+
+    h3 += (uint64_t)(h2 >> 51); g2 = (uint64_t)h2 & MASK51;
+    h1 += (uint64_t)(h0 >> 51); g0 = (uint64_t)h0 & MASK51;
+
+    h4 += (uint64_t)(h3 >> 51); g3 = (uint64_t)h3 & MASK51;
+    g2 += (uint64_t)(h1 >> 51); g1 = (uint64_t)h1 & MASK51;
+
+    g0 += (uint64_t)(h4 >> 51) * 19; g4 = (uint64_t)h4 & MASK51;
+    g3 += g2 >> 51; g2 &= MASK51;
+    g1 += g0 >> 51; g0 &= MASK51;
+
+    h[0] = g0;
+    h[1] = g1;
+    h[2] = g2;
+    h[3] = g3;
+    h[4] = g4;
+}
+# endif
+
+static void fe51_add(fe51 h, const fe51 f, const fe51 g)
+{
+    h[0] = f[0] + g[0];
+    h[1] = f[1] + g[1];
+    h[2] = f[2] + g[2];
+    h[3] = f[3] + g[3];
+    h[4] = f[4] + g[4];
+}
+
+static void fe51_sub(fe51 h, const fe51 f, const fe51 g)
+{
+    /*
+     * Add 2*modulus to ensure that result remains positive
+     * even if subtrahend is partially reduced.
+     */
+    h[0] = (f[0] + 0xfffffffffffda) - g[0];
+    h[1] = (f[1] + 0xffffffffffffe) - g[1];
+    h[2] = (f[2] + 0xffffffffffffe) - g[2];
+    h[3] = (f[3] + 0xffffffffffffe) - g[3];
+    h[4] = (f[4] + 0xffffffffffffe) - g[4];
+}
+
+static void fe51_0(fe51 h)
+{
+    h[0] = 0;
+    h[1] = 0;
+    h[2] = 0;
+    h[3] = 0;
+    h[4] = 0;
+}
+
+static void fe51_1(fe51 h)
+{
+    h[0] = 1;
+    h[1] = 0;
+    h[2] = 0;
+    h[3] = 0;
+    h[4] = 0;
+}
+
+static void fe51_copy(fe51 h, const fe51 f)
+{
+    h[0] = f[0];
+    h[1] = f[1];
+    h[2] = f[2];
+    h[3] = f[3];
+    h[4] = f[4];
+}
+
+static void fe51_cswap(fe51 f, fe51 g, unsigned int b)
+{
+    int i;
+    uint64_t mask = 0 - (uint64_t)b;
+
+    for (i = 0; i < 5; i++) {
+        int64_t x = f[i] ^ g[i];
+        x &= mask;
+        f[i] ^= x;
+        g[i] ^= x;
+    }
+}
+
+static void fe51_invert(fe51 out, const fe51 z)
+{
+    fe51 t0;
+    fe51 t1;
+    fe51 t2;
+    fe51 t3;
+    int i;
+
+    /*
+     * Compute z ** -1 = z ** (2 ** 255 - 19 - 2) with the exponent as
+     * 2 ** 255 - 21 = (2 ** 5) * (2 ** 250 - 1) + 11.
+     */
+
+    /* t0 = z ** 2 */
+    fe51_sq(t0, z);
+
+    /* t1 = t0 ** (2 ** 2) = z ** 8 */
+    fe51_sq(t1, t0);
+    fe51_sq(t1, t1);
+
+    /* t1 = z * t1 = z ** 9 */
+    fe51_mul(t1, z, t1);
+    /* t0 = t0 * t1 = z ** 11 -- stash t0 away for the end. */
+    fe51_mul(t0, t0, t1);
+
+    /* t2 = t0 ** 2 = z ** 22 */
+    fe51_sq(t2, t0);
+
+    /* t1 = t1 * t2 = z ** (2 ** 5 - 1) */
+    fe51_mul(t1, t1, t2);
+
+    /* t2 = t1 ** (2 ** 5) = z ** ((2 ** 5) * (2 ** 5 - 1)) */
+    fe51_sq(t2, t1);
+    for (i = 1; i < 5; ++i)
+        fe51_sq(t2, t2);
+
+    /* t1 = t1 * t2 = z ** ((2 ** 5 + 1) * (2 ** 5 - 1)) = z ** (2 ** 10 - 1) */
+    fe51_mul(t1, t2, t1);
+
+    /* Continuing similarly... */
+
+    /* t2 = z ** (2 ** 20 - 1) */
+    fe51_sq(t2, t1);
+    for (i = 1; i < 10; ++i)
+        fe51_sq(t2, t2);
+
+    fe51_mul(t2, t2, t1);
+
+    /* t2 = z ** (2 ** 40 - 1) */
+    fe51_sq(t3, t2);
+    for (i = 1; i < 20; ++i)
+        fe51_sq(t3, t3);
+
+    fe51_mul(t2, t3, t2);
+
+    /* t2 = z ** (2 ** 10) * (2 ** 40 - 1) */
+    for (i = 0; i < 10; ++i)
+        fe51_sq(t2, t2);
+
+    /* t1 = z ** (2 ** 50 - 1) */
+    fe51_mul(t1, t2, t1);
+
+    /* t2 = z ** (2 ** 100 - 1) */
+    fe51_sq(t2, t1);
+    for (i = 1; i < 50; ++i)
+        fe51_sq(t2, t2);
+
+    fe51_mul(t2, t2, t1);
+
+    /* t2 = z ** (2 ** 200 - 1) */
+    fe51_sq(t3, t2);
+    for (i = 1; i < 100; ++i)
+        fe51_sq(t3, t3);
+
+    fe51_mul(t2, t3, t2);
+
+    /* t2 = z ** ((2 ** 50) * (2 ** 200 - 1) */
+    for (i = 0; i < 50; ++i)
+        fe51_sq(t2, t2);
+
+    /* t1 = z ** (2 ** 250 - 1) */
+    fe51_mul(t1, t2, t1);
+
+    /* t1 = z ** ((2 ** 5) * (2 ** 250 - 1)) */
+    for (i = 0; i < 5; ++i)
+        fe51_sq(t1, t1);
+
+    /* Recall t0 = z ** 11; out = z ** (2 ** 255 - 21) */
+    fe51_mul(out, t1, t0);
+}
+
+/*
+ * Duplicate of original x25519_scalar_mult_generic, but using
+ * fe51_* subroutines.
+ */
+static void x25519_scalar_mult(uint8_t out[32], const uint8_t scalar[32],
+                               const uint8_t point[32])
+{
+    fe51 x1, x2, z2, x3, z3, tmp0, tmp1;
+    uint8_t e[32];
+    unsigned swap = 0;
+    int pos;
+
+# ifdef BASE_2_64_IMPLEMENTED
+    if (x25519_fe64_eligible()) {
+        x25519_scalar_mulx(out, scalar, point);
+        return;
+    }
+# endif
+
+    memcpy(e, scalar, 32);
+    e[0]  &= 0xf8;
+    e[31] &= 0x7f;
+    e[31] |= 0x40;
+    fe51_frombytes(x1, point);
+    fe51_1(x2);
+    fe51_0(z2);
+    fe51_copy(x3, x1);
+    fe51_1(z3);
+
+    for (pos = 254; pos >= 0; --pos) {
+        unsigned int b = 1 & (e[pos / 8] >> (pos & 7));
+
+        swap ^= b;
+        fe51_cswap(x2, x3, swap);
+        fe51_cswap(z2, z3, swap);
+        swap = b;
+        fe51_sub(tmp0, x3, z3);
+        fe51_sub(tmp1, x2, z2);
+        fe51_add(x2, x2, z2);
+        fe51_add(z2, x3, z3);
+        fe51_mul(z3, tmp0, x2);
+        fe51_mul(z2, z2, tmp1);
+        fe51_sq(tmp0, tmp1);
+        fe51_sq(tmp1, x2);
+        fe51_add(x3, z3, z2);
+        fe51_sub(z2, z3, z2);
+        fe51_mul(x2, tmp1, tmp0);
+        fe51_sub(tmp1, tmp1, tmp0);
+        fe51_sq(z2, z2);
+        fe51_mul121666(z3, tmp1);
+        fe51_sq(x3, x3);
+        fe51_add(tmp0, tmp0, z3);
+        fe51_mul(z3, x1, z2);
+        fe51_mul(z2, tmp1, tmp0);
+    }
+
+    fe51_invert(z2, z2);
+    fe51_mul(x2, x2, z2);
+    fe51_tobytes(out, x2);
+
+    OPENSSL_cleanse(e, sizeof(e));
+}
+#endif
+
+/*
+ * Reference base 2^25.5 implementation.
+ */
+/*
+ * This code is mostly taken from the ref10 version of Ed25519 in SUPERCOP
+ * 20141124 (http://bench.cr.yp.to/supercop.html).
+ *
+ * The field functions are shared by Ed25519 and X25519 where possible.
+ */
 
 /* fe means field element. Here the field is \Z/(2^255-19). An element t,
  * entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77
@@ -80,16 +808,16 @@ static void fe_frombytes(fe h, const uint8_t *s) {
   carry6 = h6 + (1 << 25); h7 += carry6 >> 26; h6 -= carry6 & kTop38Bits;
   carry8 = h8 + (1 << 25); h9 += carry8 >> 26; h8 -= carry8 & kTop38Bits;
 
-  h[0] = h0;
-  h[1] = h1;
-  h[2] = h2;
-  h[3] = h3;
-  h[4] = h4;
-  h[5] = h5;
-  h[6] = h6;
-  h[7] = h7;
-  h[8] = h8;
-  h[9] = h9;
+  h[0] = (int32_t)h0;
+  h[1] = (int32_t)h1;
+  h[2] = (int32_t)h2;
+  h[3] = (int32_t)h3;
+  h[4] = (int32_t)h4;
+  h[5] = (int32_t)h5;
+  h[6] = (int32_t)h6;
+  h[7] = (int32_t)h7;
+  h[8] = (int32_t)h8;
+  h[9] = (int32_t)h9;
 }
 
 /* Preconditions:
@@ -471,16 +1199,16 @@ static void fe_mul(fe h, const fe f, const fe g) {
   /* |h0| <= 2^25; from now on fits into int32 unchanged */
   /* |h1| <= 1.01*2^24 */
 
-  h[0] = h0;
-  h[1] = h1;
-  h[2] = h2;
-  h[3] = h3;
-  h[4] = h4;
-  h[5] = h5;
-  h[6] = h6;
-  h[7] = h7;
-  h[8] = h8;
-  h[9] = h9;
+  h[0] = (int32_t)h0;
+  h[1] = (int32_t)h1;
+  h[2] = (int32_t)h2;
+  h[3] = (int32_t)h3;
+  h[4] = (int32_t)h4;
+  h[5] = (int32_t)h5;
+  h[6] = (int32_t)h6;
+  h[7] = (int32_t)h7;
+  h[8] = (int32_t)h8;
+  h[9] = (int32_t)h9;
 }
 
 /* h = f * f
@@ -612,16 +1340,16 @@ static void fe_sq(fe h, const fe f) {
 
   carry0 = h0 + (1 << 25); h1 += carry0 >> 26; h0 -= carry0 & kTop38Bits;
 
-  h[0] = h0;
-  h[1] = h1;
-  h[2] = h2;
-  h[3] = h3;
-  h[4] = h4;
-  h[5] = h5;
-  h[6] = h6;
-  h[7] = h7;
-  h[8] = h8;
-  h[9] = h9;
+  h[0] = (int32_t)h0;
+  h[1] = (int32_t)h1;
+  h[2] = (int32_t)h2;
+  h[3] = (int32_t)h3;
+  h[4] = (int32_t)h4;
+  h[5] = (int32_t)h5;
+  h[6] = (int32_t)h6;
+  h[7] = (int32_t)h7;
+  h[8] = (int32_t)h8;
+  h[9] = (int32_t)h9;
 }
 
 static void fe_invert(fe out, const fe z) {
@@ -911,16 +1639,16 @@ static void fe_sq2(fe h, const fe f) {
 
   carry0 = h0 + (1 << 25); h1 += carry0 >> 26; h0 -= carry0 & kTop38Bits;
 
-  h[0] = h0;
-  h[1] = h1;
-  h[2] = h2;
-  h[3] = h3;
-  h[4] = h4;
-  h[5] = h5;
-  h[6] = h6;
-  h[7] = h7;
-  h[8] = h8;
-  h[9] = h9;
+  h[0] = (int32_t)h0;
+  h[1] = (int32_t)h1;
+  h[2] = (int32_t)h2;
+  h[3] = (int32_t)h3;
+  h[4] = (int32_t)h4;
+  h[5] = (int32_t)h5;
+  h[6] = (int32_t)h6;
+  h[7] = (int32_t)h7;
+  h[8] = (int32_t)h8;
+  h[9] = (int32_t)h9;
 }
 
 static void fe_pow22523(fe out, const fe z) {
@@ -3448,8 +4176,11 @@ static void ge_scalarmult_base(ge_p3 *h, const uint8_t *a) {
     ge_madd(&r, h, &t);
     ge_p1p1_to_p3(h, &r);
   }
+
+  OPENSSL_cleanse(e, sizeof(e));
 }
 
+#if !defined(BASE_2_51_IMPLEMENTED)
 /* Replace (f,g) with (g,f) if b == 1;
  * replace (f,g) with (f,g) if b == 0.
  *
@@ -3517,16 +4248,16 @@ static void fe_mul121666(fe h, fe f) {
   carry6 = h6 + (1 << 25); h7 += carry6 >> 26; h6 -= carry6 & kTop38Bits;
   carry8 = h8 + (1 << 25); h9 += carry8 >> 26; h8 -= carry8 & kTop38Bits;
 
-  h[0] = h0;
-  h[1] = h1;
-  h[2] = h2;
-  h[3] = h3;
-  h[4] = h4;
-  h[5] = h5;
-  h[6] = h6;
-  h[7] = h7;
-  h[8] = h8;
-  h[9] = h9;
+  h[0] = (int32_t)h0;
+  h[1] = (int32_t)h1;
+  h[2] = (int32_t)h2;
+  h[3] = (int32_t)h3;
+  h[4] = (int32_t)h4;
+  h[5] = (int32_t)h5;
+  h[6] = (int32_t)h6;
+  h[7] = (int32_t)h7;
+  h[8] = (int32_t)h8;
+  h[9] = (int32_t)h9;
 }
 
 static void x25519_scalar_mult_generic(uint8_t out[32],
@@ -3572,18 +4303,19 @@ static void x25519_scalar_mult_generic(uint8_t out[32],
     fe_mul(z3, x1, z2);
     fe_mul(z2, tmp1, tmp0);
   }
-  fe_cswap(x2, x3, swap);
-  fe_cswap(z2, z3, swap);
 
   fe_invert(z2, z2);
   fe_mul(x2, x2, z2);
   fe_tobytes(out, x2);
+
+  OPENSSL_cleanse(e, sizeof(e));
 }
 
 static void x25519_scalar_mult(uint8_t out[32], const uint8_t scalar[32],
                                const uint8_t point[32]) {
   x25519_scalar_mult_generic(out, scalar, point);
 }
+#endif
 
 static void slide(signed char *r, const uint8_t *a) {
   int i;
@@ -3862,38 +4594,38 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry12 = (s12 + (1 << 20)) >> 21;
   s13 += carry12;
-  s12 -= carry12 << 21;
+  s12 -= carry12 * (1 << 21);
   carry14 = (s14 + (1 << 20)) >> 21;
   s15 += carry14;
-  s14 -= carry14 << 21;
+  s14 -= carry14 * (1 << 21);
   carry16 = (s16 + (1 << 20)) >> 21;
   s17 += carry16;
-  s16 -= carry16 << 21;
+  s16 -= carry16 * (1 << 21);
 
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
   carry13 = (s13 + (1 << 20)) >> 21;
   s14 += carry13;
-  s13 -= carry13 << 21;
+  s13 -= carry13 * (1 << 21);
   carry15 = (s15 + (1 << 20)) >> 21;
   s16 += carry15;
-  s15 -= carry15 << 21;
+  s15 -= carry15 * (1 << 21);
 
   s5 += s17 * 666643;
   s6 += s17 * 470296;
@@ -3945,41 +4677,41 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry0 = (s0 + (1 << 20)) >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry2 = (s2 + (1 << 20)) >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry4 = (s4 + (1 << 20)) >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
 
   carry1 = (s1 + (1 << 20)) >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry3 = (s3 + (1 << 20)) >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry5 = (s5 + (1 << 20)) >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -3991,40 +4723,40 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry11 = s11 >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -4036,70 +4768,70 @@ static void x25519_sc_reduce(uint8_t *s) {
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
-
-  s[0] = s0 >> 0;
-  s[1] = s0 >> 8;
-  s[2] = (s0 >> 16) | (s1 << 5);
-  s[3] = s1 >> 3;
-  s[4] = s1 >> 11;
-  s[5] = (s1 >> 19) | (s2 << 2);
-  s[6] = s2 >> 6;
-  s[7] = (s2 >> 14) | (s3 << 7);
-  s[8] = s3 >> 1;
-  s[9] = s3 >> 9;
-  s[10] = (s3 >> 17) | (s4 << 4);
-  s[11] = s4 >> 4;
-  s[12] = s4 >> 12;
-  s[13] = (s4 >> 20) | (s5 << 1);
-  s[14] = s5 >> 7;
-  s[15] = (s5 >> 15) | (s6 << 6);
-  s[16] = s6 >> 2;
-  s[17] = s6 >> 10;
-  s[18] = (s6 >> 18) | (s7 << 3);
-  s[19] = s7 >> 5;
-  s[20] = s7 >> 13;
-  s[21] = s8 >> 0;
-  s[22] = s8 >> 8;
-  s[23] = (s8 >> 16) | (s9 << 5);
-  s[24] = s9 >> 3;
-  s[25] = s9 >> 11;
-  s[26] = (s9 >> 19) | (s10 << 2);
-  s[27] = s10 >> 6;
-  s[28] = (s10 >> 14) | (s11 << 7);
-  s[29] = s11 >> 1;
-  s[30] = s11 >> 9;
-  s[31] = s11 >> 17;
+  s10 -= carry10 * (1 << 21);
+
+  s[0] = (uint8_t)(s0 >> 0);
+  s[1] = (uint8_t)(s0 >> 8);
+  s[2] = (uint8_t)((s0 >> 16) | (s1 << 5));
+  s[3] = (uint8_t)(s1 >> 3);
+  s[4] = (uint8_t)(s1 >> 11);
+  s[5] = (uint8_t)((s1 >> 19) | (s2 << 2));
+  s[6] = (uint8_t)(s2 >> 6);
+  s[7] = (uint8_t)((s2 >> 14) | (s3 << 7));
+  s[8] = (uint8_t)(s3 >> 1);
+  s[9] = (uint8_t)(s3 >> 9);
+  s[10] = (uint8_t)((s3 >> 17) | (s4 << 4));
+  s[11] = (uint8_t)(s4 >> 4);
+  s[12] = (uint8_t)(s4 >> 12);
+  s[13] = (uint8_t)((s4 >> 20) | (s5 << 1));
+  s[14] = (uint8_t)(s5 >> 7);
+  s[15] = (uint8_t)((s5 >> 15) | (s6 << 6));
+  s[16] = (uint8_t)(s6 >> 2);
+  s[17] = (uint8_t)(s6 >> 10);
+  s[18] = (uint8_t)((s6 >> 18) | (s7 << 3));
+  s[19] = (uint8_t)(s7 >> 5);
+  s[20] = (uint8_t)(s7 >> 13);
+  s[21] = (uint8_t)(s8 >> 0);
+  s[22] = (uint8_t)(s8 >> 8);
+  s[23] = (uint8_t)((s8 >> 16) | (s9 << 5));
+  s[24] = (uint8_t)(s9 >> 3);
+  s[25] = (uint8_t)(s9 >> 11);
+  s[26] = (uint8_t)((s9 >> 19) | (s10 << 2));
+  s[27] = (uint8_t)(s10 >> 6);
+  s[28] = (uint8_t)((s10 >> 14) | (s11 << 7));
+  s[29] = (uint8_t)(s11 >> 1);
+  s[30] = (uint8_t)(s11 >> 9);
+  s[31] = (uint8_t)(s11 >> 17);
 }
 
 /* Input:
@@ -4232,74 +4964,74 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = (s0 + (1 << 20)) >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry2 = (s2 + (1 << 20)) >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry4 = (s4 + (1 << 20)) >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry12 = (s12 + (1 << 20)) >> 21;
   s13 += carry12;
-  s12 -= carry12 << 21;
+  s12 -= carry12 * (1 << 21);
   carry14 = (s14 + (1 << 20)) >> 21;
   s15 += carry14;
-  s14 -= carry14 << 21;
+  s14 -= carry14 * (1 << 21);
   carry16 = (s16 + (1 << 20)) >> 21;
   s17 += carry16;
-  s16 -= carry16 << 21;
+  s16 -= carry16 * (1 << 21);
   carry18 = (s18 + (1 << 20)) >> 21;
   s19 += carry18;
-  s18 -= carry18 << 21;
+  s18 -= carry18 * (1 << 21);
   carry20 = (s20 + (1 << 20)) >> 21;
   s21 += carry20;
-  s20 -= carry20 << 21;
+  s20 -= carry20 * (1 << 21);
   carry22 = (s22 + (1 << 20)) >> 21;
   s23 += carry22;
-  s22 -= carry22 << 21;
+  s22 -= carry22 * (1 << 21);
 
   carry1 = (s1 + (1 << 20)) >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry3 = (s3 + (1 << 20)) >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry5 = (s5 + (1 << 20)) >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
   carry13 = (s13 + (1 << 20)) >> 21;
   s14 += carry13;
-  s13 -= carry13 << 21;
+  s13 -= carry13 * (1 << 21);
   carry15 = (s15 + (1 << 20)) >> 21;
   s16 += carry15;
-  s15 -= carry15 << 21;
+  s15 -= carry15 * (1 << 21);
   carry17 = (s17 + (1 << 20)) >> 21;
   s18 += carry17;
-  s17 -= carry17 << 21;
+  s17 -= carry17 * (1 << 21);
   carry19 = (s19 + (1 << 20)) >> 21;
   s20 += carry19;
-  s19 -= carry19 << 21;
+  s19 -= carry19 * (1 << 21);
   carry21 = (s21 + (1 << 20)) >> 21;
   s22 += carry21;
-  s21 -= carry21 << 21;
+  s21 -= carry21 * (1 << 21);
 
   s11 += s23 * 666643;
   s12 += s23 * 470296;
@@ -4351,38 +5083,38 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry12 = (s12 + (1 << 20)) >> 21;
   s13 += carry12;
-  s12 -= carry12 << 21;
+  s12 -= carry12 * (1 << 21);
   carry14 = (s14 + (1 << 20)) >> 21;
   s15 += carry14;
-  s14 -= carry14 << 21;
+  s14 -= carry14 * (1 << 21);
   carry16 = (s16 + (1 << 20)) >> 21;
   s17 += carry16;
-  s16 -= carry16 << 21;
+  s16 -= carry16 * (1 << 21);
 
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
   carry13 = (s13 + (1 << 20)) >> 21;
   s14 += carry13;
-  s13 -= carry13 << 21;
+  s13 -= carry13 * (1 << 21);
   carry15 = (s15 + (1 << 20)) >> 21;
   s16 += carry15;
-  s15 -= carry15 << 21;
+  s15 -= carry15 * (1 << 21);
 
   s5 += s17 * 666643;
   s6 += s17 * 470296;
@@ -4434,41 +5166,41 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = (s0 + (1 << 20)) >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry2 = (s2 + (1 << 20)) >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry4 = (s4 + (1 << 20)) >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry6 = (s6 + (1 << 20)) >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry8 = (s8 + (1 << 20)) >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry10 = (s10 + (1 << 20)) >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
 
   carry1 = (s1 + (1 << 20)) >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry3 = (s3 + (1 << 20)) >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry5 = (s5 + (1 << 20)) >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry7 = (s7 + (1 << 20)) >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry9 = (s9 + (1 << 20)) >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry11 = (s11 + (1 << 20)) >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -4480,40 +5212,40 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
+  s10 -= carry10 * (1 << 21);
   carry11 = s11 >> 21;
   s12 += carry11;
-  s11 -= carry11 << 21;
+  s11 -= carry11 * (1 << 21);
 
   s0 += s12 * 666643;
   s1 += s12 * 470296;
@@ -4525,70 +5257,70 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b,
 
   carry0 = s0 >> 21;
   s1 += carry0;
-  s0 -= carry0 << 21;
+  s0 -= carry0 * (1 << 21);
   carry1 = s1 >> 21;
   s2 += carry1;
-  s1 -= carry1 << 21;
+  s1 -= carry1 * (1 << 21);
   carry2 = s2 >> 21;
   s3 += carry2;
-  s2 -= carry2 << 21;
+  s2 -= carry2 * (1 << 21);
   carry3 = s3 >> 21;
   s4 += carry3;
-  s3 -= carry3 << 21;
+  s3 -= carry3 * (1 << 21);
   carry4 = s4 >> 21;
   s5 += carry4;
-  s4 -= carry4 << 21;
+  s4 -= carry4 * (1 << 21);
   carry5 = s5 >> 21;
   s6 += carry5;
-  s5 -= carry5 << 21;
+  s5 -= carry5 * (1 << 21);
   carry6 = s6 >> 21;
   s7 += carry6;
-  s6 -= carry6 << 21;
+  s6 -= carry6 * (1 << 21);
   carry7 = s7 >> 21;
   s8 += carry7;
-  s7 -= carry7 << 21;
+  s7 -= carry7 * (1 << 21);
   carry8 = s8 >> 21;
   s9 += carry8;
-  s8 -= carry8 << 21;
+  s8 -= carry8 * (1 << 21);
   carry9 = s9 >> 21;
   s10 += carry9;
-  s9 -= carry9 << 21;
+  s9 -= carry9 * (1 << 21);
   carry10 = s10 >> 21;
   s11 += carry10;
-  s10 -= carry10 << 21;
-
-  s[0] = s0 >> 0;
-  s[1] = s0 >> 8;
-  s[2] = (s0 >> 16) | (s1 << 5);
-  s[3] = s1 >> 3;
-  s[4] = s1 >> 11;
-  s[5] = (s1 >> 19) | (s2 << 2);
-  s[6] = s2 >> 6;
-  s[7] = (s2 >> 14) | (s3 << 7);
-  s[8] = s3 >> 1;
-  s[9] = s3 >> 9;
-  s[10] = (s3 >> 17) | (s4 << 4);
-  s[11] = s4 >> 4;
-  s[12] = s4 >> 12;
-  s[13] = (s4 >> 20) | (s5 << 1);
-  s[14] = s5 >> 7;
-  s[15] = (s5 >> 15) | (s6 << 6);
-  s[16] = s6 >> 2;
-  s[17] = s6 >> 10;
-  s[18] = (s6 >> 18) | (s7 << 3);
-  s[19] = s7 >> 5;
-  s[20] = s7 >> 13;
-  s[21] = s8 >> 0;
-  s[22] = s8 >> 8;
-  s[23] = (s8 >> 16) | (s9 << 5);
-  s[24] = s9 >> 3;
-  s[25] = s9 >> 11;
-  s[26] = (s9 >> 19) | (s10 << 2);
-  s[27] = s10 >> 6;
-  s[28] = (s10 >> 14) | (s11 << 7);
-  s[29] = s11 >> 1;
-  s[30] = s11 >> 9;
-  s[31] = s11 >> 17;
+  s10 -= carry10 * (1 << 21);
+
+  s[0] = (uint8_t)(s0 >> 0);
+  s[1] = (uint8_t)(s0 >> 8);
+  s[2] = (uint8_t)((s0 >> 16) | (s1 << 5));
+  s[3] = (uint8_t)(s1 >> 3);
+  s[4] = (uint8_t)(s1 >> 11);
+  s[5] = (uint8_t)((s1 >> 19) | (s2 << 2));
+  s[6] = (uint8_t)(s2 >> 6);
+  s[7] = (uint8_t)((s2 >> 14) | (s3 << 7));
+  s[8] = (uint8_t)(s3 >> 1);
+  s[9] = (uint8_t)(s3 >> 9);
+  s[10] = (uint8_t)((s3 >> 17) | (s4 << 4));
+  s[11] = (uint8_t)(s4 >> 4);
+  s[12] = (uint8_t)(s4 >> 12);
+  s[13] = (uint8_t)((s4 >> 20) | (s5 << 1));
+  s[14] = (uint8_t)(s5 >> 7);
+  s[15] = (uint8_t)((s5 >> 15) | (s6 << 6));
+  s[16] = (uint8_t)(s6 >> 2);
+  s[17] = (uint8_t)(s6 >> 10);
+  s[18] = (uint8_t)((s6 >> 18) | (s7 << 3));
+  s[19] = (uint8_t)(s7 >> 5);
+  s[20] = (uint8_t)(s7 >> 13);
+  s[21] = (uint8_t)(s8 >> 0);
+  s[22] = (uint8_t)(s8 >> 8);
+  s[23] = (uint8_t)((s8 >> 16) | (s9 << 5));
+  s[24] = (uint8_t)(s9 >> 3);
+  s[25] = (uint8_t)(s9 >> 11);
+  s[26] = (uint8_t)((s9 >> 19) | (s10 << 2));
+  s[27] = (uint8_t)(s10 >> 6);
+  s[28] = (uint8_t)((s10 >> 14) | (s11 << 7));
+  s[29] = (uint8_t)(s11 >> 1);
+  s[30] = (uint8_t)(s11 >> 9);
+  s[31] = (uint8_t)(s11 >> 17);
 }
 
 int ED25519_sign(uint8_t *out_sig, const uint8_t *message, size_t message_len,
@@ -4599,7 +5331,9 @@ int ED25519_sign(uint8_t *out_sig, const uint8_t *message, size_t message_len,
   uint8_t hram[SHA512_DIGEST_LENGTH];
   SHA512_CTX hash_ctx;
 
-  SHA512(private_key, 32, az);
+  SHA512_Init(&hash_ctx);
+  SHA512_Update(&hash_ctx, private_key, 32);
+  SHA512_Final(az, &hash_ctx);
 
   az[0] &= 248;
   az[31] &= 63;
@@ -4623,13 +5357,16 @@ int ED25519_sign(uint8_t *out_sig, const uint8_t *message, size_t message_len,
   x25519_sc_reduce(hram);
   sc_muladd(out_sig + 32, hram, az, nonce);
 
+  OPENSSL_cleanse(&hash_ctx, sizeof(hash_ctx));
+  OPENSSL_cleanse(nonce, sizeof(nonce));
+  OPENSSL_cleanse(az, sizeof(az));
+
   return 1;
 }
 
 int ED25519_verify(const uint8_t *message, size_t message_len,
                    const uint8_t signature[64], const uint8_t public_key[32]) {
   ge_p3 A;
-  uint8_t pkcopy[32];
   uint8_t rcopy[32];
   uint8_t scopy[32];
   SHA512_CTX hash_ctx;
@@ -4645,7 +5382,6 @@ int ED25519_verify(const uint8_t *message, size_t message_len,
   fe_neg(A.X, A.X);
   fe_neg(A.T, A.T);
 
-  memcpy(pkcopy, public_key, 32);
   memcpy(rcopy, signature, 32);
   memcpy(scopy, signature + 32, 32);
 
@@ -4677,6 +5413,8 @@ void ED25519_public_from_private(uint8_t out_public_key[32],
 
   ge_scalarmult_base(&A, az);
   ge_p3_tobytes(out_public_key, &A);
+
+  OPENSSL_cleanse(az, sizeof(az));
 }
 
 int X25519(uint8_t out_shared_key[32], const uint8_t private_key[32],
@@ -4707,4 +5445,6 @@ void X25519_public_from_private(uint8_t out_public_value[32],
   fe_invert(zminusy_inv, zminusy);
   fe_mul(zplusy, zplusy, zminusy_inv);
   fe_tobytes(out_public_value, zplusy);
+
+  OPENSSL_cleanse(e, sizeof(e));
 }