Skip to content

Commit

Permalink
Alternative fix for CVE-2022-4304
Browse files Browse the repository at this point in the history
This is about a timing leak in the topmost limb
of the internal result of RSA_private_decrypt,
before the padding check.

There are in fact at least three bugs together that
caused the timing leak:

First and probably most important is the fact that
the blinding did not use the constant time code path
at all when the RSA object was used for a private
decrypt, due to the fact that the Montgomery context
rsa->_method_mod_n was not set up early enough in
rsa_ossl_private_decrypt, when BN_BLINDING_create_param
needed it, and that was persisted as blinding->m_ctx,
although the RSA object creates the Montgomery context
just a bit later.

Then the infamous bn_correct_top was used on the
secret value right after the blinding was removed.

And finally the function BN_bn2binpad did not use
the constant-time code path since the BN_FLG_CONSTTIME
was not set on the secret value.

In order to address the first problem, this patch
makes sure that the rsa->_method_mod_n is initialized
right before the blinding context.

And to fix the second problem, we add a new utility
function bn_correct_top_consttime, a const-time
variant of bn_correct_top.

Together with the fact, that BN_bn2binpad is already
constant time if the flag BN_FLG_CONSTTIME is set,
this should eliminate the timing oracle completely.

In addition the no-asm variant may also have
branches that depend on secret values, because the last
invocation of bn_sub_words in bn_from_montgomery_word
had branches when the function is compiled by certain
gcc compiler versions, due to the clumsy coding style.

So additionally this patch stream-lined the no-asm
C-code in order to avoid branches where possible and
improve the resulting code quality.

Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from #20284)
  • Loading branch information
bernd-edlinger committed Mar 31, 2023
1 parent 0372649 commit 3f499b2
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 69 deletions.
10 changes: 10 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@

Changes between 1.1.1t and 1.1.1u [xx XXX xxxx]

*) Reworked the Fix for the Timing Oracle in RSA Decryption (CVE-2022-4304).
The previous fix for this timing side channel turned out to cause
a severe 2-3x performance regression in the typical use case
compared to 1.1.1s. The new fix uses existing constant time
code paths, and restores the previous performance level while
fully eliminating all existing timing side channels.
The fix was developed by Bernd Edlinger with testing support
by Hubert Kario.
[Bernd Edlinger]

*) Corrected documentation of X509_VERIFY_PARAM_add0_policy() to mention
that it does not enable policy checking. Thanks to
David Benjamin for discovering this issue. (CVE-2023-0466)
Expand Down
106 changes: 58 additions & 48 deletions crypto/bn/bn_asm.c
Original file line number Diff line number Diff line change
Expand Up @@ -381,25 +381,33 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
#ifndef OPENSSL_SMALL_FOOTPRINT
while (n & ~3) {
t1 = a[0];
t2 = b[0];
r[0] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[0];
t1 = (t2 - t1) & BN_MASK2;
r[0] = t1;
c += (t1 > t2);
t1 = a[1];
t2 = b[1];
r[1] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[1];
t1 = (t2 - t1) & BN_MASK2;
r[1] = t1;
c += (t1 > t2);
t1 = a[2];
t2 = b[2];
r[2] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[2];
t1 = (t2 - t1) & BN_MASK2;
r[2] = t1;
c += (t1 > t2);
t1 = a[3];
t2 = b[3];
r[3] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[3];
t1 = (t2 - t1) & BN_MASK2;
r[3] = t1;
c += (t1 > t2);
a += 4;
b += 4;
r += 4;
Expand All @@ -408,10 +416,12 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
#endif
while (n) {
t1 = a[0];
t2 = b[0];
r[0] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[0];
t1 = (t2 - t1) & BN_MASK2;
r[0] = t1;
c += (t1 > t2);
a++;
b++;
r++;
Expand Down Expand Up @@ -446,7 +456,7 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
t += c0; /* no carry */ \
c0 = (BN_ULONG)Lw(t); \
hi = (BN_ULONG)Hw(t); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)

# define mul_add_c2(a,b,c0,c1,c2) do { \
Expand All @@ -455,11 +465,11 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULLONG tt = t+c0; /* no carry */ \
c0 = (BN_ULONG)Lw(tt); \
hi = (BN_ULONG)Hw(tt); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
t += c0; /* no carry */ \
c0 = (BN_ULONG)Lw(t); \
hi = (BN_ULONG)Hw(t); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)

# define sqr_add_c(a,i,c0,c1,c2) do { \
Expand All @@ -468,7 +478,7 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
t += c0; /* no carry */ \
c0 = (BN_ULONG)Lw(t); \
hi = (BN_ULONG)Hw(t); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)

# define sqr_add_c2(a,i,j,c0,c1,c2) \
Expand All @@ -483,26 +493,26 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG ta = (a), tb = (b); \
BN_ULONG lo, hi; \
BN_UMULT_LOHI(lo,hi,ta,tb); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)

# define mul_add_c2(a,b,c0,c1,c2) do { \
BN_ULONG ta = (a), tb = (b); \
BN_ULONG lo, hi, tt; \
BN_UMULT_LOHI(lo,hi,ta,tb); \
c0 += lo; tt = hi+((c0<lo)?1:0); \
c1 += tt; c2 += (c1<tt)?1:0; \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; tt = hi + (c0<lo); \
c1 += tt; c2 += (c1<tt); \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)

# define sqr_add_c(a,i,c0,c1,c2) do { \
BN_ULONG ta = (a)[i]; \
BN_ULONG lo, hi; \
BN_UMULT_LOHI(lo,hi,ta,ta); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)

# define sqr_add_c2(a,i,j,c0,c1,c2) \
Expand All @@ -517,26 +527,26 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG ta = (a), tb = (b); \
BN_ULONG lo = ta * tb; \
BN_ULONG hi = BN_UMULT_HIGH(ta,tb); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)

# define mul_add_c2(a,b,c0,c1,c2) do { \
BN_ULONG ta = (a), tb = (b), tt; \
BN_ULONG lo = ta * tb; \
BN_ULONG hi = BN_UMULT_HIGH(ta,tb); \
c0 += lo; tt = hi + ((c0<lo)?1:0); \
c1 += tt; c2 += (c1<tt)?1:0; \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; tt = hi + (c0<lo); \
c1 += tt; c2 += (c1<tt); \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)

# define sqr_add_c(a,i,c0,c1,c2) do { \
BN_ULONG ta = (a)[i]; \
BN_ULONG lo = ta * ta; \
BN_ULONG hi = BN_UMULT_HIGH(ta,ta); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)

# define sqr_add_c2(a,i,j,c0,c1,c2) \
Expand All @@ -551,8 +561,8 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG lo = LBITS(a), hi = HBITS(a); \
BN_ULONG bl = LBITS(b), bh = HBITS(b); \
mul64(lo,hi,bl,bh); \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c0 = (c0+lo)&BN_MASK2; hi += (c0<lo); \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)

# define mul_add_c2(a,b,c0,c1,c2) do { \
Expand All @@ -561,17 +571,17 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG bl = LBITS(b), bh = HBITS(b); \
mul64(lo,hi,bl,bh); \
tt = hi; \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) tt++; \
c1 = (c1+tt)&BN_MASK2; if (c1<tt) c2++; \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c0 = (c0+lo)&BN_MASK2; tt += (c0<lo); \
c1 = (c1+tt)&BN_MASK2; c2 += (c1<tt); \
c0 = (c0+lo)&BN_MASK2; hi += (c0<lo); \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)

# define sqr_add_c(a,i,c0,c1,c2) do { \
BN_ULONG lo, hi; \
sqr64(lo,hi,(a)[i]); \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c0 = (c0+lo)&BN_MASK2; hi += (c0<lo); \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)

# define sqr_add_c2(a,i,j,c0,c1,c2) \
Expand Down
3 changes: 2 additions & 1 deletion crypto/bn/bn_blind.c
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ int BN_BLINDING_invert_ex(BIGNUM *n, const BIGNUM *r, BN_BLINDING *b,
n->top = (int)(rtop & ~mask) | (ntop & mask);
n->flags |= (BN_FLG_FIXED_TOP & ~mask);
}
ret = BN_mod_mul_montgomery(n, n, r, b->m_ctx, ctx);
ret = bn_mul_mont_fixed_top(n, n, r, b->m_ctx, ctx);
bn_correct_top_consttime(n);
} else {
ret = BN_mod_mul(n, n, r, b->mod, ctx);
}
Expand Down
22 changes: 22 additions & 0 deletions crypto/bn/bn_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,28 @@ BIGNUM *bn_wexpand(BIGNUM *a, int words)
return (words <= a->dmax) ? a : bn_expand2(a, words);
}

void bn_correct_top_consttime(BIGNUM *a)
{
int j, atop;
BN_ULONG limb;
unsigned int mask;

for (j = 0, atop = 0; j < a->dmax; j++) {
limb = a->d[j];
limb |= 0 - limb;
limb >>= BN_BITS2 - 1;
limb = 0 - limb;
mask = (unsigned int)limb;
mask &= constant_time_msb(j - a->top);
atop = constant_time_select_int(mask, j + 1, atop);
}

mask = constant_time_eq_int(atop, 0);
a->top = atop;
a->neg = constant_time_select_int(mask, 0, a->neg);
a->flags &= ~BN_FLG_FIXED_TOP;
}

void bn_correct_top(BIGNUM *a)
{
BN_ULONG *ftl;
Expand Down
26 changes: 13 additions & 13 deletions crypto/bn/bn_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
ret = (r); \
BN_UMULT_LOHI(low,high,w,tmp); \
ret += (c); \
(c) = (ret<(c))?1:0; \
(c) = (ret<(c)); \
(c) += high; \
ret += low; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}

Expand All @@ -527,7 +527,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
BN_UMULT_LOHI(low,high,w,ta); \
ret = low + (c); \
(c) = high; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}

Expand All @@ -543,10 +543,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
high= BN_UMULT_HIGH(w,tmp); \
ret += (c); \
low = (w) * tmp; \
(c) = (ret<(c))?1:0; \
(c) = (ret<(c)); \
(c) += high; \
ret += low; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}

Expand All @@ -556,7 +556,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
high= BN_UMULT_HIGH(w,ta); \
ret = low + (c); \
(c) = high; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}

Expand Down Expand Up @@ -589,10 +589,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
lt=(bl)*(lt); \
m1=(bl)*(ht); \
ht =(bh)*(ht); \
m=(m+m1)&BN_MASK2; if (m < m1) ht+=L2HBITS((BN_ULONG)1); \
m=(m+m1)&BN_MASK2; ht += L2HBITS((BN_ULONG)(m < m1)); \
ht+=HBITS(m); \
m1=L2HBITS(m); \
lt=(lt+m1)&BN_MASK2; if (lt < m1) ht++; \
lt=(lt+m1)&BN_MASK2; ht += (lt < m1); \
(l)=lt; \
(h)=ht; \
}
Expand All @@ -609,7 +609,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
h*=h; \
h+=(m&BN_MASK2h1)>>(BN_BITS4-1); \
m =(m&BN_MASK2l)<<(BN_BITS4+1); \
l=(l+m)&BN_MASK2; if (l < m) h++; \
l=(l+m)&BN_MASK2; h += (l < m); \
(lo)=l; \
(ho)=h; \
}
Expand All @@ -623,9 +623,9 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
mul64(l,h,(bl),(bh)); \
\
/* non-multiply part */ \
l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
l=(l+(c))&BN_MASK2; h += (l < (c)); \
(c)=(r); \
l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
l=(l+(c))&BN_MASK2; h += (l < (c)); \
(c)=h&BN_MASK2; \
(r)=l; \
}
Expand All @@ -639,7 +639,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
mul64(l,h,(bl),(bh)); \
\
/* non-multiply part */ \
l+=(c); if ((l&BN_MASK2) < (c)) h++; \
l+=(c); h += ((l&BN_MASK2) < (c)); \
(c)=h&BN_MASK2; \
(r)=l&BN_MASK2; \
}
Expand Down Expand Up @@ -669,7 +669,7 @@ BN_ULONG bn_sub_part_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
int cl, int dl);
int bn_mul_mont(BN_ULONG *rp, const BN_ULONG *ap, const BN_ULONG *bp,
const BN_ULONG *np, const BN_ULONG *n0, int num);

void bn_correct_top_consttime(BIGNUM *a);
BIGNUM *int_bn_mod_inverse(BIGNUM *in,
const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
int *noinv);
Expand Down
13 changes: 6 additions & 7 deletions crypto/rsa/rsa_ossl.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
* will only read the modulus from BN_BLINDING. In both cases it's safe
* to access the blinding without a lock.
*/
BN_set_flags(f, BN_FLG_CONSTTIME);
return BN_BLINDING_invert_ex(f, unblind, b, ctx);
}

Expand Down Expand Up @@ -412,6 +413,11 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
goto err;
}

if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
rsa->n, ctx))
goto err;

if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
if (blinding == NULL) {
Expand Down Expand Up @@ -449,13 +455,6 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
goto err;
}
BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);

if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
rsa->n, ctx)) {
BN_free(d);
goto err;
}
if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
rsa->_method_mod_n)) {
BN_free(d);
Expand Down

0 comments on commit 3f499b2

Please sign in to comment.