Remove unnecessary sm2_za.c
[openssl.git] / crypto / sm2 / sm2_sign.c
1 /*
2  * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the OpenSSL license (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11
12 #include "internal/sm2.h"
13 #include "internal/sm2err.h"
14 #include "internal/ec_int.h" /* ec_group_do_inverse_ord() */
15 #include <openssl/err.h>
16 #include <openssl/evp.h>
17 #include <openssl/err.h>
18 #include <openssl/bn.h>
19 #include <string.h>
20
21 static int sm2_compute_userid_digest(uint8_t *out,
22                                      const EVP_MD *digest,
23                                      const char *user_id,
24                                      const EC_KEY *key)
25 {
26     int rc = 0;
27     const EC_GROUP *group = EC_KEY_get0_group(key);
28     BN_CTX *ctx = NULL;
29     EVP_MD_CTX *hash = NULL;
30     BIGNUM *p = NULL;
31     BIGNUM *a = NULL;
32     BIGNUM *b = NULL;
33     BIGNUM *xG = NULL;
34     BIGNUM *yG = NULL;
35     BIGNUM *xA = NULL;
36     BIGNUM *yA = NULL;
37     int p_bytes = 0;
38     uint8_t *buf = NULL;
39     size_t uid_len = 0;
40     uint16_t entla = 0;
41     uint8_t e_byte = 0;
42
43     hash = EVP_MD_CTX_new();
44     ctx = BN_CTX_new();
45     if (hash == NULL || ctx == NULL) {
46         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_MALLOC_FAILURE);
47         goto done;
48     }
49
50     p = BN_CTX_get(ctx);
51     a = BN_CTX_get(ctx);
52     b = BN_CTX_get(ctx);
53     xG = BN_CTX_get(ctx);
54     yG = BN_CTX_get(ctx);
55     xA = BN_CTX_get(ctx);
56     yA = BN_CTX_get(ctx);
57
58     if (yA == NULL) {
59         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_MALLOC_FAILURE);
60         goto done;
61     }
62
63     if (!EVP_DigestInit(hash, digest)) {
64         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EVP_LIB);
65         goto done;
66     }
67
68     /* Z = SM3(ENTLA || IDA || a || b || xG || yG || xA || yA) */
69
70     uid_len = strlen(user_id);
71     if (uid_len >= (UINT16_MAX / 8)) {
72         /* too large */
73         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, SM2_R_USER_ID_TOO_LARGE);
74         goto done;
75     }
76
77     entla = (uint16_t)(8 * uid_len);
78
79     e_byte = entla >> 8;
80     if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
81         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EVP_LIB);
82         goto done;
83     }
84     e_byte = entla & 0xFF;
85     if (!EVP_DigestUpdate(hash, &e_byte, 1)
86             || !EVP_DigestUpdate(hash, user_id, uid_len)) {
87         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EVP_LIB);
88         goto done;
89     }
90
91     if (!EC_GROUP_get_curve(group, p, a, b, ctx)) {
92         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EC_LIB);
93         goto done;
94     }
95
96     p_bytes = BN_num_bytes(p);
97     buf = OPENSSL_zalloc(p_bytes);
98     if (buf == NULL) {
99         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_MALLOC_FAILURE);
100         goto done;
101     }
102
103     if (BN_bn2binpad(a, buf, p_bytes) < 0
104             || !EVP_DigestUpdate(hash, buf, p_bytes)
105             || BN_bn2binpad(b, buf, p_bytes) < 0
106             || !EVP_DigestUpdate(hash, buf, p_bytes)
107             || !EC_POINT_get_affine_coordinates(group,
108                                                 EC_GROUP_get0_generator(group),
109                                                 xG, yG, ctx)
110             || BN_bn2binpad(xG, buf, p_bytes) < 0
111             || !EVP_DigestUpdate(hash, buf, p_bytes)
112             || BN_bn2binpad(yG, buf, p_bytes) < 0
113             || !EVP_DigestUpdate(hash, buf, p_bytes)
114             || !EC_POINT_get_affine_coordinates(group,
115                                                 EC_KEY_get0_public_key(key),
116                                                 xA, yA, ctx)
117             || BN_bn2binpad(xA, buf, p_bytes) < 0
118             || !EVP_DigestUpdate(hash, buf, p_bytes)
119             || BN_bn2binpad(yA, buf, p_bytes) < 0
120             || !EVP_DigestUpdate(hash, buf, p_bytes)
121             || !EVP_DigestFinal(hash, out, NULL)) {
122         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_INTERNAL_ERROR);
123         goto done;
124     }
125
126     rc = 1;
127
128  done:
129     OPENSSL_free(buf);
130     BN_CTX_free(ctx);
131     EVP_MD_CTX_free(hash);
132     return rc;
133 }
134
135 static BIGNUM *sm2_compute_msg_hash(const EVP_MD *digest,
136                                     const EC_KEY *key,
137                                     const char *user_id,
138                                     const uint8_t *msg, size_t msg_len)
139 {
140     EVP_MD_CTX *hash = EVP_MD_CTX_new();
141     const int md_size = EVP_MD_size(digest);
142     uint8_t *za = NULL;
143     BIGNUM *e = NULL;
144
145     if (md_size < 0) {
146         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, SM2_R_INVALID_DIGEST);
147         goto done;
148     }
149
150     za = OPENSSL_zalloc(md_size);
151     if (hash == NULL || za == NULL) {
152         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_MALLOC_FAILURE);
153         goto done;
154     }
155
156     if (!sm2_compute_userid_digest(za, digest, user_id, key)) {
157         /* SM2err already called */
158         goto done;
159     }
160
161     if (!EVP_DigestInit(hash, digest)
162             || !EVP_DigestUpdate(hash, za, md_size)
163             || !EVP_DigestUpdate(hash, msg, msg_len)
164                /* reuse za buffer to hold H(ZA || M) */
165             || !EVP_DigestFinal(hash, za, NULL)) {
166         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
167         goto done;
168     }
169
170     e = BN_bin2bn(za, md_size, NULL);
171     if (e == NULL)
172         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_INTERNAL_ERROR);
173
174  done:
175     OPENSSL_free(za);
176     EVP_MD_CTX_free(hash);
177     return e;
178 }
179
180 static ECDSA_SIG *sm2_sig_gen(const EC_KEY *key, const BIGNUM *e)
181 {
182     const BIGNUM *dA = EC_KEY_get0_private_key(key);
183     const EC_GROUP *group = EC_KEY_get0_group(key);
184     const BIGNUM *order = EC_GROUP_get0_order(group);
185     ECDSA_SIG *sig = NULL;
186     EC_POINT *kG = NULL;
187     BN_CTX *ctx = NULL;
188     BIGNUM *k = NULL;
189     BIGNUM *rk = NULL;
190     BIGNUM *r = NULL;
191     BIGNUM *s = NULL;
192     BIGNUM *x1 = NULL;
193     BIGNUM *tmp = NULL;
194
195     kG = EC_POINT_new(group);
196     ctx = BN_CTX_new();
197     if (kG == NULL || ctx == NULL) {
198         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
199         goto done;
200     }
201
202     BN_CTX_start(ctx);
203     k = BN_CTX_get(ctx);
204     rk = BN_CTX_get(ctx);
205     x1 = BN_CTX_get(ctx);
206     tmp = BN_CTX_get(ctx);
207     if (tmp == NULL) {
208         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
209         goto done;
210     }
211
212     /*
213      * These values are returned and so should not be allocated out of the
214      * context
215      */
216     r = BN_new();
217     s = BN_new();
218
219     if (r == NULL || s == NULL) {
220         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
221         goto done;
222     }
223
224     for (;;) {
225         if (!BN_priv_rand_range(k, order)) {
226             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
227             goto done;
228         }
229
230         if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)
231                 || !EC_POINT_get_affine_coordinates(group, kG, x1, NULL,
232                                                     ctx)
233                 || !BN_mod_add(r, e, x1, order, ctx)) {
234             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
235             goto done;
236         }
237
238         /* try again if r == 0 or r+k == n */
239         if (BN_is_zero(r))
240             continue;
241
242         if (!BN_add(rk, r, k)) {
243             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
244             goto done;
245         }
246
247         if (BN_cmp(rk, order) == 0)
248             continue;
249
250         if (!BN_add(s, dA, BN_value_one())
251                 || !ec_group_do_inverse_ord(group, s, s, ctx)
252                 || !BN_mod_mul(tmp, dA, r, order, ctx)
253                 || !BN_sub(tmp, k, tmp)
254                 || !BN_mod_mul(s, s, tmp, order, ctx)) {
255             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
256             goto done;
257         }
258
259         sig = ECDSA_SIG_new();
260         if (sig == NULL) {
261             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
262             goto done;
263         }
264
265          /* takes ownership of r and s */
266         ECDSA_SIG_set0(sig, r, s);
267         break;
268     }
269
270  done:
271     if (sig == NULL) {
272         BN_free(r);
273         BN_free(s);
274     }
275
276     BN_CTX_free(ctx);
277     EC_POINT_free(kG);
278     return sig;
279 }
280
281 static int sm2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig,
282                           const BIGNUM *e)
283 {
284     int ret = 0;
285     const EC_GROUP *group = EC_KEY_get0_group(key);
286     const BIGNUM *order = EC_GROUP_get0_order(group);
287     BN_CTX *ctx = NULL;
288     EC_POINT *pt = NULL;
289     BIGNUM *t = NULL;
290     BIGNUM *x1 = NULL;
291     const BIGNUM *r = NULL;
292     const BIGNUM *s = NULL;
293
294     ctx = BN_CTX_new();
295     pt = EC_POINT_new(group);
296     if (ctx == NULL || pt == NULL) {
297         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
298         goto done;
299     }
300
301     BN_CTX_start(ctx);
302     t = BN_CTX_get(ctx);
303     x1 = BN_CTX_get(ctx);
304     if (x1 == NULL) {
305         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
306         goto done;
307     }
308
309     /*
310      * B1: verify whether r' in [1,n-1], verification failed if not
311      * B2: vefify whether s' in [1,n-1], verification failed if not
312      * B3: set M'~=ZA || M'
313      * B4: calculate e'=Hv(M'~)
314      * B5: calculate t = (r' + s') modn, verification failed if t=0
315      * B6: calculate the point (x1', y1')=[s']G + [t]PA
316      * B7: calculate R=(e'+x1') modn, verfication pass if yes, otherwise failed
317      */
318
319     ECDSA_SIG_get0(sig, &r, &s);
320
321     if (BN_cmp(r, BN_value_one()) < 0
322             || BN_cmp(s, BN_value_one()) < 0
323             || BN_cmp(order, r) <= 0
324             || BN_cmp(order, s) <= 0) {
325         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
326         goto done;
327     }
328
329     if (!BN_mod_add(t, r, s, order, ctx)) {
330         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
331         goto done;
332     }
333
334     if (BN_is_zero(t)) {
335         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
336         goto done;
337     }
338
339     if (!EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx)
340             || !EC_POINT_get_affine_coordinates(group, pt, x1, NULL, ctx)) {
341         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_EC_LIB);
342         goto done;
343     }
344
345     if (!BN_mod_add(t, e, x1, order, ctx)) {
346         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
347         goto done;
348     }
349
350     if (BN_cmp(r, t) == 0)
351         ret = 1;
352
353  done:
354     EC_POINT_free(pt);
355     BN_CTX_free(ctx);
356     return ret;
357 }
358
359 ECDSA_SIG *sm2_do_sign(const EC_KEY *key,
360                        const EVP_MD *digest,
361                        const char *user_id, const uint8_t *msg, size_t msg_len)
362 {
363     BIGNUM *e = NULL;
364     ECDSA_SIG *sig = NULL;
365
366     e = sm2_compute_msg_hash(digest, key, user_id, msg, msg_len);
367     if (e == NULL) {
368         /* SM2err already called */
369         goto done;
370     }
371
372     sig = sm2_sig_gen(key, e);
373
374  done:
375     BN_free(e);
376     return sig;
377 }
378
379 int sm2_do_verify(const EC_KEY *key,
380                   const EVP_MD *digest,
381                   const ECDSA_SIG *sig,
382                   const char *user_id, const uint8_t *msg, size_t msg_len)
383 {
384     BIGNUM *e = NULL;
385     int ret = 0;
386
387     e = sm2_compute_msg_hash(digest, key, user_id, msg, msg_len);
388     if (e == NULL) {
389         /* SM2err already called */
390         goto done;
391     }
392
393     ret = sm2_sig_verify(key, sig, e);
394
395  done:
396     BN_free(e);
397     return ret;
398 }
399
400 int sm2_sign(const unsigned char *dgst, int dgstlen,
401              unsigned char *sig, unsigned int *siglen, EC_KEY *eckey)
402 {
403     BIGNUM *e = NULL;
404     ECDSA_SIG *s = NULL;
405     int sigleni;
406     int ret = -1;
407
408     e = BN_bin2bn(dgst, dgstlen, NULL);
409     if (e == NULL) {
410        SM2err(SM2_F_SM2_SIGN, ERR_R_BN_LIB);
411        goto done;
412     }
413
414     s = sm2_sig_gen(eckey, e);
415
416     sigleni = i2d_ECDSA_SIG(s, &sig);
417     if (sigleni < 0) {
418        SM2err(SM2_F_SM2_SIGN, ERR_R_INTERNAL_ERROR);
419        goto done;
420     }
421     *siglen = (unsigned int)sigleni;
422
423     ret = 1;
424
425  done:
426     ECDSA_SIG_free(s);
427     BN_free(e);
428     return ret;
429 }
430
431 int sm2_verify(const unsigned char *dgst, int dgstlen,
432                const unsigned char *sig, int sig_len, EC_KEY *eckey)
433 {
434     ECDSA_SIG *s = NULL;
435     BIGNUM *e = NULL;
436     const unsigned char *p = sig;
437     unsigned char *der = NULL;
438     int derlen = -1;
439     int ret = -1;
440
441     s = ECDSA_SIG_new();
442     if (s == NULL) {
443         SM2err(SM2_F_SM2_VERIFY, ERR_R_MALLOC_FAILURE);
444         goto done;
445     }
446     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL) {
447         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
448         goto done;
449     }
450     /* Ensure signature uses DER and doesn't have trailing garbage */
451     derlen = i2d_ECDSA_SIG(s, &der);
452     if (derlen != sig_len || memcmp(sig, der, derlen) != 0) {
453         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
454         goto done;
455     }
456
457     e = BN_bin2bn(dgst, dgstlen, NULL);
458     if (e == NULL) {
459         SM2err(SM2_F_SM2_VERIFY, ERR_R_BN_LIB);
460         goto done;
461     }
462
463     ret = sm2_sig_verify(eckey, s, e);
464
465  done:
466     OPENSSL_free(der);
467     BN_free(e);
468     ECDSA_SIG_free(s);
469     return ret;
470 }