Convert RSA blinding to new multi-threading API
[openssl.git] / crypto / rsa / rsa_ossl.c
index 0752f5fd8b4cdc0d00ebfff6ed4e8b43cc95211c..8d3383bfb0814351126e1643913d0f5e2d485af0 100644 (file)
@@ -220,7 +220,7 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
         if (!BN_MONT_CTX_set_locked
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
         if (!BN_MONT_CTX_set_locked
-            (&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, ctx))
+            (&rsa->_method_mod_n, rsa->lock, rsa->n, ctx))
             goto err;
 
     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
             goto err;
 
     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
@@ -248,26 +248,18 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
 {
     BN_BLINDING *ret;
 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
 {
     BN_BLINDING *ret;
-    int got_write_lock = 0;
-    CRYPTO_THREADID cur;
 
 
-    CRYPTO_r_lock(CRYPTO_LOCK_RSA);
+    CRYPTO_THREAD_write_lock(rsa->lock);
 
     if (rsa->blinding == NULL) {
 
     if (rsa->blinding == NULL) {
-        CRYPTO_r_unlock(CRYPTO_LOCK_RSA);
-        CRYPTO_w_lock(CRYPTO_LOCK_RSA);
-        got_write_lock = 1;
-
-        if (rsa->blinding == NULL)
-            rsa->blinding = RSA_setup_blinding(rsa, ctx);
+        rsa->blinding = RSA_setup_blinding(rsa, ctx);
     }
 
     ret = rsa->blinding;
     if (ret == NULL)
         goto err;
 
     }
 
     ret = rsa->blinding;
     if (ret == NULL)
         goto err;
 
-    CRYPTO_THREADID_current(&cur);
-    if (!CRYPTO_THREADID_cmp(&cur, BN_BLINDING_thread_id(ret))) {
+    if (BN_BLINDING_is_current_thread(ret)) {
         /* rsa->blinding is ours! */
 
         *local = 1;
         /* rsa->blinding is ours! */
 
         *local = 1;
@@ -282,23 +274,13 @@ static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
         *local = 0;
 
         if (rsa->mt_blinding == NULL) {
         *local = 0;
 
         if (rsa->mt_blinding == NULL) {
-            if (!got_write_lock) {
-                CRYPTO_r_unlock(CRYPTO_LOCK_RSA);
-                CRYPTO_w_lock(CRYPTO_LOCK_RSA);
-                got_write_lock = 1;
-            }
-
-            if (rsa->mt_blinding == NULL)
-                rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
+            rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
         }
         ret = rsa->mt_blinding;
     }
 
  err:
         }
         ret = rsa->mt_blinding;
     }
 
  err:
-    if (got_write_lock)
-        CRYPTO_w_unlock(CRYPTO_LOCK_RSA);
-    else
-        CRYPTO_r_unlock(CRYPTO_LOCK_RSA);
+    CRYPTO_THREAD_unlock(rsa->lock);
     return ret;
 }
 
     return ret;
 }
 
@@ -315,9 +297,11 @@ static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
          * Shared blinding: store the unblinding factor outside BN_BLINDING.
          */
         int ret;
          * Shared blinding: store the unblinding factor outside BN_BLINDING.
          */
         int ret;
-        CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING);
+
+        BN_BLINDING_lock(b);
         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
-        CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
+        BN_BLINDING_unlock(b);
+
         return ret;
     }
 }
         return ret;
     }
 }
@@ -432,7 +416,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
 
         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
             if (!BN_MONT_CTX_set_locked
 
         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
             if (!BN_MONT_CTX_set_locked
-                (&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, ctx)) {
+                (&rsa->_method_mod_n, rsa->lock, rsa->n, ctx)) {
                 BN_free(local_d);
                 goto err;
             }
                 BN_free(local_d);
                 goto err;
             }
@@ -566,7 +550,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
 
         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
             if (!BN_MONT_CTX_set_locked
 
         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
             if (!BN_MONT_CTX_set_locked
-                (&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, ctx)) {
+                (&rsa->_method_mod_n, rsa->lock, rsa->n, ctx)) {
                 BN_free(local_d);
                 goto err;
             }
                 BN_free(local_d);
                 goto err;
             }
@@ -674,7 +658,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
         if (!BN_MONT_CTX_set_locked
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
         if (!BN_MONT_CTX_set_locked
-            (&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, ctx))
+            (&rsa->_method_mod_n, rsa->lock, rsa->n, ctx))
             goto err;
 
     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
             goto err;
 
     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
@@ -729,7 +713,7 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
         BIGNUM *p = NULL, *q = NULL;
 
         /*
         BIGNUM *p = NULL, *q = NULL;
 
         /*
-         * Make sure BN_mod_inverse in Montgomery intialization uses the
+         * Make sure BN_mod_inverse in Montgomery initialization uses the
          * BN_FLG_CONSTTIME flag (unless RSA_FLAG_NO_CONSTTIME is set)
          */
         if (!(rsa->flags & RSA_FLAG_NO_CONSTTIME)) {
          * BN_FLG_CONSTTIME flag (unless RSA_FLAG_NO_CONSTTIME is set)
          */
         if (!(rsa->flags & RSA_FLAG_NO_CONSTTIME)) {
@@ -751,9 +735,9 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
 
         if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
             if (!BN_MONT_CTX_set_locked
 
         if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
             if (!BN_MONT_CTX_set_locked
-                (&rsa->_method_mod_p, CRYPTO_LOCK_RSA, p, ctx)
+                (&rsa->_method_mod_p, rsa->lock, p, ctx)
                 || !BN_MONT_CTX_set_locked(&rsa->_method_mod_q,
                 || !BN_MONT_CTX_set_locked(&rsa->_method_mod_q,
-                                           CRYPTO_LOCK_RSA, q, ctx)) {
+                                           rsa->lock, q, ctx)) {
                 BN_free(local_p);
                 BN_free(local_q);
                 goto err;
                 BN_free(local_p);
                 BN_free(local_q);
                 goto err;
@@ -769,7 +753,7 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
         if (!BN_MONT_CTX_set_locked
 
     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
         if (!BN_MONT_CTX_set_locked
-            (&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, ctx))
+            (&rsa->_method_mod_n, rsa->lock, rsa->n, ctx))
             goto err;
 
     /* compute I mod q */
             goto err;
 
     /* compute I mod q */