[SM2_sign] fix double free and return value
[openssl.git] / crypto / sm2 / sm2_sign.c
1 /*
2  * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the OpenSSL license (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11
12 #include <openssl/sm2.h>
13 #include <openssl/evp.h>
14 #include <openssl/bn.h>
15 #include <string.h>
16
17 static BIGNUM *compute_msg_hash(const EVP_MD *digest,
18                                 const EC_KEY *key,
19                                 const char *user_id,
20                                 const uint8_t *msg, size_t msg_len)
21 {
22     EVP_MD_CTX *hash = EVP_MD_CTX_new();
23     const int md_size = EVP_MD_size(digest);
24     uint8_t *za = OPENSSL_zalloc(md_size);
25     BIGNUM *e = NULL;
26
27     if (za == NULL)
28         goto done;
29
30     if (hash == NULL)
31         goto done;
32
33     if (SM2_compute_userid_digest(za, digest, user_id, key) == 0)
34         goto done;
35
36     if (EVP_DigestInit(hash, digest) == 0)
37         goto done;
38
39     if (EVP_DigestUpdate(hash, za, md_size) == 0)
40         goto done;
41
42     if (EVP_DigestUpdate(hash, msg, msg_len) == 0)
43         goto done;
44
45     /* reuse za buffer to hold H(ZA || M) */
46     if (EVP_DigestFinal(hash, za, NULL) == 0)
47         goto done;
48
49     e = BN_bin2bn(za, md_size, NULL);
50
51  done:
52     OPENSSL_free(za);
53     EVP_MD_CTX_free(hash);
54     return e;
55 }
56
57 static
58 ECDSA_SIG *SM2_sig_gen(const EC_KEY *key, const BIGNUM *e)
59 {
60     const BIGNUM *dA = EC_KEY_get0_private_key(key);
61     const EC_GROUP *group = EC_KEY_get0_group(key);
62     const BIGNUM *order = EC_GROUP_get0_order(group);
63
64     ECDSA_SIG *sig = NULL;
65     EC_POINT *kG = NULL;
66     BN_CTX *ctx = NULL;
67     BIGNUM *k = NULL;
68     BIGNUM *rk = NULL;
69     BIGNUM *r = NULL;
70     BIGNUM *s = NULL;
71     BIGNUM *x1 = NULL;
72     BIGNUM *tmp = NULL;
73
74     kG = EC_POINT_new(group);
75     if (kG == NULL)
76         goto done;
77
78     ctx = BN_CTX_new();
79     if (ctx == NULL)
80         goto done;
81
82     BN_CTX_start(ctx);
83
84     k = BN_CTX_get(ctx);
85     rk = BN_CTX_get(ctx);
86     x1 = BN_CTX_get(ctx);
87     tmp = BN_CTX_get(ctx);
88
89     if (tmp == NULL)
90         goto done;
91
92     /* These values are returned and so should not be allocated out of the context */
93     r = BN_new();
94     s = BN_new();
95
96     if (r == NULL || s == NULL)
97         goto done;
98
99     for (;;) {
100         BN_priv_rand_range(k, order);
101
102         if (EC_POINT_mul(group, kG, k, NULL, NULL, ctx) == 0)
103             goto done;
104
105         if (EC_POINT_get_affine_coordinates_GFp(group, kG, x1, NULL, ctx) == 0)
106             goto done;
107
108         if (BN_mod_add(r, e, x1, order, ctx) == 0)
109             goto done;
110
111         /* try again if r == 0 or r+k == n */
112         if (BN_is_zero(r))
113             continue;
114
115         BN_add(rk, r, k);
116
117         if (BN_cmp(rk, order) == 0)
118             continue;
119
120         BN_add(s, dA, BN_value_one());
121         BN_mod_inverse(s, s, order, ctx);
122
123         BN_mod_mul(tmp, dA, r, order, ctx);
124         BN_sub(tmp, k, tmp);
125
126         BN_mod_mul(s, s, tmp, order, ctx);
127
128         sig = ECDSA_SIG_new();
129
130         if (sig == NULL)
131             goto done;
132
133          /* takes ownership of r and s */
134         ECDSA_SIG_set0(sig, r, s);
135         break;
136     }
137
138  done:
139
140     if (sig == NULL) {
141         BN_free(r);
142         BN_free(s);
143     }
144
145     BN_CTX_free(ctx);
146     EC_POINT_free(kG);
147     return sig;
148
149 }
150
151 static
152 int SM2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig, const BIGNUM *e)
153 {
154     int ret = 0;
155     const EC_GROUP *group = EC_KEY_get0_group(key);
156     const BIGNUM *order = EC_GROUP_get0_order(group);
157     BN_CTX *ctx = NULL;
158     EC_POINT *pt = NULL;
159
160     BIGNUM *t = NULL;
161     BIGNUM *x1 = NULL;
162     const BIGNUM *r = NULL;
163     const BIGNUM *s = NULL;
164
165     ctx = BN_CTX_new();
166     if (ctx == NULL)
167         goto done;
168     pt = EC_POINT_new(group);
169     if (pt == NULL)
170         goto done;
171
172     BN_CTX_start(ctx);
173
174     t = BN_CTX_get(ctx);
175     x1 = BN_CTX_get(ctx);
176
177     if (x1 == NULL)
178         goto done;
179
180     /*
181        B1: verify whether r' in [1,n-1], verification failed if not
182        B2: vefify whether s' in [1,n-1], verification failed if not
183        B3: set M'~=ZA || M'
184        B4: calculate e'=Hv(M'~)
185        B5: calculate t = (r' + s') modn, verification failed if t=0
186        B6: calculate the point (x1', y1')=[s']G + [t]PA
187        B7: calculate R=(e'+x1') modn, verfication pass if yes, otherwise failed
188      */
189
190     ECDSA_SIG_get0(sig, &r, &s);
191
192     if (BN_cmp(r, BN_value_one()) < 0)
193         goto done;
194     if (BN_cmp(s, BN_value_one()) < 0)
195         goto done;
196
197     if (BN_cmp(order, r) <= 0)
198         goto done;
199     if (BN_cmp(order, s) <= 0)
200         goto done;
201
202     if (BN_mod_add(t, r, s, order, ctx) == 0)
203         goto done;
204
205     if (BN_is_zero(t) == 1)
206         goto done;
207
208     if (EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx) == 0)
209         goto done;
210
211     if (EC_POINT_get_affine_coordinates_GFp(group, pt, x1, NULL, ctx) == 0)
212         goto done;
213
214     if (BN_mod_add(t, e, x1, order, ctx) == 0)
215         goto done;
216
217     if (BN_cmp(r, t) == 0)
218         ret = 1;
219
220  done:
221     EC_POINT_free(pt);
222     BN_CTX_free(ctx);
223     return ret;
224 }
225
226 ECDSA_SIG *SM2_do_sign(const EC_KEY *key,
227                        const EVP_MD *digest,
228                        const char *user_id, const uint8_t *msg, size_t msg_len)
229 {
230     BIGNUM *e = NULL;
231     ECDSA_SIG *sig = NULL;
232
233     e = compute_msg_hash(digest, key, user_id, msg, msg_len);
234     if (e == NULL)
235         goto done;
236
237     sig = SM2_sig_gen(key, e);
238
239  done:
240     BN_free(e);
241     return sig;
242 }
243
244 int SM2_do_verify(const EC_KEY *key,
245                   const EVP_MD *digest,
246                   const ECDSA_SIG *sig,
247                   const char *user_id, const uint8_t *msg, size_t msg_len)
248 {
249     BIGNUM *e = NULL;
250     int ret = -1;
251
252     e = compute_msg_hash(digest, key, user_id, msg, msg_len);
253     if (e == NULL)
254         goto done;
255
256     ret = SM2_sig_verify(key, sig, e);
257
258  done:
259     BN_free(e);
260     return ret;
261 }
262
263 int SM2_sign(int type, const unsigned char *dgst, int dgstlen,
264              unsigned char *sig, unsigned int *siglen, EC_KEY *eckey)
265 {
266     BIGNUM *e = NULL;
267     ECDSA_SIG *s = NULL;
268     int ret = -1;
269
270     if (type != NID_sm3)
271         goto done;
272
273     if (dgstlen != 32)          /* expected length of SM3 hash */
274         goto done;
275
276     e = BN_bin2bn(dgst, dgstlen, NULL);
277
278     s = SM2_sig_gen(eckey, e);
279
280     *siglen = i2d_ECDSA_SIG(s, &sig);
281
282     ret = 1;
283
284  done:
285     ECDSA_SIG_free(s);
286     BN_free(e);
287     return ret;
288 }
289
290 int SM2_verify(int type, const unsigned char *dgst, int dgstlen,
291                const unsigned char *sig, int sig_len, EC_KEY *eckey)
292 {
293     ECDSA_SIG *s = NULL;
294     BIGNUM *e = NULL;
295     const unsigned char *p = sig;
296     unsigned char *der = NULL;
297     int derlen = -1;
298     int ret = -1;
299
300     if (type != NID_sm3)
301         goto done;
302
303     s = ECDSA_SIG_new();
304     if (s == NULL)
305         goto done;
306     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL)
307         goto done;
308     /* Ensure signature uses DER and doesn't have trailing garbage */
309     derlen = i2d_ECDSA_SIG(s, &der);
310     if (derlen != sig_len || memcmp(sig, der, derlen) != 0)
311         goto done;
312
313     e = BN_bin2bn(dgst, dgstlen, NULL);
314
315     ret = SM2_sig_verify(eckey, s, e);
316
317  done:
318     OPENSSL_free(der);
319     BN_free(e);
320     ECDSA_SIG_free(s);
321     return ret;
322 }