Improve use of the test framework in the SM2 internal tests
[openssl.git] / crypto / sm2 / sm2_sign.c
1 /*
2  * Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2017 Ribose Inc. All Rights Reserved.
4  * Ported from Ribose contributions from Botan.
5  *
6  * Licensed under the OpenSSL license (the "License").  You may not use
7  * this file except in compliance with the License.  You can obtain a copy
8  * in the file LICENSE in the source distribution or at
9  * https://www.openssl.org/source/license.html
10  */
11
12 #include "internal/sm2.h"
13 #include "internal/sm2err.h"
14 #include <openssl/err.h>
15 #include <openssl/evp.h>
16 #include <openssl/err.h>
17 #include <openssl/bn.h>
18 #include <string.h>
19
20 static BIGNUM *SM2_compute_msg_hash(const EVP_MD *digest,
21                                     const EC_KEY *key,
22                                     const char *user_id,
23                                     const uint8_t *msg, size_t msg_len)
24 {
25     EVP_MD_CTX *hash = EVP_MD_CTX_new();
26     const int md_size = EVP_MD_size(digest);
27     uint8_t *za = OPENSSL_zalloc(md_size);
28     BIGNUM *e = NULL;
29
30     if (za == NULL)
31         {
32         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_MALLOC_FAILURE);
33         goto done;
34         }
35
36     if (hash == NULL)
37         {
38         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_MALLOC_FAILURE);
39         goto done;
40         }
41
42     if (SM2_compute_userid_digest(za, digest, user_id, key) == 0)
43         goto done;
44
45     if (EVP_DigestInit(hash, digest) == 0)
46         {
47         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
48         goto done;
49         }
50
51     if (EVP_DigestUpdate(hash, za, md_size) == 0)
52         {
53         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
54         goto done;
55         }
56
57     if (EVP_DigestUpdate(hash, msg, msg_len) == 0)
58         {
59         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
60         goto done;
61         }
62
63     /* reuse za buffer to hold H(ZA || M) */
64     if (EVP_DigestFinal(hash, za, NULL) == 0)
65         {
66         SM2err(SM2_F_SM2_COMPUTE_MSG_HASH, ERR_R_EVP_LIB);
67         goto done;
68         }
69
70     e = BN_bin2bn(za, md_size, NULL);
71
72  done:
73     OPENSSL_free(za);
74     EVP_MD_CTX_free(hash);
75     return e;
76 }
77
78 static
79 ECDSA_SIG *SM2_sig_gen(const EC_KEY *key, const BIGNUM *e)
80 {
81     const BIGNUM *dA = EC_KEY_get0_private_key(key);
82     const EC_GROUP *group = EC_KEY_get0_group(key);
83     const BIGNUM *order = EC_GROUP_get0_order(group);
84
85     ECDSA_SIG *sig = NULL;
86     EC_POINT *kG = NULL;
87     BN_CTX *ctx = NULL;
88     BIGNUM *k = NULL;
89     BIGNUM *rk = NULL;
90     BIGNUM *r = NULL;
91     BIGNUM *s = NULL;
92     BIGNUM *x1 = NULL;
93     BIGNUM *tmp = NULL;
94
95     kG = EC_POINT_new(group);
96     if (kG == NULL)
97         {
98         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
99         goto done;
100         }
101
102     ctx = BN_CTX_new();
103     if (ctx == NULL)
104         {
105         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
106         goto done;
107         }
108
109     BN_CTX_start(ctx);
110
111     k = BN_CTX_get(ctx);
112     rk = BN_CTX_get(ctx);
113     x1 = BN_CTX_get(ctx);
114     tmp = BN_CTX_get(ctx);
115
116     if (tmp == NULL)
117         {
118         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
119         goto done;
120         }
121
122     /* These values are returned and so should not be allocated out of the context */
123     r = BN_new();
124     s = BN_new();
125
126     if (r == NULL || s == NULL)
127         {
128         SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
129         goto done;
130         }
131
132     for (;;) {
133         BN_priv_rand_range(k, order);
134
135         if (EC_POINT_mul(group, kG, k, NULL, NULL, ctx) == 0)
136             {
137             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_EC_LIB);
138             goto done;
139             }
140
141         if (EC_POINT_get_affine_coordinates_GFp(group, kG, x1, NULL, ctx) == 0)
142             {
143             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_EC_LIB);
144             goto done;
145             }
146
147         if (BN_mod_add(r, e, x1, order, ctx) == 0)
148             {
149             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
150             goto done;
151             }
152
153         /* try again if r == 0 or r+k == n */
154         if (BN_is_zero(r))
155             continue;
156
157         BN_add(rk, r, k);
158
159         if (BN_cmp(rk, order) == 0)
160             continue;
161
162         if(BN_add(s, dA, BN_value_one()) == 0)
163            {
164            SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
165            goto done;
166            }
167
168         if(BN_mod_inverse(s, s, order, ctx) == 0)
169            {
170            SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
171            goto done;
172            }
173
174         if(BN_mod_mul(tmp, dA, r, order, ctx) == 0)
175            {
176            SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
177            goto done;
178            }
179
180         if(BN_sub(tmp, k, tmp) == 0)
181            {
182            SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
183            goto done;
184            }
185
186         if(BN_mod_mul(s, s, tmp, order, ctx) == 0)
187            {
188            SM2err(SM2_F_SM2_SIG_GEN, ERR_R_BN_LIB);
189            goto done;
190            }
191
192         sig = ECDSA_SIG_new();
193
194         if (sig == NULL)
195             {
196             SM2err(SM2_F_SM2_SIG_GEN, ERR_R_MALLOC_FAILURE);
197             goto done;
198             }
199
200          /* takes ownership of r and s */
201         ECDSA_SIG_set0(sig, r, s);
202         break;
203     }
204
205  done:
206
207     if (sig == NULL) {
208         BN_free(r);
209         BN_free(s);
210     }
211
212     BN_CTX_free(ctx);
213     EC_POINT_free(kG);
214     return sig;
215
216 }
217
218 static
219 int SM2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig, const BIGNUM *e)
220 {
221     int ret = 0;
222     const EC_GROUP *group = EC_KEY_get0_group(key);
223     const BIGNUM *order = EC_GROUP_get0_order(group);
224     BN_CTX *ctx = NULL;
225     EC_POINT *pt = NULL;
226
227     BIGNUM *t = NULL;
228     BIGNUM *x1 = NULL;
229     const BIGNUM *r = NULL;
230     const BIGNUM *s = NULL;
231
232     ctx = BN_CTX_new();
233     if (ctx == NULL)
234         {
235         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
236         goto done;
237         }
238
239     pt = EC_POINT_new(group);
240     if (pt == NULL)
241         {
242         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
243         goto done;
244         }
245
246     BN_CTX_start(ctx);
247
248     t = BN_CTX_get(ctx);
249     x1 = BN_CTX_get(ctx);
250
251     if (x1 == NULL)
252        {
253        SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_MALLOC_FAILURE);
254        goto done;
255        }
256
257     /*
258        B1: verify whether r' in [1,n-1], verification failed if not
259        B2: vefify whether s' in [1,n-1], verification failed if not
260        B3: set M'~=ZA || M'
261        B4: calculate e'=Hv(M'~)
262        B5: calculate t = (r' + s') modn, verification failed if t=0
263        B6: calculate the point (x1', y1')=[s']G + [t]PA
264        B7: calculate R=(e'+x1') modn, verfication pass if yes, otherwise failed
265      */
266
267     ECDSA_SIG_get0(sig, &r, &s);
268
269     if ((BN_cmp(r, BN_value_one()) < 0) || (BN_cmp(s, BN_value_one()) < 0))
270         {
271         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
272         goto done;
273         }
274
275     if ((BN_cmp(order, r) <= 0) || (BN_cmp(order, s) <= 0))
276         {
277         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
278         goto done;
279         }
280
281     if (BN_mod_add(t, r, s, order, ctx) == 0)
282         {
283         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
284         goto done;
285         }
286
287     if (BN_is_zero(t) == 1)
288         {
289         SM2err(SM2_F_SM2_SIG_VERIFY, SM2_R_BAD_SIGNATURE);
290         goto done;
291         }
292
293     if (EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx) == 0)
294         {
295         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_EC_LIB);
296         goto done;
297         }
298
299     if (EC_POINT_get_affine_coordinates_GFp(group, pt, x1, NULL, ctx) == 0)
300         {
301         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_EC_LIB);
302         goto done;
303         }
304
305     if (BN_mod_add(t, e, x1, order, ctx) == 0)
306         {
307         SM2err(SM2_F_SM2_SIG_VERIFY, ERR_R_BN_LIB);
308         goto done;
309         }
310
311     if (BN_cmp(r, t) == 0)
312         ret = 1;
313
314  done:
315     EC_POINT_free(pt);
316     BN_CTX_free(ctx);
317     return ret;
318 }
319
320 ECDSA_SIG *SM2_do_sign(const EC_KEY *key,
321                        const EVP_MD *digest,
322                        const char *user_id, const uint8_t *msg, size_t msg_len)
323 {
324     BIGNUM *e = NULL;
325     ECDSA_SIG *sig = NULL;
326
327     e = SM2_compute_msg_hash(digest, key, user_id, msg, msg_len);
328     if (e == NULL)
329         goto done;
330
331     sig = SM2_sig_gen(key, e);
332
333  done:
334     BN_free(e);
335     return sig;
336 }
337
338 int SM2_do_verify(const EC_KEY *key,
339                   const EVP_MD *digest,
340                   const ECDSA_SIG *sig,
341                   const char *user_id, const uint8_t *msg, size_t msg_len)
342 {
343     BIGNUM *e = NULL;
344     int ret = -1;
345
346     e = SM2_compute_msg_hash(digest, key, user_id, msg, msg_len);
347     if (e == NULL)
348         goto done;
349
350     ret = SM2_sig_verify(key, sig, e);
351
352  done:
353     BN_free(e);
354     return ret;
355 }
356
357 int SM2_sign(int type, const unsigned char *dgst, int dgstlen,
358              unsigned char *sig, unsigned int *siglen, EC_KEY *eckey)
359 {
360     BIGNUM *e = NULL;
361     ECDSA_SIG *s = NULL;
362     int ret = -1;
363
364     if (type != NID_sm3 || dgstlen != 32)
365        {
366        SM2err(SM2_F_SM2_SIGN, SM2_R_INVALID_DIGEST_TYPE);
367        goto done;
368        }
369
370     e = BN_bin2bn(dgst, dgstlen, NULL);
371
372     s = SM2_sig_gen(eckey, e);
373
374     *siglen = i2d_ECDSA_SIG(s, &sig);
375
376     ret = 1;
377
378  done:
379     ECDSA_SIG_free(s);
380     BN_free(e);
381     return ret;
382 }
383
384 int SM2_verify(int type, const unsigned char *dgst, int dgstlen,
385                const unsigned char *sig, int sig_len, EC_KEY *eckey)
386 {
387     ECDSA_SIG *s = NULL;
388     BIGNUM *e = NULL;
389     const unsigned char *p = sig;
390     unsigned char *der = NULL;
391     int derlen = -1;
392     int ret = -1;
393
394     if (type != NID_sm3)
395        {
396        SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_DIGEST_TYPE);
397        goto done;
398        }
399
400     s = ECDSA_SIG_new();
401     if (s == NULL)
402         {
403         SM2err(SM2_F_SM2_VERIFY, ERR_R_MALLOC_FAILURE);
404         goto done;
405         }
406     if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL)
407         {
408         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
409         goto done;
410         }
411     /* Ensure signature uses DER and doesn't have trailing garbage */
412     derlen = i2d_ECDSA_SIG(s, &der);
413     if (derlen != sig_len || memcmp(sig, der, derlen) != 0)
414         {
415         SM2err(SM2_F_SM2_VERIFY, SM2_R_INVALID_ENCODING);
416         goto done;
417         }
418
419     e = BN_bin2bn(dgst, dgstlen, NULL);
420
421     ret = SM2_sig_verify(eckey, s, e);
422
423  done:
424     OPENSSL_free(der);
425     BN_free(e);
426     ECDSA_SIG_free(s);
427     return ret;
428 }