f1185c13378e65045816f0db8fdf4aa213c842e6
[openssl.git] / crypto / sm2 / sm2_sign.c
1 /*
2  * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the OpenSSL license (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11
12 #include "internal/sm2.h"
13 #include "internal/sm2err.h"
14 #include "internal/ec_int.h" /* ec_group_do_inverse_ord() */
15 #include <openssl/err.h>
16 #include <openssl/evp.h>
17 #include <openssl/err.h>
18 #include <openssl/bn.h>
19 #include <string.h>
20
21 static BIGNUM *sm2_compute_msg_hash(const EVP_MD *digest,
22                                     const EC_KEY *key,
23                                     const char *user_id,
24                                     const uint8_t *msg, size_t msg_len)
25 {
26     EVP_MD_CTX *hash = EVP_MD_CTX_new();
27     const int md_size = EVP_MD_size(digest);
28     uint8_t *za = NULL;
29     BIGNUM *e = NULL;
30
31     if (md_size < 0) {
32         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, SM2_R_INVALID_DIGEST);
33         goto done;
34     }
35
36     za = OPENSSL_zalloc(md_size);
37     if (hash == NULL || za == NULL) {
38         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_MALLOC_FAILURE);
39         goto done;
40     }
41
42     if (!sm2_compute_userid_digest(za, digest, user_id, key)) {
43         /* SM2err already called */
44         goto done;
45     }
46
47     if (!EVP_DigestInit(hash, digest)
48             || !EVP_DigestUpdate(hash, za, md_size)
49             || !EVP_DigestUpdate(hash, msg, msg_len)
50                /* reuse za buffer to hold H(ZA || M) */
51             || !EVP_DigestFinal(hash, za, NULL)) {
52         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
53         goto done;
54     }
55
56     e = BN_bin2bn(za, md_size, NULL);
57     if (e == NULL)
58         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_INTERNAL_ERROR);
59
60  done:
61     OPENSSL_free(za);
62     EVP_MD_CTX_free(hash);
63     return e;
64 }
65
66 static ECDSA_SIG *sm2_sig_gen(const EC_KEY *key, const BIGNUM *e)
67 {
68     const BIGNUM *dA = EC_KEY_get0_private_key(key);
69     const EC_GROUP *group = EC_KEY_get0_group(key);
70     const BIGNUM *order = EC_GROUP_get0_order(group);
71     ECDSA_SIG *sig = NULL;
72     EC_POINT *kG = NULL;
73     BN_CTX *ctx = NULL;
74     BIGNUM *k = NULL;
75     BIGNUM *rk = NULL;
76     BIGNUM *r = NULL;
77     BIGNUM *s = NULL;
78     BIGNUM *x1 = NULL;
79     BIGNUM *tmp = NULL;
80
81     kG = EC_POINT_new(group);
82     ctx = BN_CTX_new();
83     if (kG == NULL || ctx == NULL) {
84         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
85         goto done;
86     }
87
88
89     BN_CTX_start(ctx);
90     k = BN_CTX_get(ctx);
91     rk = BN_CTX_get(ctx);
92     x1 = BN_CTX_get(ctx);
93     tmp = BN_CTX_get(ctx);
94     if (tmp == NULL) {
95         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
96         goto done;
97     }
98
99     /*
100      * These values are returned and so should not be allocated out of the
101      * context
102      */
103     r = BN_new();
104     s = BN_new();
105
106     if (r == NULL || s == NULL) {
107         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
108         goto done;
109     }
110
111     for (;;) {
112         if (!BN_priv_rand_range(k, order)) {
113             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
114             goto done;
115         }
116
117         if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)
118                 || !EC_POINT_get_affine_coordinates(group, kG, x1, NULL,
119                                                         ctx)
120                 || !BN_mod_add(r, e, x1, order, ctx)) {
121             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
122             goto done;
123         }
124
125         /* try again if r == 0 or r+k == n */
126         if (BN_is_zero(r))
127             continue;
128
129         if (!BN_add(rk, r, k)) {
130             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_INTERNAL_ERROR);
131             goto done;
132         }
133
134         if (BN_cmp(rk, order) == 0)
135             continue;
136
137         if (!BN_add(s, dA, BN_value_one())
138                 || !ec_group_do_inverse_ord(group, s, s, ctx)
139                 || !BN_mod_mul(tmp, dA, r, order, ctx)
140                 || !BN_sub(tmp, k, tmp)
141                 || !BN_mod_mul(s, s, tmp, order, ctx)) {
142             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
143             goto done;
144         }
145
146         sig = ECDSA_SIG_new();
147         if (sig == NULL) {
148             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
149             goto done;
150         }
151
152          /* takes ownership of r and s */
153         ECDSA_SIG_set0(sig, r, s);
154         break;
155     }
156
157  done:
158     if (sig == NULL) {
159         BN_free(r);
160         BN_free(s);
161     }
162
163     BN_CTX_free(ctx);
164     EC_POINT_free(kG);
165     return sig;
166 }
167
168 static int sm2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig,
169                           const BIGNUM *e)
170 {
171     int ret = 0;
172     const EC_GROUP *group = EC_KEY_get0_group(key);
173     const BIGNUM *order = EC_GROUP_get0_order(group);
174     BN_CTX *ctx = NULL;
175     EC_POINT *pt = NULL;
176     BIGNUM *t = NULL;
177     BIGNUM *x1 = NULL;
178     const BIGNUM *r = NULL;
179     const BIGNUM *s = NULL;
180
181     ctx = BN_CTX_new();
182     pt = EC_POINT_new(group);
183     if (ctx == NULL || pt == NULL) {
184         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
185         goto done;
186     }
187
188     BN_CTX_start(ctx);
189     t = BN_CTX_get(ctx);
190     x1 = BN_CTX_get(ctx);
191     if (x1 == NULL) {
192         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
193         goto done;
194     }
195
196     /*
197      * B1: verify whether r' in [1,n-1], verification failed if not
198      * B2: vefify whether s' in [1,n-1], verification failed if not
199      * B3: set M'~=ZA || M'
200      * B4: calculate e'=Hv(M'~)
201      * B5: calculate t = (r' + s') modn, verification failed if t=0
202      * B6: calculate the point (x1', y1')=[s']G + [t]PA
203      * B7: calculate R=(e'+x1') modn, verfication pass if yes, otherwise failed
204      */
205
206     ECDSA_SIG_get0(sig, &r, &s);
207
208     if (BN_cmp(r, BN_value_one()) < 0
209             || BN_cmp(s, BN_value_one()) < 0
210             || BN_cmp(order, r) <= 0
211             || BN_cmp(order, s) <= 0) {
212         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
213         goto done;
214     }
215
216     if (!BN_mod_add(t, r, s, order, ctx)) {
217         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
218         goto done;
219     }
220
221     if (BN_is_zero(t)) {
222         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
223         goto done;
224     }
225
226     if (!EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx)
227             || !EC_POINT_get_affine_coordinates(group, pt, x1, NULL, ctx)) {
228         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_EC_LIB);
229         goto done;
230     }
231
232     if (!BN_mod_add(t, e, x1, order, ctx)) {
233         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
234         goto done;
235     }
236
237     if (BN_cmp(r, t) == 0)
238         ret = 1;
239
240  done:
241     EC_POINT_free(pt);
242     BN_CTX_free(ctx);
243     return ret;
244 }
245
246 ECDSA_SIG *sm2_do_sign(const EC_KEY *key,
247                        const EVP_MD *digest,
248                        const char *user_id, const uint8_t *msg, size_t msg_len)
249 {
250     BIGNUM *e = NULL;
251     ECDSA_SIG *sig = NULL;
252
253     e = sm2_compute_msg_hash(digest, key, user_id, msg, msg_len);
254     if (e == NULL) {
255         /* SM2err already called */
256         goto done;
257     }
258
259     sig = sm2_sig_gen(key, e);
260
261  done:
262     BN_free(e);
263     return sig;
264 }
265
266 int sm2_do_verify(const EC_KEY *key,
267                   const EVP_MD *digest,
268                   const ECDSA_SIG *sig,
269                   const char *user_id, const uint8_t *msg, size_t msg_len)
270 {
271     BIGNUM *e = NULL;
272     int ret = 0;
273
274     e = sm2_compute_msg_hash(digest, key, user_id, msg, msg_len);
275     if (e == NULL) {
276         /* SM2err already called */
277         goto done;
278     }
279
280     ret = sm2_sig_verify(key, sig, e);
281
282  done:
283     BN_free(e);
284     return ret;
285 }
286
287 int sm2_sign(const unsigned char *dgst, int dgstlen,
288              unsigned char *sig, unsigned int *siglen, EC_KEY *eckey)
289 {
290     BIGNUM *e = NULL;
291     ECDSA_SIG *s = NULL;
292     int sigleni;
293     int ret = -1;
294
295     e = BN_bin2bn(dgst, dgstlen, NULL);
296     if (e == NULL) {
297        SM2err(SM2_F_SM2_SIGN, ERR_R_BN_LIB);
298        goto done;
299     }
300
301     s = sm2_sig_gen(eckey, e);
302
303     sigleni = i2d_ECDSA_SIG(s, &sig);
304     if (sigleni < 0) {
305        SM2err(SM2_F_SM2_SIGN, ERR_R_INTERNAL_ERROR);
306        goto done;
307     }
308     *siglen = (unsigned int)sigleni;
309
310     ret = 1;
311
312  done:
313     ECDSA_SIG_free(s);
314     BN_free(e);
315     return ret;
316 }
317
318 int sm2_verify(const unsigned char *dgst, int dgstlen,
319                const unsigned char *sig, int sig_len, EC_KEY *eckey)
320 {
321     ECDSA_SIG *s = NULL;
322     BIGNUM *e = NULL;
323     const unsigned char *p = sig;
324     unsigned char *der = NULL;
325     int derlen = -1;
326     int ret = -1;
327
328     s = ECDSA_SIG_new();
329     if (s == NULL) {
330         SM2err(SM2_F_SM2_VERIFY, ERR_R_MALLOC_FAILURE);
331         goto done;
332     }
333     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL) {
334         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
335         goto done;
336     }
337     /* Ensure signature uses DER and doesn't have trailing garbage */
338     derlen = i2d_ECDSA_SIG(s, &der);
339     if (derlen != sig_len || memcmp(sig, der, derlen) != 0) {
340         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
341         goto done;
342     }
343
344     e = BN_bin2bn(dgst, dgstlen, NULL);
345     if (e == NULL) {
346         SM2err(SM2_F_SM2_VERIFY, ERR_R_BN_LIB);
347         goto done;
348     }
349
350     ret = sm2_sig_verify(eckey, s, e);
351
352  done:
353     OPENSSL_free(der);
354     BN_free(e);
355     ECDSA_SIG_free(s);
356     return ret;
357 }