Support multi-prime RSA (RFC 8017)
[openssl.git] / crypto / rsa / rsa_ossl.c
index d8af92dc6ce1e5ffb97ee80b6c0415d394f515dc..ced11ad883c98e17b53e073e4da63c5010d5f0ea 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 1995-2017 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the OpenSSL license (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -9,11 +9,8 @@
 
 #include "internal/cryptlib.h"
 #include "internal/bn_int.h"
-#include <openssl/rand.h>
 #include "rsa_locl.h"
 
-#ifndef RSA_NULL
-
 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
                                   unsigned char *to, RSA *rsa, int padding);
 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
@@ -27,7 +24,7 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
 static int rsa_ossl_init(RSA *rsa);
 static int rsa_ossl_finish(RSA *rsa);
 static RSA_METHOD rsa_pkcs1_ossl_meth = {
-    "OpenSSL PKCS#1 RSA (from Eric Young)",
+    "OpenSSL PKCS#1 RSA",
     rsa_ossl_public_encrypt,
     rsa_ossl_public_decrypt,     /* signature verification */
     rsa_ossl_private_encrypt,    /* signing */
@@ -41,14 +38,32 @@ static RSA_METHOD rsa_pkcs1_ossl_meth = {
     NULL,
     0,                          /* rsa_sign */
     0,                          /* rsa_verify */
-    NULL                        /* rsa_keygen */
+    NULL,                       /* rsa_keygen */
+    NULL                        /* rsa_multi_prime_keygen */
 };
 
+static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
+
+void RSA_set_default_method(const RSA_METHOD *meth)
+{
+    default_RSA_meth = meth;
+}
+
+const RSA_METHOD *RSA_get_default_method(void)
+{
+    return default_RSA_meth;
+}
+
 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
 {
     return &rsa_pkcs1_ossl_meth;
 }
 
+const RSA_METHOD *RSA_null_method(void)
+{
+    return NULL;
+}
+
 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
                                   unsigned char *to, RSA *rsa, int padding)
 {
@@ -82,7 +97,7 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -141,7 +156,7 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
@@ -186,12 +201,12 @@ static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
                                 BN_CTX *ctx)
 {
-    if (unblind == NULL)
+    if (unblind == NULL) {
         /*
          * Local blinding: store the unblinding factor in BN_BLINDING.
          */
         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
-    else {
+    else {
         /*
          * Shared blinding: store the unblinding factor outside BN_BLINDING.
          */
@@ -243,7 +258,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -294,6 +309,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
     }
 
     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
+        (rsa->version == RSA_ASN1_VERSION_MULTI) ||
         ((rsa->p != NULL) &&
          (rsa->q != NULL) &&
          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
@@ -333,8 +349,9 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
             res = f;
         else
             res = ret;
-    } else
+    } else {
         res = ret;
+    }
 
     /*
      * put in leading 0 bytes if the number is less than the length of the
@@ -351,7 +368,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
@@ -378,7 +395,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -422,6 +439,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
 
     /* do the decrypt */
     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
+        (rsa->version == RSA_ASN1_VERSION_MULTI) ||
         ((rsa->p != NULL) &&
          (rsa->q != NULL) &&
          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
@@ -482,7 +500,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 /* signature verification */
@@ -520,7 +538,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
     ret = BN_CTX_get(ctx);
     num = BN_num_bytes(rsa->n);
     buf = OPENSSL_malloc(num);
-    if (f == NULL || ret == NULL || buf == NULL) {
+    if (ret == NULL || buf == NULL) {
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, ERR_R_MALLOC_FAILURE);
         goto err;
     }
@@ -581,19 +599,27 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
         BN_CTX_end(ctx);
     BN_CTX_free(ctx);
     OPENSSL_clear_free(buf, num);
-    return (r);
+    return r;
 }
 
 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
 {
-    BIGNUM *r1, *m1, *vrfy;
-    int ret = 0;
+    BIGNUM *r1, *m1, *vrfy, *r2, *m[RSA_MAX_PRIME_NUM];
+    int ret = 0, i, ex_primes = 0;
+    RSA_PRIME_INFO *pinfo;
 
     BN_CTX_start(ctx);
 
     r1 = BN_CTX_get(ctx);
+    r2 = BN_CTX_get(ctx);
     m1 = BN_CTX_get(ctx);
     vrfy = BN_CTX_get(ctx);
+    if (vrfy == NULL)
+        goto err;
+
+    if (rsa->version == RSA_ASN1_VERSION_MULTI
+        && (ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0)
+        goto err;
 
     {
         BIGNUM *p = BN_new(), *q = BN_new();
@@ -619,6 +645,28 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
                 BN_free(q);
                 goto err;
             }
+            if (ex_primes > 0) {
+                /* cache BN_MONT_CTX for other primes */
+                BIGNUM *r = BN_new();
+
+                if (r == NULL) {
+                    BN_free(p);
+                    BN_free(q);
+                    goto err;
+                }
+
+                for (i = 0; i < ex_primes; i++) {
+                    pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+                    BN_with_flags(r, pinfo->r, BN_FLG_CONSTTIME);
+                    if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, r, ctx)) {
+                        BN_free(p);
+                        BN_free(q);
+                        BN_free(r);
+                        goto err;
+                    }
+                }
+                BN_free(r);
+            }
         }
         /*
          * We MUST free p and q before any further use of rsa->p and rsa->q
@@ -688,6 +736,56 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
         BN_free(dmp1);
     }
 
+    /*
+     * calculate m_i in multi-prime case
+     *
+     * TODO:
+     * 1. squash the following two loops and calculate |m_i| there.
+     * 2. remove cc and reuse |c|.
+     * 3. remove |dmq1| and |dmp1| in previous block and use |di|.
+     *
+     * If these things are done, the code will be more readable.
+     */
+    if (ex_primes > 0) {
+        BIGNUM *di = BN_new(), *cc = BN_new();
+
+        if (cc == NULL || di == NULL) {
+            BN_free(cc);
+            BN_free(di);
+            goto err;
+        }
+
+        for (i = 0; i < ex_primes; i++) {
+            /* prepare m_i */
+            if ((m[i] = BN_CTX_get(ctx)) == NULL) {
+                BN_free(cc);
+                BN_free(di);
+                goto err;
+            }
+
+            pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+
+            /* prepare c and d_i */
+            BN_with_flags(cc, I, BN_FLG_CONSTTIME);
+            BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
+
+            if (!BN_mod(r1, cc, pinfo->r, ctx)) {
+                BN_free(cc);
+                BN_free(di);
+                goto err;
+            }
+            /* compute r1 ^ d_i mod r_i */
+            if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
+                BN_free(cc);
+                BN_free(di);
+                goto err;
+            }
+        }
+
+        BN_free(cc);
+        BN_free(di);
+    }
+
     if (!BN_sub(r0, r0, m1))
         goto err;
     /*
@@ -730,6 +828,49 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
     if (!BN_add(r0, r1, m1))
         goto err;
 
+    /* add m_i to m in multi-prime case */
+    if (ex_primes > 0) {
+        BIGNUM *pr2 = BN_new();
+
+        if (pr2 == NULL)
+            goto err;
+
+        for (i = 0; i < ex_primes; i++) {
+            pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+            if (!BN_sub(r1, m[i], r0)) {
+                BN_free(pr2);
+                goto err;
+            }
+
+            if (!BN_mul(r2, r1, pinfo->t, ctx)) {
+                BN_free(pr2);
+                goto err;
+            }
+
+            BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
+
+            if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
+                BN_free(pr2);
+                goto err;
+            }
+
+            if (BN_is_negative(r1))
+                if (!BN_add(r1, r1, pinfo->r)) {
+                    BN_free(pr2);
+                    goto err;
+                }
+            if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
+                BN_free(pr2);
+                goto err;
+            }
+            if (!BN_add(r0, r0, r1)) {
+                BN_free(pr2);
+                goto err;
+            }
+        }
+        BN_free(pr2);
+    }
+
     if (rsa->e && rsa->n) {
         if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
                                    rsa->_method_mod_n))
@@ -771,21 +912,26 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
     ret = 1;
  err:
     BN_CTX_end(ctx);
-    return (ret);
+    return ret;
 }
 
 static int rsa_ossl_init(RSA *rsa)
 {
     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
-    return (1);
+    return 1;
 }
 
 static int rsa_ossl_finish(RSA *rsa)
 {
+    int i;
+    RSA_PRIME_INFO *pinfo;
+
     BN_MONT_CTX_free(rsa->_method_mod_n);
     BN_MONT_CTX_free(rsa->_method_mod_p);
     BN_MONT_CTX_free(rsa->_method_mod_q);
-    return (1);
+    for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
+        pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
+        BN_MONT_CTX_free(pinfo->m);
+    }
+    return 1;
 }
-
-#endif