support params argument to AES cipher init calls
authorPauli <ppzgs1@gmail.com>
Tue, 2 Mar 2021 12:46:24 +0000 (22:46 +1000)
committerPauli <ppzgs1@gmail.com>
Thu, 11 Mar 2021 22:27:21 +0000 (08:27 +1000)
Reviewed-by: Shane Lontis <shane.lontis@oracle.com>
(Merged from https://github.com/openssl/openssl/pull/14383)

providers/implementations/ciphers/cipher_aes_cbc_hmac_sha.c
providers/implementations/ciphers/cipher_aes_cts.inc
providers/implementations/ciphers/cipher_aes_ocb.c
providers/implementations/ciphers/cipher_aes_siv.c
providers/implementations/ciphers/cipher_aes_wrp.c
providers/implementations/ciphers/cipher_aes_xts.c

index b78687ceae8705c08a9dc9bdfa8dd388a4344135..a0eef7c1e516eccc77f2e78d092298da29062574 100644 (file)
@@ -33,6 +33,8 @@ const OSSL_DISPATCH ossl_##nm##kbits##sub##_functions[] = {                    \
 # define AES_CBC_HMAC_SHA_FLAGS (PROV_CIPHER_FLAG_AEAD                         \
                                  | PROV_CIPHER_FLAG_TLS1_MULTIBLOCK)
 
+static OSSL_FUNC_cipher_encrypt_init_fn aes_einit;
+static OSSL_FUNC_cipher_decrypt_init_fn aes_dinit;
 static OSSL_FUNC_cipher_freectx_fn aes_cbc_hmac_sha1_freectx;
 static OSSL_FUNC_cipher_freectx_fn aes_cbc_hmac_sha256_freectx;
 static OSSL_FUNC_cipher_get_ctx_params_fn aes_get_ctx_params;
@@ -40,12 +42,28 @@ static OSSL_FUNC_cipher_gettable_ctx_params_fn aes_gettable_ctx_params;
 static OSSL_FUNC_cipher_set_ctx_params_fn aes_set_ctx_params;
 static OSSL_FUNC_cipher_settable_ctx_params_fn aes_settable_ctx_params;
 # define aes_gettable_params ossl_cipher_generic_gettable_params
-# define aes_einit ossl_cipher_generic_einit
-# define aes_dinit ossl_cipher_generic_dinit
 # define aes_update ossl_cipher_generic_stream_update
 # define aes_final ossl_cipher_generic_stream_final
 # define aes_cipher ossl_cipher_generic_cipher
 
+static int aes_einit(void *ctx, const unsigned char *key, size_t keylen,
+                          const unsigned char *iv, size_t ivlen,
+                          const OSSL_PARAM params[])
+{
+    if (!ossl_cipher_generic_einit(ctx, key, keylen, iv, ivlen, NULL))
+        return 0;
+    return aes_set_ctx_params(ctx, params);
+}
+
+static int aes_dinit(void *ctx, const unsigned char *key, size_t keylen,
+                          const unsigned char *iv, size_t ivlen,
+                          const OSSL_PARAM params[])
+{
+    if (!ossl_cipher_generic_dinit(ctx, key, keylen, iv, ivlen, NULL))
+        return 0;
+    return aes_set_ctx_params(ctx, params);
+}
+
 static const OSSL_PARAM cipher_aes_known_settable_ctx_params[] = {
     OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_AEAD_MAC_KEY, NULL, 0),
     OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_AEAD_TLS1_AAD, NULL, 0),
@@ -76,6 +94,9 @@ static int aes_set_ctx_params(void *vctx, const OSSL_PARAM params[])
     EVP_CTRL_TLS1_1_MULTIBLOCK_PARAM mb_param;
 # endif
 
+    if (params == NULL)
+        return 1;
+
     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_MAC_KEY);
     if (p != NULL) {
         if (p->data_type != OSSL_PARAM_OCTET_STRING) {
index fbd66eb257d049957142f93a5e4d6d77534302da..2a3b88b2c00338b659f29d378cc7f880050b08d0 100644 (file)
@@ -14,6 +14,8 @@
 
 #define AES_CTS_FLAGS PROV_CIPHER_FLAG_CTS
 
+static OSSL_FUNC_cipher_encrypt_init_fn aes_cbc_cts_einit;
+static OSSL_FUNC_cipher_decrypt_init_fn aes_cbc_cts_dinit;
 static OSSL_FUNC_cipher_get_ctx_params_fn aes_cbc_cts_get_ctx_params;
 static OSSL_FUNC_cipher_set_ctx_params_fn aes_cbc_cts_set_ctx_params;
 static OSSL_FUNC_cipher_gettable_ctx_params_fn aes_cbc_cts_gettable_ctx_params;
@@ -23,6 +25,24 @@ CIPHER_DEFAULT_GETTABLE_CTX_PARAMS_START(aes_cbc_cts)
 OSSL_PARAM_utf8_string(OSSL_CIPHER_PARAM_CTS_MODE, NULL, 0),
 CIPHER_DEFAULT_GETTABLE_CTX_PARAMS_END(aes_cbc_cts)
 
+static int aes_cbc_cts_einit(void *ctx, const unsigned char *key, size_t keylen,
+                             const unsigned char *iv, size_t ivlen,
+                             const OSSL_PARAM params[])
+{
+    if (!ossl_cipher_generic_einit(ctx, key, keylen, iv, ivlen, NULL))
+        return 0;
+    return aes_cbc_cts_set_ctx_params(ctx, params);
+}
+
+static int aes_cbc_cts_dinit(void *ctx, const unsigned char *key, size_t keylen,
+                             const unsigned char *iv, size_t ivlen,
+                             const OSSL_PARAM params[])
+{
+    if (!ossl_cipher_generic_dinit(ctx, key, keylen, iv, ivlen, NULL))
+        return 0;
+    return aes_cbc_cts_set_ctx_params(ctx, params);
+}
+
 static int aes_cbc_cts_get_ctx_params(void *vctx, OSSL_PARAM params[])
 {
     PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx;
@@ -80,8 +100,8 @@ const OSSL_DISPATCH ossl_##alg##kbits##lcmode##_cts_functions[] = {            \
       (void (*)(void)) alg##_##kbits##_##lcmode##_newctx },                    \
     { OSSL_FUNC_CIPHER_FREECTX, (void (*)(void)) alg##_freectx },              \
     { OSSL_FUNC_CIPHER_DUPCTX, (void (*)(void)) alg##_dupctx },                \
-    { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))ossl_cipher_generic_einit }, \
-    { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))ossl_cipher_generic_dinit }, \
+    { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))aes_cbc_cts_einit },      \
+    { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))aes_cbc_cts_dinit },      \
     { OSSL_FUNC_CIPHER_UPDATE,                                                 \
       (void (*)(void)) ossl_##alg##_##lcmode##_cts_block_update },             \
     { OSSL_FUNC_CIPHER_FINAL,                                                  \
index 627f146273dc626de73f74144796a85b27995c08..ce377ad57409327533de1905f6e963647ffae9d7 100644 (file)
@@ -102,7 +102,8 @@ static ossl_inline int aes_generic_ocb_copy_ctx(PROV_AES_OCB_CTX *dst,
  * Provider dispatch functions
  */
 static int aes_ocb_init(void *vctx, const unsigned char *key, size_t keylen,
-                        const unsigned char *iv, size_t ivlen, int enc)
+                        const unsigned char *iv, size_t ivlen,
+                        const OSSL_PARAM params[], int enc)
 {
     PROV_AES_OCB_CTX *ctx = (PROV_AES_OCB_CTX *)vctx;
 
@@ -131,21 +132,24 @@ static int aes_ocb_init(void *vctx, const unsigned char *key, size_t keylen,
             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH);
             return 0;
         }
-        return ctx->base.hw->init(&ctx->base, key, keylen);
+        if (!ctx->base.hw->init(&ctx->base, key, keylen))
+            return 0;
     }
-    return 1;
+    return aes_ocb_set_ctx_params(ctx, params);
 }
 
 static int aes_ocb_einit(void *vctx, const unsigned char *key, size_t keylen,
-                         const unsigned char *iv, size_t ivlen)
+                         const unsigned char *iv, size_t ivlen,
+                         const OSSL_PARAM params[])
 {
-    return aes_ocb_init(vctx, key, keylen, iv, ivlen, 1);
+    return aes_ocb_init(vctx, key, keylen, iv, ivlen, params, 1);
 }
 
 static int aes_ocb_dinit(void *vctx, const unsigned char *key, size_t keylen,
-                         const unsigned char *iv, size_t ivlen)
+                         const unsigned char *iv, size_t ivlen,
+                         const OSSL_PARAM params[])
 {
-    return aes_ocb_init(vctx, key, keylen, iv, ivlen, 0);
+    return aes_ocb_init(vctx, key, keylen, iv, ivlen, params, 0);
 }
 
 /*
@@ -354,6 +358,9 @@ static int aes_ocb_set_ctx_params(void *vctx, const OSSL_PARAM params[])
     const OSSL_PARAM *p;
     size_t sz;
 
+    if (params == NULL)
+        return 1;
+
     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TAG);
     if (p != NULL) {
         if (p->data_type != OSSL_PARAM_OCTET_STRING) {
index 9a75f6f5b762cc4274b48b59e959e25b5516c460..dd3346a81ce5b6164c3e29d4bca9aac87a5a9733 100644 (file)
@@ -25,6 +25,8 @@
 #define siv_stream_update siv_cipher
 #define SIV_FLAGS AEAD_FLAGS
 
+static OSSL_FUNC_cipher_set_ctx_params_fn aes_siv_set_ctx_params;
+
 static void *aes_siv_newctx(void *provctx, size_t keybits, unsigned int mode,
                             uint64_t flags)
 {
@@ -75,7 +77,8 @@ static void *siv_dupctx(void *vctx)
 }
 
 static int siv_init(void *vctx, const unsigned char *key, size_t keylen,
-                    const unsigned char *iv, size_t ivlen, int enc)
+                    const unsigned char *iv, size_t ivlen,
+                    const OSSL_PARAM params[], int enc)
 {
     PROV_AES_SIV_CTX *ctx = (PROV_AES_SIV_CTX *)vctx;
 
@@ -89,21 +92,24 @@ static int siv_init(void *vctx, const unsigned char *key, size_t keylen,
             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH);
             return 0;
         }
-        return ctx->hw->initkey(ctx, key, ctx->keylen);
+        if (!ctx->hw->initkey(ctx, key, ctx->keylen))
+            return 0;
     }
-    return 1;
+    return aes_siv_set_ctx_params(ctx, params);
 }
 
 static int siv_einit(void *vctx, const unsigned char *key, size_t keylen,
-                     const unsigned char *iv, size_t ivlen)
+                     const unsigned char *iv, size_t ivlen,
+                     const OSSL_PARAM params[])
 {
-    return siv_init(vctx, key, keylen, iv, ivlen, 1);
+    return siv_init(vctx, key, keylen, iv, ivlen, params, 1);
 }
 
 static int siv_dinit(void *vctx, const unsigned char *key, size_t keylen,
-                     const unsigned char *iv, size_t ivlen)
+                     const unsigned char *iv, size_t ivlen,
+                     const OSSL_PARAM params[])
 {
-    return siv_init(vctx, key, keylen, iv, ivlen, 0);
+    return siv_init(vctx, key, keylen, iv, ivlen, params, 0);
 }
 
 static int siv_cipher(void *vctx, unsigned char *out, size_t *outl,
@@ -195,6 +201,9 @@ static int aes_siv_set_ctx_params(void *vctx, const OSSL_PARAM params[])
     const OSSL_PARAM *p;
     unsigned int speed = 0;
 
+    if (params == NULL)
+        return 1;
+
     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TAG);
     if (p != NULL) {
         if (ctx->enc)
index 4428ff0552d91db6b499478b29513bcdccbdb1d8..f797db4596801801eb65f727d2586be490188cfc 100644 (file)
@@ -34,6 +34,7 @@ static OSSL_FUNC_cipher_decrypt_init_fn aes_wrap_dinit;
 static OSSL_FUNC_cipher_update_fn aes_wrap_cipher;
 static OSSL_FUNC_cipher_final_fn aes_wrap_final;
 static OSSL_FUNC_cipher_freectx_fn aes_wrap_freectx;
+static OSSL_FUNC_cipher_set_ctx_params_fn aes_wrap_set_ctx_params;
 
 typedef struct prov_aes_wrap_ctx_st {
     PROV_CIPHER_CTX base;
@@ -75,7 +76,7 @@ static void aes_wrap_freectx(void *vctx)
 
 static int aes_wrap_init(void *vctx, const unsigned char *key,
                          size_t keylen, const unsigned char *iv,
-                         size_t ivlen, int enc)
+                         size_t ivlen, const OSSL_PARAM params[], int enc)
 {
     PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx;
     PROV_AES_WRAP_CTX *wctx = (PROV_AES_WRAP_CTX *)vctx;
@@ -121,19 +122,21 @@ static int aes_wrap_init(void *vctx, const unsigned char *key,
             ctx->block = (block128_f)AES_decrypt;
         }
     }
-    return 1;
+    return aes_wrap_set_ctx_params(ctx, params);
 }
 
 static int aes_wrap_einit(void *ctx, const unsigned char *key, size_t keylen,
-                          const unsigned char *iv, size_t ivlen)
+                          const unsigned char *iv, size_t ivlen,
+                          const OSSL_PARAM params[])
 {
-    return aes_wrap_init(ctx, key, keylen, iv, ivlen, 1);
+    return aes_wrap_init(ctx, key, keylen, iv, ivlen, params, 1);
 }
 
 static int aes_wrap_dinit(void *ctx, const unsigned char *key, size_t keylen,
-                          const unsigned char *iv, size_t ivlen)
+                          const unsigned char *iv, size_t ivlen,
+                          const OSSL_PARAM params[])
 {
-    return aes_wrap_init(ctx, key, keylen, iv, ivlen, 0);
+    return aes_wrap_init(ctx, key, keylen, iv, ivlen, params, 0);
 }
 
 static int aes_wrap_cipher_internal(void *vctx, unsigned char *out,
@@ -226,6 +229,9 @@ static int aes_wrap_set_ctx_params(void *vctx, const OSSL_PARAM params[])
     const OSSL_PARAM *p;
     size_t keylen = 0;
 
+    if (params == NULL)
+        return 1;
+
     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_KEYLEN);
     if (p != NULL) {
         if (!OSSL_PARAM_get_size_t(p, &keylen)) {
index 13552b2a760b612796aa49091bca5b853a72f67e..5cfb22778ec4b7daf60e2105dbafd9e7ced25350 100644 (file)
@@ -66,7 +66,8 @@ static int aes_xts_check_keys_differ(const unsigned char *key, size_t bytes,
  * Provider dispatch functions
  */
 static int aes_xts_init(void *vctx, const unsigned char *key, size_t keylen,
-                        const unsigned char *iv, size_t ivlen, int enc)
+                        const unsigned char *iv, size_t ivlen,
+                        const OSSL_PARAM params[], int enc)
 {
     PROV_AES_XTS_CTX *xctx = (PROV_AES_XTS_CTX *)vctx;
     PROV_CIPHER_CTX *ctx = &xctx->base;
@@ -87,21 +88,24 @@ static int aes_xts_init(void *vctx, const unsigned char *key, size_t keylen,
         }
         if (!aes_xts_check_keys_differ(key, keylen / 2, enc))
             return 0;
-        return ctx->hw->init(ctx, key, keylen);
+        if (!ctx->hw->init(ctx, key, keylen))
+            return 0;
     }
-    return 1;
+    return aes_xts_set_ctx_params(ctx, params);
 }
 
 static int aes_xts_einit(void *vctx, const unsigned char *key, size_t keylen,
-                         const unsigned char *iv, size_t ivlen)
+                         const unsigned char *iv, size_t ivlen,
+                         const OSSL_PARAM params[])
 {
-    return aes_xts_init(vctx, key, keylen, iv, ivlen, 1);
+    return aes_xts_init(vctx, key, keylen, iv, ivlen, params, 1);
 }
 
 static int aes_xts_dinit(void *vctx, const unsigned char *key, size_t keylen,
-                         const unsigned char *iv, size_t ivlen)
+                         const unsigned char *iv, size_t ivlen,
+                         const OSSL_PARAM params[])
 {
-    return aes_xts_init(vctx, key, keylen, iv, ivlen, 0);
+    return aes_xts_init(vctx, key, keylen, iv, ivlen, params, 0);
 }
 
 static void *aes_xts_newctx(void *provctx, unsigned int mode, uint64_t flags,
@@ -229,6 +233,9 @@ static int aes_xts_set_ctx_params(void *vctx, const OSSL_PARAM params[])
     PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx;
     const OSSL_PARAM *p;
 
+    if (params == NULL)
+        return 1;
+
     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_KEYLEN);
     if (p != NULL) {
         size_t keylen;