Add evp_test fixes.
[openssl.git] / crypto / sm2 / sm2_sign.c
1 /*
2  * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the Apache License 2.0 (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11
12 #include "crypto/sm2.h"
13 #include "crypto/sm2err.h"
14 #include "crypto/ec.h" /* ec_group_do_inverse_ord() */
15 #include "internal/numbers.h"
16 #include <openssl/err.h>
17 #include <openssl/evp.h>
18 #include <openssl/err.h>
19 #include <openssl/bn.h>
20 #include <string.h>
21
22 int sm2_compute_z_digest(uint8_t *out,
23                          const EVP_MD *digest,
24                          const uint8_t *id,
25                          const size_t id_len,
26                          const EC_KEY *key)
27 {
28     int rc = 0;
29     const EC_GROUP *group = EC_KEY_get0_group(key);
30     BN_CTX *ctx = NULL;
31     EVP_MD_CTX *hash = NULL;
32     BIGNUM *p = NULL;
33     BIGNUM *a = NULL;
34     BIGNUM *b = NULL;
35     BIGNUM *xG = NULL;
36     BIGNUM *yG = NULL;
37     BIGNUM *xA = NULL;
38     BIGNUM *yA = NULL;
39     int p_bytes = 0;
40     uint8_t *buf = NULL;
41     uint16_t entl = 0;
42     uint8_t e_byte = 0;
43
44     hash = EVP_MD_CTX_new();
45     ctx = BN_CTX_new_ex(ec_key_get_libctx(key));
46     if (hash == NULL || ctx == NULL) {
47         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_MALLOC_FAILURE);
48         goto done;
49     }
50
51     p = BN_CTX_get(ctx);
52     a = BN_CTX_get(ctx);
53     b = BN_CTX_get(ctx);
54     xG = BN_CTX_get(ctx);
55     yG = BN_CTX_get(ctx);
56     xA = BN_CTX_get(ctx);
57     yA = BN_CTX_get(ctx);
58
59     if (yA == NULL) {
60         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_MALLOC_FAILURE);
61         goto done;
62     }
63
64     if (!EVP_DigestInit(hash, digest)) {
65         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_EVP_LIB);
66         goto done;
67     }
68
69     /* Z = h(ENTL || ID || a || b || xG || yG || xA || yA) */
70
71     if (id_len >= (UINT16_MAX / 8)) {
72         /* too large */
73         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, SM2_R_ID_TOO_LARGE);
74         goto done;
75     }
76
77     entl = (uint16_t)(8 * id_len);
78
79     e_byte = entl >> 8;
80     if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
81         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_EVP_LIB);
82         goto done;
83     }
84     e_byte = entl & 0xFF;
85     if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
86         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_EVP_LIB);
87         goto done;
88     }
89
90     if (id_len > 0 && !EVP_DigestUpdate(hash, id, id_len)) {
91         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_EVP_LIB);
92         goto done;
93     }
94
95     if (!EC_GROUP_get_curve(group, p, a, b, ctx)) {
96         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_EC_LIB);
97         goto done;
98     }
99
100     p_bytes = BN_num_bytes(p);
101     buf = OPENSSL_zalloc(p_bytes);
102     if (buf == NULL) {
103         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_MALLOC_FAILURE);
104         goto done;
105     }
106
107     if (BN_bn2binpad(a, buf, p_bytes) < 0
108             || !EVP_DigestUpdate(hash, buf, p_bytes)
109             || BN_bn2binpad(b, buf, p_bytes) < 0
110             || !EVP_DigestUpdate(hash, buf, p_bytes)
111             || !EC_POINT_get_affine_coordinates(group,
112                                                 EC_GROUP_get0_generator(group),
113                                                 xG, yG, ctx)
114             || BN_bn2binpad(xG, buf, p_bytes) < 0
115             || !EVP_DigestUpdate(hash, buf, p_bytes)
116             || BN_bn2binpad(yG, buf, p_bytes) < 0
117             || !EVP_DigestUpdate(hash, buf, p_bytes)
118             || !EC_POINT_get_affine_coordinates(group,
119                                                 EC_KEY_get0_public_key(key),
120                                                 xA, yA, ctx)
121             || BN_bn2binpad(xA, buf, p_bytes) < 0
122             || !EVP_DigestUpdate(hash, buf, p_bytes)
123             || BN_bn2binpad(yA, buf, p_bytes) < 0
124             || !EVP_DigestUpdate(hash, buf, p_bytes)
125             || !EVP_DigestFinal(hash, out, NULL)) {
126         SM2err(SM2_F_SM2_COMPUTE_Z_DIGEST, ERR_R_INTERNAL_ERROR);
127         goto done;
128     }
129
130     rc = 1;
131
132  done:
133     OPENSSL_free(buf);
134     BN_CTX_free(ctx);
135     EVP_MD_CTX_free(hash);
136     return rc;
137 }
138
139 static BIGNUM *sm2_compute_msg_hash(const EVP_MD *digest,
140                                     const EC_KEY *key,
141                                     const uint8_t *id,
142                                     const size_t id_len,
143                                     const uint8_t *msg, size_t msg_len)
144 {
145     EVP_MD_CTX *hash = EVP_MD_CTX_new();
146     const int md_size = EVP_MD_size(digest);
147     uint8_t *z = NULL;
148     BIGNUM *e = NULL;
149     EVP_MD *fetched_digest = NULL;
150     OPENSSL_CTX *libctx = ec_key_get_libctx(key);
151     const char *propq = ec_key_get0_propq(key);
152
153     if (md_size < 0) {
154         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, SM2_R_INVALID_DIGEST);
155         goto done;
156     }
157
158     z = OPENSSL_zalloc(md_size);
159     if (hash == NULL || z == NULL) {
160         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_MALLOC_FAILURE);
161         goto done;
162     }
163
164     fetched_digest = EVP_MD_fetch(libctx, EVP_MD_name(digest), propq);
165     if (fetched_digest == NULL) {
166         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_INTERNAL_ERROR);
167         goto done;
168     }
169
170     if (!sm2_compute_z_digest(z, fetched_digest, id, id_len, key)) {
171         /* SM2err already called */
172         goto done;
173     }
174
175     if (!EVP_DigestInit(hash, fetched_digest)
176             || !EVP_DigestUpdate(hash, z, md_size)
177             || !EVP_DigestUpdate(hash, msg, msg_len)
178                /* reuse z buffer to hold H(Z || M) */
179             || !EVP_DigestFinal(hash, z, NULL)) {
180         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
181         goto done;
182     }
183
184     e = BN_bin2bn(z, md_size, NULL);
185     if (e == NULL)
186         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_INTERNAL_ERROR);
187
188  done:
189     EVP_MD_free(fetched_digest);
190     OPENSSL_free(z);
191     EVP_MD_CTX_free(hash);
192     return e;
193 }
194
195 static ECDSA_SIG *sm2_sig_gen(const EC_KEY *key, const BIGNUM *e)
196 {
197     const BIGNUM *dA = EC_KEY_get0_private_key(key);
198     const EC_GROUP *group = EC_KEY_get0_group(key);
199     const BIGNUM *order = EC_GROUP_get0_order(group);
200     ECDSA_SIG *sig = NULL;
201     EC_POINT *kG = NULL;
202     BN_CTX *ctx = NULL;
203     BIGNUM *k = NULL;
204     BIGNUM *rk = NULL;
205     BIGNUM *r = NULL;
206     BIGNUM *s = NULL;
207     BIGNUM *x1 = NULL;
208     BIGNUM *tmp = NULL;
209     OPENSSL_CTX *libctx = ec_key_get_libctx(key);
210
211     kG = EC_POINT_new(group);
212     ctx = BN_CTX_new_ex(libctx);
213     if (kG == NULL || ctx == NULL) {
214         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
215         goto done;
216     }
217
218     BN_CTX_start(ctx);
219     k = BN_CTX_get(ctx);
220     rk = BN_CTX_get(ctx);
221     x1 = BN_CTX_get(ctx);
222     tmp = BN_CTX_get(ctx);
223     if (tmp == NULL) {
224         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
225         goto done;
226     }
227
228     /*
229      * These values are returned and so should not be allocated out of the
230      * context
231      */
232     r = BN_new();
233     s = BN_new();
234
235     if (r == NULL || s == NULL) {
236         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
237         goto done;
238     }
239
240     for (;;) {
241         if (!BN_priv_rand_range_ex(k, order, ctx)) {
242             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
243             goto done;
244         }
245
246         if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)
247                 || !EC_POINT_get_affine_coordinates(group, kG, x1, NULL,
248                                                     ctx)
249                 || !BN_mod_add(r, e, x1, order, ctx)) {
250             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
251             goto done;
252         }
253
254         /* try again if r == 0 or r+k == n */
255         if (BN_is_zero(r))
256             continue;
257
258         if (!BN_add(rk, r, k)) {
259             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
260             goto done;
261         }
262
263         if (BN_cmp(rk, order) == 0)
264             continue;
265
266         if (!BN_add(s, dA, BN_value_one())
267                 || !ec_group_do_inverse_ord(group, s, s, ctx)
268                 || !BN_mod_mul(tmp, dA, r, order, ctx)
269                 || !BN_sub(tmp, k, tmp)
270                 || !BN_mod_mul(s, s, tmp, order, ctx)) {
271             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
272             goto done;
273         }
274
275         sig = ECDSA_SIG_new();
276         if (sig == NULL) {
277             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
278             goto done;
279         }
280
281          /* takes ownership of r and s */
282         ECDSA_SIG_set0(sig, r, s);
283         break;
284     }
285
286  done:
287     if (sig == NULL) {
288         BN_free(r);
289         BN_free(s);
290     }
291
292     BN_CTX_free(ctx);
293     EC_POINT_free(kG);
294     return sig;
295 }
296
297 static int sm2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig,
298                           const BIGNUM *e)
299 {
300     int ret = 0;
301     const EC_GROUP *group = EC_KEY_get0_group(key);
302     const BIGNUM *order = EC_GROUP_get0_order(group);
303     BN_CTX *ctx = NULL;
304     EC_POINT *pt = NULL;
305     BIGNUM *t = NULL;
306     BIGNUM *x1 = NULL;
307     const BIGNUM *r = NULL;
308     const BIGNUM *s = NULL;
309     OPENSSL_CTX *libctx = ec_key_get_libctx(key);
310
311     ctx = BN_CTX_new_ex(libctx);
312     pt = EC_POINT_new(group);
313     if (ctx == NULL || pt == NULL) {
314         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
315         goto done;
316     }
317
318     BN_CTX_start(ctx);
319     t = BN_CTX_get(ctx);
320     x1 = BN_CTX_get(ctx);
321     if (x1 == NULL) {
322         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
323         goto done;
324     }
325
326     /*
327      * B1: verify whether r' in [1,n-1], verification failed if not
328      * B2: verify whether s' in [1,n-1], verification failed if not
329      * B3: set M'~=ZA || M'
330      * B4: calculate e'=Hv(M'~)
331      * B5: calculate t = (r' + s') modn, verification failed if t=0
332      * B6: calculate the point (x1', y1')=[s']G + [t]PA
333      * B7: calculate R=(e'+x1') modn, verification pass if yes, otherwise failed
334      */
335
336     ECDSA_SIG_get0(sig, &r, &s);
337
338     if (BN_cmp(r, BN_value_one()) < 0
339             || BN_cmp(s, BN_value_one()) < 0
340             || BN_cmp(order, r) <= 0
341             || BN_cmp(order, s) <= 0) {
342         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
343         goto done;
344     }
345
346     if (!BN_mod_add(t, r, s, order, ctx)) {
347         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
348         goto done;
349     }
350
351     if (BN_is_zero(t)) {
352         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
353         goto done;
354     }
355
356     if (!EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx)
357             || !EC_POINT_get_affine_coordinates(group, pt, x1, NULL, ctx)) {
358         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_EC_LIB);
359         goto done;
360     }
361
362     if (!BN_mod_add(t, e, x1, order, ctx)) {
363         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
364         goto done;
365     }
366
367     if (BN_cmp(r, t) == 0)
368         ret = 1;
369
370  done:
371     EC_POINT_free(pt);
372     BN_CTX_free(ctx);
373     return ret;
374 }
375
376 ECDSA_SIG *sm2_do_sign(const EC_KEY *key,
377                        const EVP_MD *digest,
378                        const uint8_t *id,
379                        const size_t id_len,
380                        const uint8_t *msg, size_t msg_len)
381 {
382     BIGNUM *e = NULL;
383     ECDSA_SIG *sig = NULL;
384
385     e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
386     if (e == NULL) {
387         /* SM2err already called */
388         goto done;
389     }
390
391     sig = sm2_sig_gen(key, e);
392
393  done:
394     BN_free(e);
395     return sig;
396 }
397
398 int sm2_do_verify(const EC_KEY *key,
399                   const EVP_MD *digest,
400                   const ECDSA_SIG *sig,
401                   const uint8_t *id,
402                   const size_t id_len,
403                   const uint8_t *msg, size_t msg_len)
404 {
405     BIGNUM *e = NULL;
406     int ret = 0;
407
408     e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
409     if (e == NULL) {
410         /* SM2err already called */
411         goto done;
412     }
413
414     ret = sm2_sig_verify(key, sig, e);
415
416  done:
417     BN_free(e);
418     return ret;
419 }
420
421 int sm2_sign(const unsigned char *dgst, int dgstlen,
422              unsigned char *sig, unsigned int *siglen, EC_KEY *eckey)
423 {
424     BIGNUM *e = NULL;
425     ECDSA_SIG *s = NULL;
426     int sigleni;
427     int ret = -1;
428
429     e = BN_bin2bn(dgst, dgstlen, NULL);
430     if (e == NULL) {
431        SM2err(SM2_F_SM2_SIGN, ERR_R_BN_LIB);
432        goto done;
433     }
434
435     s = sm2_sig_gen(eckey, e);
436     if (s == NULL) {
437         SM2err(SM2_F_SM2_SIGN, ERR_R_INTERNAL_ERROR);
438         goto done;
439     }
440
441     sigleni = i2d_ECDSA_SIG(s, &sig);
442     if (sigleni < 0) {
443        SM2err(SM2_F_SM2_SIGN, ERR_R_INTERNAL_ERROR);
444        goto done;
445     }
446     *siglen = (unsigned int)sigleni;
447
448     ret = 1;
449
450  done:
451     ECDSA_SIG_free(s);
452     BN_free(e);
453     return ret;
454 }
455
456 int sm2_verify(const unsigned char *dgst, int dgstlen,
457                const unsigned char *sig, int sig_len, EC_KEY *eckey)
458 {
459     ECDSA_SIG *s = NULL;
460     BIGNUM *e = NULL;
461     const unsigned char *p = sig;
462     unsigned char *der = NULL;
463     int derlen = -1;
464     int ret = -1;
465
466     s = ECDSA_SIG_new();
467     if (s == NULL) {
468         SM2err(SM2_F_SM2_VERIFY, ERR_R_MALLOC_FAILURE);
469         goto done;
470     }
471     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL) {
472         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
473         goto done;
474     }
475     /* Ensure signature uses DER and doesn't have trailing garbage */
476     derlen = i2d_ECDSA_SIG(s, &der);
477     if (derlen != sig_len || memcmp(sig, der, derlen) != 0) {
478         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
479         goto done;
480     }
481
482     e = BN_bin2bn(dgst, dgstlen, NULL);
483     if (e == NULL) {
484         SM2err(SM2_F_SM2_VERIFY, ERR_R_BN_LIB);
485         goto done;
486     }
487
488     ret = sm2_sig_verify(eckey, s, e);
489
490  done:
491     OPENSSL_free(der);
492     BN_free(e);
493     ECDSA_SIG_free(s);
494     return ret;
495 }