Check return from BN_sub
[openssl.git] / crypto / rsa / rsa_ossl.c
index 782606645bd1b4b4a504e5dbc58ffeb8da85f274..c441905526c82b0e49c1945e9262498bbb0c6688 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 1995-2017 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the OpenSSL license (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -11,8 +11,6 @@
 #include "internal/bn_int.h"
 #include "rsa_locl.h"
 
-#ifndef RSA_NULL
-
 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
                                   unsigned char *to, RSA *rsa, int padding);
 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
@@ -26,7 +24,7 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
 static int rsa_ossl_init(RSA *rsa);
 static int rsa_ossl_finish(RSA *rsa);
 static RSA_METHOD rsa_pkcs1_ossl_meth = {
-    "OpenSSL PKCS#1 RSA (from Eric Young)",
+    "OpenSSL PKCS#1 RSA",
     rsa_ossl_public_encrypt,
     rsa_ossl_public_decrypt,     /* signature verification */
     rsa_ossl_private_encrypt,    /* signing */
@@ -40,19 +38,37 @@ static RSA_METHOD rsa_pkcs1_ossl_meth = {
     NULL,
     0,                          /* rsa_sign */
     0,                          /* rsa_verify */
-    NULL                        /* rsa_keygen */
+    NULL,                       /* rsa_keygen */
+    NULL                        /* rsa_multi_prime_keygen */
 };
 
+static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
+
+void RSA_set_default_method(const RSA_METHOD *meth)
+{
+    default_RSA_meth = meth;
+}
+
+const RSA_METHOD *RSA_get_default_method(void)
+{
+    return default_RSA_meth;
+}
+
 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
 {
     return &rsa_pkcs1_ossl_meth;
 }
 
+const RSA_METHOD *RSA_null_method(void)
+{
+    return NULL;
+}
+
 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
                                   unsigned char *to, RSA *rsa, int padding)
 {
     BIGNUM *f, *ret;
-    int i, j, k, num = 0, r = -1;
+    int i, num = 0, r = -1;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
 
@@ -81,7 +97,7 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -126,21 +142,16 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
         goto err;
 
     /*
-     * put in leading 0 bytes if the number is less than the length of the
-     * modulus
+     * BN_bn2binpad puts in leading 0 bytes if the number is less than
+     * the length of the modulus.
      */
-    j = BN_num_bytes(ret);
-    i = BN_bn2bin(ret, &(to[num - j]));
-    for (k = 0; k < (num - i); k++)
-        to[k] = 0;
-
-    r = num;
+    r = BN_bn2binpad(ret, to, num);
  err:
     if (ctx != NULL)
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
@@ -185,12 +196,12 @@ static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
                                 BN_CTX *ctx)
 {
-    if (unblind == NULL)
+    if (unblind == NULL) {
         /*
          * Local blinding: store the unblinding factor in BN_BLINDING.
          */
         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
-    else {
+    else {
         /*
          * Shared blinding: store the unblinding factor outside BN_BLINDING.
          */
@@ -223,7 +234,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
                                    unsigned char *to, RSA *rsa, int padding)
 {
     BIGNUM *f, *ret, *res;
-    int i, j, k, num = 0, r = -1;
+    int i, num = 0, r = -1;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
     int local_blinding = 0;
@@ -242,7 +253,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -293,6 +304,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
     }
 
     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
+        (rsa->version == RSA_ASN1_VERSION_MULTI) ||
         ((rsa->p != NULL) &&
          (rsa->q != NULL) &&
          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
@@ -327,30 +339,27 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
             goto err;
 
     if (padding == RSA_X931_PADDING) {
-        BN_sub(f, rsa->n, ret);
+        if (!BN_sub(f, rsa->n, ret))
+            goto err;
         if (BN_cmp(ret, f) > 0)
             res = f;
         else
             res = ret;
-    } else
+    } else {
         res = ret;
+    }
 
     /*
-     * put in leading 0 bytes if the number is less than the length of the
-     * modulus
+     * BN_bn2binpad puts in leading 0 bytes if the number is less than
+     * the length of the modulus.
      */
-    j = BN_num_bytes(res);
-    i = BN_bn2bin(res, &(to[num - j]));
-    for (k = 0; k < (num - i); k++)
-        to[k] = 0;
-
-    r = num;
+    r = BN_bn2binpad(res, to, num);
  err:
     if (ctx != NULL)
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
@@ -358,7 +367,6 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
 {
     BIGNUM *f, *ret;
     int j, num = 0, r = -1;
-    unsigned char *p;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
     int local_blinding = 0;
@@ -377,7 +385,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -421,6 +429,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
 
     /* do the decrypt */
     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
+        (rsa->version == RSA_ASN1_VERSION_MULTI) ||
         ((rsa->p != NULL) &&
          (rsa->q != NULL) &&
          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
@@ -453,8 +462,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
             goto err;
 
-    p = buf;
-    j = BN_bn2bin(ret, p);      /* j is only used with no-padding mode */
+    j = BN_bn2binpad(ret, buf, num);
 
     switch (padding) {
     case RSA_PKCS1_PADDING:
@@ -467,7 +475,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         r = RSA_padding_check_SSLv23(to, num, buf, j, num);
         break;
     case RSA_NO_PADDING:
-        r = RSA_padding_check_none(to, num, buf, j, num);
+        memcpy(to, buf, (r = j));
         break;
     default:
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
@@ -481,7 +489,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 /* signature verification */
@@ -490,7 +498,6 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
 {
     BIGNUM *f, *ret;
     int i, num = 0, r = -1;
-    unsigned char *p;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
 
@@ -519,7 +526,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -555,8 +562,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
         if (!BN_sub(ret, rsa->n, ret))
             goto err;
 
-    p = buf;
-    i = BN_bn2bin(ret, p);
+    i = BN_bn2binpad(ret, buf, num);
 
     switch (padding) {
     case RSA_PKCS1_PADDING:
@@ -566,7 +572,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
         r = RSA_padding_check_X931(to, num, buf, i, num);
         break;
     case RSA_NO_PADDING:
-        r = RSA_padding_check_none(to, num, buf, i, num);
+        memcpy(to, buf, (r = i));
         break;
     default:
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
@@ -580,19 +586,28 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
 {
-    BIGNUM *r1, *m1, *vrfy;
-    int ret = 0;
+    BIGNUM *r1, *m1, *vrfy, *r2, *m[RSA_MAX_PRIME_NUM - 2];
+    int ret = 0, i, ex_primes = 0;
+    RSA_PRIME_INFO *pinfo;
 
     BN_CTX_start(ctx);
 
     r1 = BN_CTX_get(ctx);
+    r2 = BN_CTX_get(ctx);
     m1 = BN_CTX_get(ctx);
     vrfy = BN_CTX_get(ctx);
+    if (vrfy == NULL)
+        goto err;
+
+    if (rsa->version == RSA_ASN1_VERSION_MULTI
+        && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
+             || ex_primes > RSA_MAX_PRIME_NUM - 2))
+        goto err;
 
     {
         BIGNUM *p = BN_new(), *q = BN_new();
@@ -618,6 +633,28 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
                 BN_free(q);
                 goto err;
             }
+            if (ex_primes > 0) {
+                /* cache BN_MONT_CTX for other primes */
+                BIGNUM *r = BN_new();
+
+                if (r == NULL) {
+                    BN_free(p);
+                    BN_free(q);
+                    goto err;
+                }
+
+                for (i = 0; i < ex_primes; i++) {
+                    pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+                    BN_with_flags(r, pinfo->r, BN_FLG_CONSTTIME);
+                    if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, r, ctx)) {
+                        BN_free(p);
+                        BN_free(q);
+                        BN_free(r);
+                        goto err;
+                    }
+                }
+                BN_free(r);
+            }
         }
         /*
          * We MUST free p and q before any further use of rsa->p and rsa->q
@@ -687,6 +724,56 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
         BN_free(dmp1);
     }
 
+    /*
+     * calculate m_i in multi-prime case
+     *
+     * TODO:
+     * 1. squash the following two loops and calculate |m_i| there.
+     * 2. remove cc and reuse |c|.
+     * 3. remove |dmq1| and |dmp1| in previous block and use |di|.
+     *
+     * If these things are done, the code will be more readable.
+     */
+    if (ex_primes > 0) {
+        BIGNUM *di = BN_new(), *cc = BN_new();
+
+        if (cc == NULL || di == NULL) {
+            BN_free(cc);
+            BN_free(di);
+            goto err;
+        }
+
+        for (i = 0; i < ex_primes; i++) {
+            /* prepare m_i */
+            if ((m[i] = BN_CTX_get(ctx)) == NULL) {
+                BN_free(cc);
+                BN_free(di);
+                goto err;
+            }
+
+            pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+
+            /* prepare c and d_i */
+            BN_with_flags(cc, I, BN_FLG_CONSTTIME);
+            BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
+
+            if (!BN_mod(r1, cc, pinfo->r, ctx)) {
+                BN_free(cc);
+                BN_free(di);
+                goto err;
+            }
+            /* compute r1 ^ d_i mod r_i */
+            if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
+                BN_free(cc);
+                BN_free(di);
+                goto err;
+            }
+        }
+
+        BN_free(cc);
+        BN_free(di);
+    }
+
     if (!BN_sub(r0, r0, m1))
         goto err;
     /*
@@ -729,6 +816,49 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
     if (!BN_add(r0, r1, m1))
         goto err;
 
+    /* add m_i to m in multi-prime case */
+    if (ex_primes > 0) {
+        BIGNUM *pr2 = BN_new();
+
+        if (pr2 == NULL)
+            goto err;
+
+        for (i = 0; i < ex_primes; i++) {
+            pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+            if (!BN_sub(r1, m[i], r0)) {
+                BN_free(pr2);
+                goto err;
+            }
+
+            if (!BN_mul(r2, r1, pinfo->t, ctx)) {
+                BN_free(pr2);
+                goto err;
+            }
+
+            BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
+
+            if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
+                BN_free(pr2);
+                goto err;
+            }
+
+            if (BN_is_negative(r1))
+                if (!BN_add(r1, r1, pinfo->r)) {
+                    BN_free(pr2);
+                    goto err;
+                }
+            if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
+                BN_free(pr2);
+                goto err;
+            }
+            if (!BN_add(r0, r0, r1)) {
+                BN_free(pr2);
+                goto err;
+            }
+        }
+        BN_free(pr2);
+    }
+
     if (rsa->e && rsa->n) {
         if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
                                    rsa->_method_mod_n))
@@ -770,21 +900,26 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
     ret = 1;
  err:
     BN_CTX_end(ctx);
-    return (ret);
+    return ret;
 }
 
 static int rsa_ossl_init(RSA *rsa)
 {
     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
-    return (1);
+    return 1;
 }
 
 static int rsa_ossl_finish(RSA *rsa)
 {
+    int i;
+    RSA_PRIME_INFO *pinfo;
+
     BN_MONT_CTX_free(rsa->_method_mod_n);
     BN_MONT_CTX_free(rsa->_method_mod_p);
     BN_MONT_CTX_free(rsa->_method_mod_q);
-    return (1);
+    for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
+        pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+        BN_MONT_CTX_free(pinfo->m);
+    }
+    return 1;
 }
-
-#endif