Fix KMAC bounds checks.
[openssl.git] / providers / implementations / macs / kmac_prov.c
index 111e0e8ba72b9a76d86c9f81289263239ecc4ddf..c95cf57ffbeeb49c94b121e17c61ae4c34819cae 100644 (file)
@@ -78,10 +78,14 @@ static OSSL_FUNC_mac_init_fn kmac_init;
 static OSSL_FUNC_mac_update_fn kmac_update;
 static OSSL_FUNC_mac_final_fn kmac_final;
 
-#define KMAC_MAX_BLOCKSIZE ((1600 - 128*2) / 8) /* 168 */
+#define KMAC_MAX_BLOCKSIZE ((1600 - 128 * 2) / 8) /* 168 */
 
-/* Length encoding will be  a 1 byte size + length in bits (2 bytes max) */
-#define KMAC_MAX_ENCODED_HEADER_LEN 3
+/*
+ * Length encoding will be  a 1 byte size + length in bits (3 bytes max)
+ * This gives a range of 0..0XFFFFFF bits = 2097151 bytes).
+ */
+#define KMAC_MAX_OUTPUT_LEN (0xFFFFFF / 8)
+#define KMAC_MAX_ENCODED_HEADER_LEN (1 + 3)
 
 /*
  * Restrict the maximum length of the customisation string.  This must not
@@ -92,12 +96,13 @@ static OSSL_FUNC_mac_final_fn kmac_final;
 /* Maximum size of encoded custom string */
 #define KMAC_MAX_CUSTOM_ENCODED (KMAC_MAX_CUSTOM + KMAC_MAX_ENCODED_HEADER_LEN)
 
-/* Maximum key size in bytes = 2040 / 8 */
-#define KMAC_MAX_KEY 255
+/* Maximum key size in bytes = 256 (2048 bits) */
+#define KMAC_MAX_KEY 256
+#define KMAC_MIN_KEY 4
 
 /*
  * Maximum Encoded Key size will be padded to a multiple of the blocksize
- * i.e KMAC_MAX_KEY + KMAC_MAX_ENCODED_LEN = 258
+ * i.e KMAC_MAX_KEY + KMAC_MAX_ENCODED_HEADER_LEN = 256 + 4
  * Padded to a multiple of KMAC_MAX_BLOCKSIZE
  */
 #define KMAC_MAX_KEY_ENCODED (KMAC_MAX_BLOCKSIZE * 2)
@@ -107,7 +112,6 @@ static const unsigned char kmac_string[] = {
     0x01, 0x20, 0x4B, 0x4D, 0x41, 0x43
 };
 
-
 #define KMAC_FLAG_XOF_MODE          1
 
 struct kmac_data_st {
@@ -124,14 +128,16 @@ struct kmac_data_st {
     unsigned char custom[KMAC_MAX_CUSTOM_ENCODED];
 };
 
-static int encode_string(unsigned char *out, size_t *out_len,
+static int encode_string(unsigned char *out, size_t out_max_len, size_t *out_len,
                          const unsigned char *in, size_t in_len);
-static int right_encode(unsigned char *out, size_t *out_len, size_t bits);
+static int right_encode(unsigned char *out, size_t out_max_len, size_t *out_len,
+                        size_t bits);
 static int bytepad(unsigned char *out, size_t *out_len,
                    const unsigned char *in1, size_t in1_len,
                    const unsigned char *in2, size_t in2_len,
                    size_t w);
-static int kmac_bytepad_encode_key(unsigned char *out, size_t *out_len,
+static int kmac_bytepad_encode_key(unsigned char *out, size_t out_max_len,
+                                   size_t *out_len,
                                    const unsigned char *in, size_t in_len,
                                    size_t w);
 
@@ -246,7 +252,7 @@ static int kmac_setkey(struct kmac_data_st *kctx, const unsigned char *key,
     const EVP_MD *digest = ossl_prov_digest_md(&kctx->digest);
     int w = EVP_MD_block_size(digest);
 
-    if (keylen < 4 || keylen > KMAC_MAX_KEY) {
+    if (keylen < KMAC_MIN_KEY || keylen > KMAC_MAX_KEY) {
         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH);
         return 0;
     }
@@ -254,7 +260,7 @@ static int kmac_setkey(struct kmac_data_st *kctx, const unsigned char *key,
         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
         return 0;
     }
-    if (!kmac_bytepad_encode_key(kctx->key, &kctx->key_len,
+    if (!kmac_bytepad_encode_key(kctx->key, sizeof(kctx->key), &kctx->key_len,
                                  key, keylen, (size_t)w))
         return 0;
     return 1;
@@ -346,7 +352,7 @@ static int kmac_final(void *vmacctx, unsigned char *out, size_t *outl,
     /* KMAC XOF mode sets the encoded length to 0 */
     lbits = (kctx->xof_mode ? 0 : (kctx->out_len * 8));
 
-    ok = right_encode(encoded_outlen, &len, lbits)
+    ok = right_encode(encoded_outlen, sizeof(encoded_outlen), &len, lbits)
         && EVP_DigestUpdate(ctx, encoded_outlen, len)
         && EVP_DigestFinalXOF(ctx, out, kctx->out_len);
     *outl = kctx->out_len;
@@ -406,9 +412,17 @@ static int kmac_set_ctx_params(void *vmacctx, const OSSL_PARAM *params)
     if ((p = OSSL_PARAM_locate_const(params, OSSL_MAC_PARAM_XOF)) != NULL
         && !OSSL_PARAM_get_int(p, &kctx->xof_mode))
         return 0;
-    if (((p = OSSL_PARAM_locate_const(params, OSSL_MAC_PARAM_SIZE)) != NULL)
-        && !OSSL_PARAM_get_size_t(p, &kctx->out_len))
-        return 0;
+    if ((p = OSSL_PARAM_locate_const(params, OSSL_MAC_PARAM_SIZE)) != NULL) {
+        size_t sz = 0;
+
+        if (!OSSL_PARAM_get_size_t(p, &sz))
+            return 0;
+        if (sz > KMAC_MAX_OUTPUT_LEN) {
+            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_OUTPUT_LENGTH);
+            return 0;
+        }
+        kctx->out_len = sz;
+    }
     if ((p = OSSL_PARAM_locate_const(params, OSSL_MAC_PARAM_KEY)) != NULL
             && !kmac_setkey(kctx, p->data, p->data_size))
         return 0;
@@ -418,16 +432,14 @@ static int kmac_set_ctx_params(void *vmacctx, const OSSL_PARAM *params)
             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_CUSTOM_LENGTH);
             return 0;
         }
-        if (!encode_string(kctx->custom, &kctx->custom_len,
+        if (!encode_string(kctx->custom, sizeof(kctx->custom), &kctx->custom_len,
                            p->data, p->data_size))
             return 0;
     }
     return 1;
 }
 
-/*
- * Encoding/Padding Methods.
- */
+/* Encoding/Padding Methods. */
 
 /* Returns the number of bytes required to store 'bits' into a byte array */
 static unsigned int get_encode_size(size_t bits)
@@ -450,15 +462,14 @@ static unsigned int get_encode_size(size_t bits)
  * *out_len.
  *
  * e.g if bits = 32, out[2] = { 0x20, 0x01 }
- *
  */
-static int right_encode(unsigned char *out, size_t *out_len, size_t bits)
+static int right_encode(unsigned char *out, size_t out_max_len, size_t *out_len,
+                        size_t bits)
 {
     unsigned int len = get_encode_size(bits);
     int i;
 
-    /* The length is constrained to a single byte: 2040/8 = 255 */
-    if (len > 0xFF) {
+    if (len >= out_max_len) {
         ERR_raise(ERR_LIB_PROV, PROV_R_LENGTH_TOO_LARGE);
         return 0;
     }
@@ -483,17 +494,19 @@ static int right_encode(unsigned char *out, size_t *out_len, size_t bits)
  * e.g- in="KMAC" gives out[6] = { 0x01, 0x20, 0x4B, 0x4D, 0x41, 0x43 }
  *                                 len   bits    K     M     A     C
  */
-static int encode_string(unsigned char *out, size_t *out_len,
+static int encode_string(unsigned char *out, size_t out_max_len, size_t *out_len,
                          const unsigned char *in, size_t in_len)
 {
     if (in == NULL) {
         *out_len = 0;
     } else {
-        size_t i, bits, len;
+        size_t i, bits, len, sz;
 
         bits = 8 * in_len;
         len = get_encode_size(bits);
-        if (len > 0xFF) {
+        sz = 1 + len + in_len;
+
+        if (sz > out_max_len) {
             ERR_raise(ERR_LIB_PROV, PROV_R_LENGTH_TOO_LARGE);
             return 0;
         }
@@ -504,7 +517,7 @@ static int encode_string(unsigned char *out, size_t *out_len,
             bits >>= 8;
         }
         memcpy(out + len + 1, in, in_len);
-        *out_len = (1 + len + in_len);
+        *out_len = sz;
     }
     return 1;
 }
@@ -560,20 +573,22 @@ static int bytepad(unsigned char *out, size_t *out_len,
     return 1;
 }
 
-/*
- * Returns out = bytepad(encode_string(in), w)
- */
-static int kmac_bytepad_encode_key(unsigned char *out, size_t *out_len,
+/* Returns out = bytepad(encode_string(in), w) */
+static int kmac_bytepad_encode_key(unsigned char *out, size_t out_max_len,
+                                   size_t *out_len,
                                    const unsigned char *in, size_t in_len,
                                    size_t w)
 {
     unsigned char tmp[KMAC_MAX_KEY + KMAC_MAX_ENCODED_HEADER_LEN];
     size_t tmp_len;
 
-    if (!encode_string(tmp, &tmp_len, in, in_len))
+    if (!encode_string(tmp, sizeof(tmp), &tmp_len, in, in_len))
         return 0;
-
-    return bytepad(out, out_len, tmp, tmp_len, NULL, 0, w);
+    if (!bytepad(NULL, out_len, tmp, tmp_len, NULL, 0, w))
+        return 0;
+    if (!ossl_assert(*out_len <= out_max_len))
+        return 0;
+    return bytepad(out, NULL, tmp, tmp_len, NULL, 0, w);
 }
 
 const OSSL_DISPATCH ossl_kmac128_functions[] = {