evp/e_aes.c: wire hardware-assisted block function to OCB.
[openssl.git] / crypto / evp / e_aes.c
index eaceab2e632eab32f327ddcf56dcab775c76dcef..b067dcfdff0420bfb322598c90e3bb0afb6b3cfa 100644 (file)
@@ -50,6 +50,7 @@
 
 #include <openssl/opensslconf.h>
 #ifndef OPENSSL_NO_AES
+#include <openssl/crypto.h>
 # include <openssl/evp.h>
 # include <openssl/err.h>
 # include <string.h>
@@ -109,14 +110,21 @@ typedef struct {
     int tag_set;                /* Set if tag is valid */
     int len_set;                /* Set if message length set */
     int L, M;                   /* L and M parameters from RFC3610 */
+    int tls_aad_len;            /* TLS AAD length */
     CCM128_CONTEXT ccm;
     ccm128_f str;
 } EVP_AES_CCM_CTX;
 
 # ifndef OPENSSL_NO_OCB
 typedef struct {
-    AES_KEY ksenc;              /* AES key schedule to use for encryption */
-    AES_KEY ksdec;              /* AES key schedule to use for decryption */
+    union {
+        double align;
+        AES_KEY ks;
+    } ksenc;                    /* AES key schedule to use for encryption */
+    union {
+        double align;
+        AES_KEY ks;
+    } ksdec;                    /* AES key schedule to use for decryption */
     int key_set;                /* Set if key initialised */
     int iv_set;                 /* Set if an iv is set */
     OCB128_CONTEXT ocb;
@@ -453,6 +461,19 @@ static int aesni_ccm_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
                             const unsigned char *in, size_t len);
 
 #  ifndef OPENSSL_NO_OCB
+void aesni_ocb_encrypt(const unsigned char *in, unsigned char *out,
+                       size_t blocks, const void *key,
+                       size_t start_block_num,
+                       unsigned char offset_i[16],
+                       const unsigned char L_[][16],
+                       unsigned char checksum[16]);
+void aesni_ocb_decrypt(const unsigned char *in, unsigned char *out,
+                       size_t blocks, const void *key,
+                       size_t start_block_num,
+                       unsigned char offset_i[16],
+                       const unsigned char L_[][16],
+                       unsigned char checksum[16]);
+
 static int aesni_ocb_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
                               const unsigned char *iv, int enc)
 {
@@ -466,11 +487,14 @@ static int aesni_ocb_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
              * needs both. We could possibly optimise to remove setting the
              * decrypt for an encryption operation.
              */
-            aesni_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc);
-            aesni_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec);
-            if (!CRYPTO_ocb128_init(&octx->ocb, &octx->ksenc, &octx->ksdec,
+            aesni_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc.ks);
+            aesni_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec.ks);
+            if (!CRYPTO_ocb128_init(&octx->ocb,
+                                    &octx->ksenc.ks, &octx->ksdec.ks,
                                     (block128_f) aesni_encrypt,
-                                    (block128_f) aesni_decrypt))
+                                    (block128_f) aesni_decrypt,
+                                    enc ? aesni_ocb_encrypt
+                                        : aesni_ocb_decrypt))
                 return 0;
         }
         while (0);
@@ -829,6 +853,7 @@ static int aes_t4_ccm_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
         aes_t4_set_encrypt_key(key, bits, &cctx->ks.ks);
         CRYPTO_ccm128_init(&cctx->ccm, cctx->M, cctx->L,
                            &cctx->ks, (block128_f) aes_t4_encrypt);
+        cctx->str = NULL;
         cctx->key_set = 1;
     }
     if (iv) {
@@ -856,11 +881,13 @@ static int aes_t4_ocb_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
              * needs both. We could possibly optimise to remove setting the
              * decrypt for an encryption operation.
              */
-            aes_t4_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc);
-            aes_t4_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec);
-            if (!CRYPTO_ocb128_init(&octx->ocb, &octx->ksenc, &octx->ksdec,
+            aes_t4_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc.ks);
+            aes_t4_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec.ks);
+            if (!CRYPTO_ocb128_init(&octx->ocb,
+                                    &octx->ksenc.ks, &octx->ksdec.ks,
                                     (block128_f) aes_t4_encrypt,
-                                    (block128_f) aes_t4_decrypt))
+                                    (block128_f) aes_t4_decrypt,
+                                    NULL))
                 return 0;
         }
         while (0);
@@ -971,6 +998,9 @@ const EVP_CIPHER *EVP_aes_##keylen##_##mode(void) \
 #   if defined(BSAES_ASM)
 #    define BSAES_CAPABLE (OPENSSL_armcap_P & ARMV7_NEON)
 #   endif
+#   if defined(VPAES_ASM)
+#    define VPAES_CAPABLE (OPENSSL_armcap_P & ARMV7_NEON)
+#   endif
 #   define HWAES_CAPABLE (OPENSSL_armcap_P & ARMV8_AES)
 #   define HWAES_set_encrypt_key aes_v8_set_encrypt_key
 #   define HWAES_set_decrypt_key aes_v8_set_decrypt_key
@@ -1251,7 +1281,7 @@ static int aes_gcm_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
             if (gctx->iv != c->iv)
                 OPENSSL_free(gctx->iv);
             gctx->iv = OPENSSL_malloc(arg);
-            if (!gctx->iv)
+            if (gctx->iv == NULL)
                 return 0;
         }
         gctx->ivlen = arg;
@@ -1315,7 +1345,7 @@ static int aes_gcm_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
 
     case EVP_CTRL_AEAD_TLS1_AAD:
         /* Save the AAD for later use */
-        if (arg != 13)
+        if (arg != EVP_AEAD_TLS1_AAD_LEN)
             return 0;
         memcpy(c->buf, ptr, arg);
         gctx->tls_aad_len = arg;
@@ -1345,7 +1375,7 @@ static int aes_gcm_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
                 gctx_out->iv = out->iv;
             else {
                 gctx_out->iv = OPENSSL_malloc(gctx->ivlen);
-                if (!gctx_out->iv)
+                if (gctx_out->iv == NULL)
                     return 0;
                 memcpy(gctx_out->iv, gctx->iv, gctx->ivlen);
             }
@@ -1543,7 +1573,7 @@ static int aes_gcm_tls_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
         /* Retrieve tag */
         CRYPTO_gcm128_tag(&gctx->gcm, ctx->buf, EVP_GCM_TLS_TAG_LEN);
         /* If tag mismatch wipe buffer */
-        if (memcmp(ctx->buf, in + len, EVP_GCM_TLS_TAG_LEN)) {
+        if (CRYPTO_memcmp(ctx->buf, in + len, EVP_GCM_TLS_TAG_LEN)) {
             OPENSSL_cleanse(out, len);
             goto err;
         }
@@ -1840,6 +1870,34 @@ static int aes_ccm_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
         cctx->M = 12;
         cctx->tag_set = 0;
         cctx->len_set = 0;
+        cctx->tls_aad_len = -1;
+        return 1;
+
+    case EVP_CTRL_AEAD_TLS1_AAD:
+        /* Save the AAD for later use */
+        if (arg != EVP_AEAD_TLS1_AAD_LEN)
+            return 0;
+        memcpy(c->buf, ptr, arg);
+        cctx->tls_aad_len = arg;
+        {
+            uint16_t len = c->buf[arg - 2] << 8 | c->buf[arg - 1];
+            /* Correct length for explicit IV */
+            len -= EVP_CCM_TLS_EXPLICIT_IV_LEN;
+            /* If decrypting correct for tag too */
+            if (!c->encrypt)
+                len -= cctx->M;
+            c->buf[arg - 2] = len >> 8;
+            c->buf[arg - 1] = len & 0xff;
+        }
+        /* Extra padding: tag appended to record */
+        return cctx->M;
+
+    case EVP_CTRL_CCM_SET_IV_FIXED:
+        /* Sanity check length */
+        if (arg != EVP_CCM_TLS_FIXED_IV_LEN)
+            return 0;
+        /* Just copy to first part of IV */
+        memcpy(c->iv, ptr, arg);
         return 1;
 
     case EVP_CTRL_AEAD_SET_IVLEN:
@@ -1853,7 +1911,7 @@ static int aes_ccm_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
     case EVP_CTRL_AEAD_SET_TAG:
         if ((arg & 1) || arg < 4 || arg > 16)
             return 0;
-        if ((c->encrypt && ptr) || (!c->encrypt && !ptr))
+        if (c->encrypt && ptr)
             return 0;
         if (ptr) {
             cctx->tag_set = 1;
@@ -1932,14 +1990,66 @@ static int aes_ccm_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
     return 1;
 }
 
+static int aes_ccm_tls_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
+                              const unsigned char *in, size_t len)
+{
+    EVP_AES_CCM_CTX *cctx = ctx->cipher_data;
+    CCM128_CONTEXT *ccm = &cctx->ccm;
+    /* Encrypt/decrypt must be performed in place */
+    if (out != in || len < (EVP_CCM_TLS_EXPLICIT_IV_LEN + (size_t)cctx->M))
+        return -1;
+    /* If encrypting set explicit IV from sequence number (start of AAD) */
+    if (ctx->encrypt)
+        memcpy(out, ctx->buf, EVP_CCM_TLS_EXPLICIT_IV_LEN);
+    /* Get rest of IV from explicit IV */
+    memcpy(ctx->iv + EVP_CCM_TLS_FIXED_IV_LEN, in, EVP_CCM_TLS_EXPLICIT_IV_LEN);
+    /* Correct length value */
+    len -= EVP_CCM_TLS_EXPLICIT_IV_LEN + cctx->M;
+    if (CRYPTO_ccm128_setiv(ccm, ctx->iv, 15 - cctx->L, len))
+            return -1;
+    /* Use saved AAD */
+    CRYPTO_ccm128_aad(ccm, ctx->buf, cctx->tls_aad_len);
+    /* Fix buffer to point to payload */
+    in += EVP_CCM_TLS_EXPLICIT_IV_LEN;
+    out += EVP_CCM_TLS_EXPLICIT_IV_LEN;
+    if (ctx->encrypt) {
+        if (cctx->str ? CRYPTO_ccm128_encrypt_ccm64(ccm, in, out, len,
+                                                    cctx->str) :
+            CRYPTO_ccm128_encrypt(ccm, in, out, len))
+            return -1;
+        if (!CRYPTO_ccm128_tag(ccm, out + len, cctx->M))
+            return -1;
+        return len + EVP_CCM_TLS_EXPLICIT_IV_LEN + cctx->M;
+    } else {
+        if (cctx->str ? !CRYPTO_ccm128_decrypt_ccm64(ccm, in, out, len,
+                                                     cctx->str) :
+            !CRYPTO_ccm128_decrypt(ccm, in, out, len)) {
+            unsigned char tag[16];
+            if (CRYPTO_ccm128_tag(ccm, tag, cctx->M)) {
+                if (!CRYPTO_memcmp(tag, in + len, cctx->M))
+                    return len;
+            }
+        }
+        OPENSSL_cleanse(out, len);
+        return -1;
+    }
+}
+
 static int aes_ccm_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
                           const unsigned char *in, size_t len)
 {
     EVP_AES_CCM_CTX *cctx = ctx->cipher_data;
     CCM128_CONTEXT *ccm = &cctx->ccm;
     /* If not set up, return error */
-    if (!cctx->iv_set && !cctx->key_set)
+    if (!cctx->key_set)
+        return -1;
+
+    if (cctx->tls_aad_len >= 0)
+        return aes_ccm_tls_cipher(ctx, out, in, len);
+
+    if (!cctx->iv_set)
         return -1;
+
     if (!ctx->encrypt && !cctx->tag_set)
         return -1;
     if (!out) {
@@ -1978,7 +2088,7 @@ static int aes_ccm_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
             !CRYPTO_ccm128_decrypt(ccm, in, out, len)) {
             unsigned char tag[16];
             if (CRYPTO_ccm128_tag(ccm, tag, cctx->M)) {
-                if (!memcmp(tag, ctx->buf, cctx->M))
+                if (!CRYPTO_memcmp(tag, ctx->buf, cctx->M))
                     rv = len;
             }
         }
@@ -1994,9 +2104,12 @@ static int aes_ccm_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
 
 # define aes_ccm_cleanup NULL
 
-BLOCK_CIPHER_custom(NID_aes, 128, 1, 12, ccm, CCM, CUSTOM_FLAGS)
-    BLOCK_CIPHER_custom(NID_aes, 192, 1, 12, ccm, CCM, CUSTOM_FLAGS)
-    BLOCK_CIPHER_custom(NID_aes, 256, 1, 12, ccm, CCM, CUSTOM_FLAGS)
+BLOCK_CIPHER_custom(NID_aes, 128, 1, 12, ccm, CCM,
+                    EVP_CIPH_FLAG_AEAD_CIPHER | CUSTOM_FLAGS)
+    BLOCK_CIPHER_custom(NID_aes, 192, 1, 12, ccm, CCM,
+                        EVP_CIPH_FLAG_AEAD_CIPHER | CUSTOM_FLAGS)
+    BLOCK_CIPHER_custom(NID_aes, 256, 1, 12, ccm, CCM,
+                        EVP_CIPH_FLAG_AEAD_CIPHER | CUSTOM_FLAGS)
 
 typedef struct {
     union {
@@ -2222,7 +2335,8 @@ static int aes_ocb_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
         newc = (EVP_CIPHER_CTX *)ptr;
         new_octx = newc->cipher_data;
         return CRYPTO_ocb128_copy_ctx(&new_octx->ocb, &octx->ocb,
-                                      &new_octx->ksenc, &new_octx->ksdec);
+                                      &new_octx->ksenc.ks,
+                                      &new_octx->ksdec.ks);
 
     default:
         return -1;
@@ -2230,6 +2344,29 @@ static int aes_ocb_ctrl(EVP_CIPHER_CTX *c, int type, int arg, void *ptr)
     }
 }
 
+#  ifdef HWAES_CAPABLE
+#   ifdef HWAES_ocb_encrypt
+void HWAES_ocb_encrypt(const unsigned char *in, unsigned char *out,
+                       size_t blocks, const void *key,
+                       size_t start_block_num,
+                       unsigned char offset_i[16],
+                       const unsigned char L_[][16],
+                       unsigned char checksum[16]);
+#   else
+#     define HWAES_ocb_encrypt NULL
+#   endif
+#   ifdef HWAES_ocb_decrypt
+void HWAES_ocb_decrypt(const unsigned char *in, unsigned char *out,
+                       size_t blocks, const void *key,
+                       size_t start_block_num,
+                       unsigned char offset_i[16],
+                       const unsigned char L_[][16],
+                       unsigned char checksum[16]);
+#   else
+#     define HWAES_ocb_decrypt NULL
+#   endif
+#  endif
+
 static int aes_ocb_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
                             const unsigned char *iv, int enc)
 {
@@ -2243,22 +2380,40 @@ static int aes_ocb_init_key(EVP_CIPHER_CTX *ctx, const unsigned char *key,
              * needs both. We could possibly optimise to remove setting the
              * decrypt for an encryption operation.
              */
+#  ifdef HWAES_CAPABLE
+            if (HWAES_CAPABLE) {
+                HWAES_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc.ks);
+                HWAES_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec.ks);
+                if (!CRYPTO_ocb128_init(&octx->ocb,
+                                        &octx->ksenc.ks, &octx->ksdec.ks,
+                                        (block128_f) HWAES_encrypt,
+                                        (block128_f) HWAES_decrypt,
+                                        enc ? HWAES_ocb_encrypt
+                                            : HWAES_ocb_decrypt))
+                    return 0;
+                break;
+            }
+#  endif
 #  ifdef VPAES_CAPABLE
             if (VPAES_CAPABLE) {
-                vpaes_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc);
-                vpaes_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec);
-                if (!CRYPTO_ocb128_init
-                    (&octx->ocb, &octx->ksenc, &octx->ksdec,
-                     (block128_f) vpaes_encrypt, (block128_f) vpaes_decrypt))
+                vpaes_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc.ks);
+                vpaes_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec.ks);
+                if (!CRYPTO_ocb128_init(&octx->ocb,
+                                        &octx->ksenc.ks, &octx->ksdec.ks,
+                                        (block128_f) vpaes_encrypt,
+                                        (block128_f) vpaes_decrypt,
+                                        NULL))
                     return 0;
                 break;
             }
 #  endif
-            AES_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc);
-            AES_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec);
-            if (!CRYPTO_ocb128_init(&octx->ocb, &octx->ksenc, &octx->ksdec,
+            AES_set_encrypt_key(key, ctx->key_len * 8, &octx->ksenc.ks);
+            AES_set_decrypt_key(key, ctx->key_len * 8, &octx->ksdec.ks);
+            if (!CRYPTO_ocb128_init(&octx->ocb,
+                                    &octx->ksenc.ks, &octx->ksdec.ks,
                                     (block128_f) AES_encrypt,
-                                    (block128_f) AES_decrypt))
+                                    (block128_f) AES_decrypt,
+                                    NULL))
                 return 0;
         }
         while (0);