Make SM2 ID stick to specification
[openssl.git] / crypto / sm2 / sm2_sign.c
1 /*
2  * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the OpenSSL license (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11
12 #include "internal/sm2.h"
13 #include "internal/sm2err.h"
14 #include "internal/ec_int.h" /* ec_group_do_inverse_ord() */
15 #include <openssl/err.h>
16 #include <openssl/evp.h>
17 #include <openssl/err.h>
18 #include <openssl/bn.h>
19 #include <string.h>
20
21 int sm2_compute_userid_digest(uint8_t *out,
22                               const EVP_MD *digest,
23                               const uint8_t *id,
24                               const size_t id_len,
25                               const EC_KEY *key)
26 {
27     int rc = 0;
28     const EC_GROUP *group = EC_KEY_get0_group(key);
29     BN_CTX *ctx = NULL;
30     EVP_MD_CTX *hash = NULL;
31     BIGNUM *p = NULL;
32     BIGNUM *a = NULL;
33     BIGNUM *b = NULL;
34     BIGNUM *xG = NULL;
35     BIGNUM *yG = NULL;
36     BIGNUM *xA = NULL;
37     BIGNUM *yA = NULL;
38     int p_bytes = 0;
39     uint8_t *buf = NULL;
40     uint16_t entla = 0;
41     uint8_t e_byte = 0;
42
43     hash = EVP_MD_CTX_new();
44     ctx = BN_CTX_new();
45     if (hash == NULL || ctx == NULL) {
46         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_MALLOC_FAILURE);
47         goto done;
48     }
49
50     p = BN_CTX_get(ctx);
51     a = BN_CTX_get(ctx);
52     b = BN_CTX_get(ctx);
53     xG = BN_CTX_get(ctx);
54     yG = BN_CTX_get(ctx);
55     xA = BN_CTX_get(ctx);
56     yA = BN_CTX_get(ctx);
57
58     if (yA == NULL) {
59         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_MALLOC_FAILURE);
60         goto done;
61     }
62
63     if (!EVP_DigestInit(hash, digest)) {
64         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EVP_LIB);
65         goto done;
66     }
67
68     /* Z = SM3(ENTLA || IDA || a || b || xG || yG || xA || yA) */
69
70     if (id_len >= (UINT16_MAX / 8)) {
71         /* too large */
72         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, SM2_R_USER_ID_TOO_LARGE);
73         goto done;
74     }
75
76     entla = (uint16_t)(8 * id_len);
77
78     e_byte = entla >> 8;
79     if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
80         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EVP_LIB);
81         goto done;
82     }
83     e_byte = entla & 0xFF;
84     if (!EVP_DigestUpdate(hash, &e_byte, 1)
85             || !EVP_DigestUpdate(hash, id, id_len)) {
86         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EVP_LIB);
87         goto done;
88     }
89
90     if (!EC_GROUP_get_curve(group, p, a, b, ctx)) {
91         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_EC_LIB);
92         goto done;
93     }
94
95     p_bytes = BN_num_bytes(p);
96     buf = OPENSSL_zalloc(p_bytes);
97     if (buf == NULL) {
98         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_MALLOC_FAILURE);
99         goto done;
100     }
101
102     if (BN_bn2binpad(a, buf, p_bytes) < 0
103             || !EVP_DigestUpdate(hash, buf, p_bytes)
104             || BN_bn2binpad(b, buf, p_bytes) < 0
105             || !EVP_DigestUpdate(hash, buf, p_bytes)
106             || !EC_POINT_get_affine_coordinates(group,
107                                                 EC_GROUP_get0_generator(group),
108                                                 xG, yG, ctx)
109             || BN_bn2binpad(xG, buf, p_bytes) < 0
110             || !EVP_DigestUpdate(hash, buf, p_bytes)
111             || BN_bn2binpad(yG, buf, p_bytes) < 0
112             || !EVP_DigestUpdate(hash, buf, p_bytes)
113             || !EC_POINT_get_affine_coordinates(group,
114                                                 EC_KEY_get0_public_key(key),
115                                                 xA, yA, ctx)
116             || BN_bn2binpad(xA, buf, p_bytes) < 0
117             || !EVP_DigestUpdate(hash, buf, p_bytes)
118             || BN_bn2binpad(yA, buf, p_bytes) < 0
119             || !EVP_DigestUpdate(hash, buf, p_bytes)
120             || !EVP_DigestFinal(hash, out, NULL)) {
121         SM2err(SM2_F_SM2_COMPUTE_USERID_DIGEST, ERR_R_INTERNAL_ERROR);
122         goto done;
123     }
124
125     rc = 1;
126
127  done:
128     OPENSSL_free(buf);
129     BN_CTX_free(ctx);
130     EVP_MD_CTX_free(hash);
131     return rc;
132 }
133
134 static BIGNUM *sm2_compute_msg_hash(const EVP_MD *digest,
135                                     const EC_KEY *key,
136                                     const uint8_t *id,
137                                     const size_t id_len,
138                                     const uint8_t *msg, size_t msg_len)
139 {
140     EVP_MD_CTX *hash = EVP_MD_CTX_new();
141     const int md_size = EVP_MD_size(digest);
142     uint8_t *za = NULL;
143     BIGNUM *e = NULL;
144
145     if (md_size < 0) {
146         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, SM2_R_INVALID_DIGEST);
147         goto done;
148     }
149
150     za = OPENSSL_zalloc(md_size);
151     if (hash == NULL || za == NULL) {
152         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_MALLOC_FAILURE);
153         goto done;
154     }
155
156     if (!sm2_compute_userid_digest(za, digest, id, id_len, key)) {
157         /* SM2err already called */
158         goto done;
159     }
160
161     if (!EVP_DigestInit(hash, digest)
162             || !EVP_DigestUpdate(hash, za, md_size)
163             || !EVP_DigestUpdate(hash, msg, msg_len)
164                /* reuse za buffer to hold H(ZA || M) */
165             || !EVP_DigestFinal(hash, za, NULL)) {
166         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
167         goto done;
168     }
169
170     e = BN_bin2bn(za, md_size, NULL);
171     if (e == NULL)
172         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_INTERNAL_ERROR);
173
174  done:
175     OPENSSL_free(za);
176     EVP_MD_CTX_free(hash);
177     return e;
178 }
179
180 static ECDSA_SIG *sm2_sig_gen(const EC_KEY *key, const BIGNUM *e)
181 {
182     const BIGNUM *dA = EC_KEY_get0_private_key(key);
183     const EC_GROUP *group = EC_KEY_get0_group(key);
184     const BIGNUM *order = EC_GROUP_get0_order(group);
185     ECDSA_SIG *sig = NULL;
186     EC_POINT *kG = NULL;
187     BN_CTX *ctx = NULL;
188     BIGNUM *k = NULL;
189     BIGNUM *rk = NULL;
190     BIGNUM *r = NULL;
191     BIGNUM *s = NULL;
192     BIGNUM *x1 = NULL;
193     BIGNUM *tmp = NULL;
194
195     kG = EC_POINT_new(group);
196     ctx = BN_CTX_new();
197     if (kG == NULL || ctx == NULL) {
198         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
199         goto done;
200     }
201
202     BN_CTX_start(ctx);
203     k = BN_CTX_get(ctx);
204     rk = BN_CTX_get(ctx);
205     x1 = BN_CTX_get(ctx);
206     tmp = BN_CTX_get(ctx);
207     if (tmp == NULL) {
208         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
209         goto done;
210     }
211
212     /*
213      * These values are returned and so should not be allocated out of the
214      * context
215      */
216     r = BN_new();
217     s = BN_new();
218
219     if (r == NULL || s == NULL) {
220         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
221         goto done;
222     }
223
224     for (;;) {
225         if (!BN_priv_rand_range(k, order)) {
226             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
227             goto done;
228         }
229
230         if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)
231                 || !EC_POINT_get_affine_coordinates(group, kG, x1, NULL,
232                                                     ctx)
233                 || !BN_mod_add(r, e, x1, order, ctx)) {
234             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
235             goto done;
236         }
237
238         /* try again if r == 0 or r+k == n */
239         if (BN_is_zero(r))
240             continue;
241
242         if (!BN_add(rk, r, k)) {
243             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
244             goto done;
245         }
246
247         if (BN_cmp(rk, order) == 0)
248             continue;
249
250         if (!BN_add(s, dA, BN_value_one())
251                 || !ec_group_do_inverse_ord(group, s, s, ctx)
252                 || !BN_mod_mul(tmp, dA, r, order, ctx)
253                 || !BN_sub(tmp, k, tmp)
254                 || !BN_mod_mul(s, s, tmp, order, ctx)) {
255             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
256             goto done;
257         }
258
259         sig = ECDSA_SIG_new();
260         if (sig == NULL) {
261             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
262             goto done;
263         }
264
265          /* takes ownership of r and s */
266         ECDSA_SIG_set0(sig, r, s);
267         break;
268     }
269
270  done:
271     if (sig == NULL) {
272         BN_free(r);
273         BN_free(s);
274     }
275
276     BN_CTX_free(ctx);
277     EC_POINT_free(kG);
278     return sig;
279 }
280
281 static int sm2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig,
282                           const BIGNUM *e)
283 {
284     int ret = 0;
285     const EC_GROUP *group = EC_KEY_get0_group(key);
286     const BIGNUM *order = EC_GROUP_get0_order(group);
287     BN_CTX *ctx = NULL;
288     EC_POINT *pt = NULL;
289     BIGNUM *t = NULL;
290     BIGNUM *x1 = NULL;
291     const BIGNUM *r = NULL;
292     const BIGNUM *s = NULL;
293
294     ctx = BN_CTX_new();
295     pt = EC_POINT_new(group);
296     if (ctx == NULL || pt == NULL) {
297         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
298         goto done;
299     }
300
301     BN_CTX_start(ctx);
302     t = BN_CTX_get(ctx);
303     x1 = BN_CTX_get(ctx);
304     if (x1 == NULL) {
305         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
306         goto done;
307     }
308
309     /*
310      * B1: verify whether r' in [1,n-1], verification failed if not
311      * B2: vefify whether s' in [1,n-1], verification failed if not
312      * B3: set M'~=ZA || M'
313      * B4: calculate e'=Hv(M'~)
314      * B5: calculate t = (r' + s') modn, verification failed if t=0
315      * B6: calculate the point (x1', y1')=[s']G + [t]PA
316      * B7: calculate R=(e'+x1') modn, verfication pass if yes, otherwise failed
317      */
318
319     ECDSA_SIG_get0(sig, &r, &s);
320
321     if (BN_cmp(r, BN_value_one()) < 0
322             || BN_cmp(s, BN_value_one()) < 0
323             || BN_cmp(order, r) <= 0
324             || BN_cmp(order, s) <= 0) {
325         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
326         goto done;
327     }
328
329     if (!BN_mod_add(t, r, s, order, ctx)) {
330         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
331         goto done;
332     }
333
334     if (BN_is_zero(t)) {
335         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
336         goto done;
337     }
338
339     if (!EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx)
340             || !EC_POINT_get_affine_coordinates(group, pt, x1, NULL, ctx)) {
341         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_EC_LIB);
342         goto done;
343     }
344
345     if (!BN_mod_add(t, e, x1, order, ctx)) {
346         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
347         goto done;
348     }
349
350     if (BN_cmp(r, t) == 0)
351         ret = 1;
352
353  done:
354     EC_POINT_free(pt);
355     BN_CTX_free(ctx);
356     return ret;
357 }
358
359 ECDSA_SIG *sm2_do_sign(const EC_KEY *key,
360                        const EVP_MD *digest,
361                        const uint8_t *id,
362                        const size_t id_len,
363                        const uint8_t *msg, size_t msg_len)
364 {
365     BIGNUM *e = NULL;
366     ECDSA_SIG *sig = NULL;
367
368     e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
369     if (e == NULL) {
370         /* SM2err already called */
371         goto done;
372     }
373
374     sig = sm2_sig_gen(key, e);
375
376  done:
377     BN_free(e);
378     return sig;
379 }
380
381 int sm2_do_verify(const EC_KEY *key,
382                   const EVP_MD *digest,
383                   const ECDSA_SIG *sig,
384                   const uint8_t *id,
385                   const size_t id_len,
386                   const uint8_t *msg, size_t msg_len)
387 {
388     BIGNUM *e = NULL;
389     int ret = 0;
390
391     e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
392     if (e == NULL) {
393         /* SM2err already called */
394         goto done;
395     }
396
397     ret = sm2_sig_verify(key, sig, e);
398
399  done:
400     BN_free(e);
401     return ret;
402 }
403
404 int sm2_sign(const unsigned char *dgst, int dgstlen,
405              unsigned char *sig, unsigned int *siglen, EC_KEY *eckey)
406 {
407     BIGNUM *e = NULL;
408     ECDSA_SIG *s = NULL;
409     int sigleni;
410     int ret = -1;
411
412     e = BN_bin2bn(dgst, dgstlen, NULL);
413     if (e == NULL) {
414        SM2err(SM2_F_SM2_SIGN, ERR_R_BN_LIB);
415        goto done;
416     }
417
418     s = sm2_sig_gen(eckey, e);
419
420     sigleni = i2d_ECDSA_SIG(s, &sig);
421     if (sigleni < 0) {
422        SM2err(SM2_F_SM2_SIGN, ERR_R_INTERNAL_ERROR);
423        goto done;
424     }
425     *siglen = (unsigned int)sigleni;
426
427     ret = 1;
428
429  done:
430     ECDSA_SIG_free(s);
431     BN_free(e);
432     return ret;
433 }
434
435 int sm2_verify(const unsigned char *dgst, int dgstlen,
436                const unsigned char *sig, int sig_len, EC_KEY *eckey)
437 {
438     ECDSA_SIG *s = NULL;
439     BIGNUM *e = NULL;
440     const unsigned char *p = sig;
441     unsigned char *der = NULL;
442     int derlen = -1;
443     int ret = -1;
444
445     s = ECDSA_SIG_new();
446     if (s == NULL) {
447         SM2err(SM2_F_SM2_VERIFY, ERR_R_MALLOC_FAILURE);
448         goto done;
449     }
450     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL) {
451         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
452         goto done;
453     }
454     /* Ensure signature uses DER and doesn't have trailing garbage */
455     derlen = i2d_ECDSA_SIG(s, &der);
456     if (derlen != sig_len || memcmp(sig, der, derlen) != 0) {
457         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
458         goto done;
459     }
460
461     e = BN_bin2bn(dgst, dgstlen, NULL);
462     if (e == NULL) {
463         SM2err(SM2_F_SM2_VERIFY, ERR_R_BN_LIB);
464         goto done;
465     }
466
467     ret = sm2_sig_verify(eckey, s, e);
468
469  done:
470     OPENSSL_free(der);
471     BN_free(e);
472     ECDSA_SIG_free(s);
473     return ret;
474 }