Teach the RSA implementation about TLS RSA Key Transport
[openssl.git] / providers / implementations / asymciphers / rsa_enc.c
index 9b17377..53fc6de 100644 (file)
 #include <openssl/rsa.h>
 #include <openssl/params.h>
 #include <openssl/err.h>
+/* Just for SSL_MAX_MASTER_KEY_LENGTH */
+#include <openssl/ssl.h>
 #include "internal/constant_time.h"
+#include "crypto/rsa.h"
 #include "prov/providercommonerr.h"
 #include "prov/provider_ctx.h"
 #include "prov/implementations.h"
@@ -51,6 +54,9 @@ typedef struct {
     /* OAEP label */
     unsigned char *oaep_label;
     size_t oaep_labellen;
+    /* TLS padding */
+    unsigned int client_version;
+    unsigned int alt_version;
 } PROV_RSA_CTX;
 
 static void *rsa_newctx(void *provctx)
@@ -130,38 +136,70 @@ static int rsa_decrypt(void *vprsactx, unsigned char *out, size_t *outlen,
 {
     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
     int ret;
+    size_t len = RSA_size(prsactx->rsa);
 
-    if (out == NULL) {
-        size_t len = RSA_size(prsactx->rsa);
+    if (prsactx->pad_mode == RSA_PKCS1_WITH_TLS_PADDING) {
+        if (out == NULL) {
+            *outlen = SSL_MAX_MASTER_KEY_LENGTH;
+            return 1;
+        }
+        if (outsize < SSL_MAX_MASTER_KEY_LENGTH) {
+            ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
+            return 0;
+        }
+    } else {
+        if (out == NULL) {
+            if (len == 0) {
+                ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY);
+                return 0;
+            }
+            *outlen = len;
+            return 1;
+        }
 
-        if (len == 0) {
-            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY);
+        if (outsize < len) {
+            ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
             return 0;
         }
-        *outlen = len;
-        return 1;
     }
 
-    if (prsactx->pad_mode == RSA_PKCS1_OAEP_PADDING) {
-        int rsasize = RSA_size(prsactx->rsa);
+    if (prsactx->pad_mode == RSA_PKCS1_OAEP_PADDING
+            || prsactx->pad_mode == RSA_PKCS1_WITH_TLS_PADDING) {
         unsigned char *tbuf;
 
-        if ((tbuf = OPENSSL_malloc(rsasize)) == NULL) {
+        if ((tbuf = OPENSSL_malloc(len)) == NULL) {
             PROVerr(0, ERR_R_MALLOC_FAILURE);
             return 0;
         }
         ret = RSA_private_decrypt(inlen, in, tbuf, prsactx->rsa,
                                   RSA_NO_PADDING);
-        if (ret <= 0) {
+        /*
+         * With no padding then, on success ret should be len, otherwise an
+         * error occurred (non-constant time)
+         */
+        if (ret != (int)len) {
             OPENSSL_free(tbuf);
+            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_DECRYPT);
             return 0;
         }
-        ret = RSA_padding_check_PKCS1_OAEP_mgf1(out, ret, tbuf,
-                                                ret, ret,
-                                                prsactx->oaep_label,
-                                                prsactx->oaep_labellen,
-                                                prsactx->oaep_md,
-                                                prsactx->mgf1_md);
+        if (prsactx->pad_mode == RSA_PKCS1_OAEP_PADDING) {
+            ret = RSA_padding_check_PKCS1_OAEP_mgf1(out, outsize, tbuf,
+                                                    len, len,
+                                                    prsactx->oaep_label,
+                                                    prsactx->oaep_labellen,
+                                                    prsactx->oaep_md,
+                                                    prsactx->mgf1_md);
+        } else {
+            /* RSA_PKCS1_WITH_TLS_PADDING */
+            if (prsactx->client_version <= 0) {
+                ERR_raise(ERR_LIB_PROV, PROV_R_BAD_TLS_CLIENT_VERSION);
+                return 0;
+            }
+            ret = rsa_padding_check_PKCS1_type_2_TLS(out, outsize,
+                                                     tbuf, len,
+                                                     prsactx->client_version,
+                                                     prsactx->alt_version);
+        }
         OPENSSL_free(tbuf);
     } else {
         ret = RSA_private_decrypt(inlen, in, out, prsactx->rsa,
@@ -252,6 +290,14 @@ static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
     if (p != NULL && !OSSL_PARAM_set_size_t(p, prsactx->oaep_labellen))
         return 0;
 
+    p = OSSL_PARAM_locate(params, OSSL_ASYM_CIPHER_PARAM_TLS_CLIENT_VERSION);
+    if (p != NULL && !OSSL_PARAM_set_uint(p, prsactx->client_version))
+        return 0;
+
+    p = OSSL_PARAM_locate(params, OSSL_ASYM_CIPHER_PARAM_TLS_NEGOTIATED_VERSION);
+    if (p != NULL && !OSSL_PARAM_set_uint(p, prsactx->alt_version))
+        return 0;
+
     return 1;
 }
 
@@ -262,6 +308,8 @@ static const OSSL_PARAM known_gettable_ctx_params[] = {
     OSSL_PARAM_DEFN(OSSL_ASYM_CIPHER_PARAM_OAEP_LABEL, OSSL_PARAM_OCTET_PTR,
                     NULL, 0),
     OSSL_PARAM_size_t(OSSL_ASYM_CIPHER_PARAM_OAEP_LABEL_LEN, NULL),
+    OSSL_PARAM_uint(OSSL_ASYM_CIPHER_PARAM_TLS_CLIENT_VERSION, NULL),
+    OSSL_PARAM_uint(OSSL_ASYM_CIPHER_PARAM_TLS_NEGOTIATED_VERSION, NULL),
     OSSL_PARAM_END
 };
 
@@ -354,6 +402,24 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
         prsactx->oaep_labellen = tmp_labellen;
     }
 
+    p = OSSL_PARAM_locate_const(params, OSSL_ASYM_CIPHER_PARAM_TLS_CLIENT_VERSION);
+    if (p != NULL) {
+        unsigned int client_version;
+
+        if (!OSSL_PARAM_get_uint(p, &client_version))
+            return 0;
+        prsactx->client_version = client_version;
+    }
+
+    p = OSSL_PARAM_locate_const(params, OSSL_ASYM_CIPHER_PARAM_TLS_NEGOTIATED_VERSION);
+    if (p != NULL) {
+        unsigned int alt_version;
+
+        if (!OSSL_PARAM_get_uint(p, &alt_version))
+            return 0;
+        prsactx->alt_version = alt_version;
+    }
+
     return 1;
 }
 
@@ -363,6 +429,8 @@ static const OSSL_PARAM known_settable_ctx_params[] = {
     OSSL_PARAM_utf8_string(OSSL_ASYM_CIPHER_PARAM_MGF1_DIGEST, NULL, 0),
     OSSL_PARAM_utf8_string(OSSL_ASYM_CIPHER_PARAM_MGF1_DIGEST_PROPS, NULL, 0),
     OSSL_PARAM_octet_string(OSSL_ASYM_CIPHER_PARAM_OAEP_LABEL, NULL, 0),
+    OSSL_PARAM_uint(OSSL_ASYM_CIPHER_PARAM_TLS_CLIENT_VERSION, NULL),
+    OSSL_PARAM_uint(OSSL_ASYM_CIPHER_PARAM_TLS_NEGOTIATED_VERSION, NULL),
     OSSL_PARAM_END
 };