Ensure EC private keys retain leading zeros
[openssl.git] / crypto / ec / ec_asn1.c
index 6ff94a356362e99d755efef3935f523377b23cdf..4ad8494981bfc4c8e81926f14b753fc2729fc713 100644 (file)
@@ -1114,7 +1114,7 @@ int i2d_ECPrivateKey(EC_KEY *a, unsigned char **out)
 {
     int ret = 0, ok = 0;
     unsigned char *buffer = NULL;
-    size_t buf_len = 0, tmp_len;
+    size_t buf_len = 0, tmp_len, bn_len;
     EC_PRIVATEKEY *priv_key = NULL;
 
     if (a == NULL || a->group == NULL || a->priv_key == NULL ||
@@ -1130,18 +1130,32 @@ int i2d_ECPrivateKey(EC_KEY *a, unsigned char **out)
 
     priv_key->version = a->version;
 
-    buf_len = (size_t)BN_num_bytes(a->priv_key);
+    bn_len = (size_t)BN_num_bytes(a->priv_key);
+
+    /* Octetstring may need leading zeros if BN is to short */
+
+    buf_len = (EC_GROUP_get_degree(a->group) + 7) / 8;
+
+    if (bn_len > buf_len) {
+        ECerr(EC_F_I2D_ECPRIVATEKEY, EC_R_BUFFER_TOO_SMALL);
+        goto err;
+    }
+
     buffer = OPENSSL_malloc(buf_len);
     if (buffer == NULL) {
         ECerr(EC_F_I2D_ECPRIVATEKEY, ERR_R_MALLOC_FAILURE);
         goto err;
     }
 
-    if (!BN_bn2bin(a->priv_key, buffer)) {
+    if (!BN_bn2bin(a->priv_key, buffer + buf_len - bn_len)) {
         ECerr(EC_F_I2D_ECPRIVATEKEY, ERR_R_BN_LIB);
         goto err;
     }
 
+    if (buf_len - bn_len > 0) {
+        memset(buffer, 0, buf_len - bn_len);
+    }
+
     if (!M_ASN1_OCTET_STRING_set(priv_key->privateKey, buffer, buf_len)) {
         ECerr(EC_F_I2D_ECPRIVATEKEY, ERR_R_ASN1_LIB);
         goto err;
@@ -1226,16 +1240,19 @@ EC_KEY *d2i_ECParameters(EC_KEY **a, const unsigned char **in, long len)
             ECerr(EC_F_D2I_ECPARAMETERS, ERR_R_MALLOC_FAILURE);
             return NULL;
         }
-        if (a)
-            *a = ret;
     } else
         ret = *a;
 
     if (!d2i_ECPKParameters(&ret->group, in, len)) {
         ECerr(EC_F_D2I_ECPARAMETERS, ERR_R_EC_LIB);
+        if (a == NULL || *a != ret)
+             EC_KEY_free(ret);
         return NULL;
     }
 
+    if (a)
+        *a = ret;
+
     return ret;
 }