s4:heimdal: import lorikeet-heimdal-202201172009 (commit 5a0b45cd723628b3690ea848548b...
[samba.git] / source4 / heimdal / lib / hcrypto / rsa-gmp.c
1 /*
2  * Copyright (c) 2006 - 2007 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden).
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors
18  *    may be used to endorse or promote products derived from this software
19  *    without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31  * SUCH DAMAGE.
32  */
33
34 #include <config.h>
35 #include <roken.h>
36 #include <krb5-types.h>
37 #include <assert.h>
38
39 #include <rsa.h>
40
41 #ifdef HAVE_GMP
42
43 #include <gmp.h>
44
45 static void
46 BN2mpz(mpz_t s, const BIGNUM *bn)
47 {
48     size_t len;
49     void *p;
50
51     len = BN_num_bytes(bn);
52     p = malloc(len);
53     BN_bn2bin(bn, p);
54     mpz_init(s);
55     mpz_import(s, len, 1, 1, 1, 0, p);
56
57     free(p);
58 }
59
60
61 static BIGNUM *
62 mpz2BN(mpz_t s)
63 {
64     size_t size;
65     BIGNUM *bn;
66     void *p;
67
68     mpz_export(NULL, &size, 1, 1, 1, 0, s);
69     p = malloc(size);
70     if (p == NULL && size != 0)
71         return NULL;
72     mpz_export(p, &size, 1, 1, 1, 0, s);
73     bn = BN_bin2bn(p, size, NULL);
74     free(p);
75     return bn;
76 }
77
78 static int
79 rsa_private_calculate(mpz_t in, mpz_t p,  mpz_t q,
80                       mpz_t dmp1, mpz_t dmq1, mpz_t iqmp,
81                       mpz_t out)
82 {
83     mpz_t vp, vq, u;
84     mpz_init(vp); mpz_init(vq); mpz_init(u);
85
86     /* vq = c ^ (d mod (q - 1)) mod q */
87     /* vp = c ^ (d mod (p - 1)) mod p */
88     mpz_fdiv_r(vp, in, p);
89     mpz_powm(vp, vp, dmp1, p);
90     mpz_fdiv_r(vq, in, q);
91     mpz_powm(vq, vq, dmq1, q);
92
93     /* C2 = 1/q mod p  (iqmp) */
94     /* u = (vp - vq)C2 mod p. */
95     mpz_sub(u, vp, vq);
96 #if 0
97     if (mp_int_compare_zero(&u) < 0)
98         mp_int_add(&u, p, &u);
99 #endif
100     mpz_mul(u, iqmp, u);
101     mpz_fdiv_r(u, u, p);
102
103     /* c ^ d mod n = vq + u q */
104     mpz_mul(u, q, u);
105     mpz_add(out, u, vq);
106
107     mpz_clear(vp);
108     mpz_clear(vq);
109     mpz_clear(u);
110
111     return 0;
112 }
113
114 /*
115  *
116  */
117
118 static int
119 gmp_rsa_public_encrypt(int flen, const unsigned char* from,
120                         unsigned char* to, RSA* rsa, int padding)
121 {
122     unsigned char *p, *p0;
123     size_t size, padlen;
124     mpz_t enc, dec, n, e;
125
126     if (padding != RSA_PKCS1_PADDING)
127         return -1;
128
129     size = RSA_size(rsa);
130
131     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
132         return -2;
133
134     BN2mpz(n, rsa->n);
135     BN2mpz(e, rsa->e);
136
137     p = p0 = malloc(size - 1);
138     if (p0 == NULL) {
139         mpz_clear(e);
140         mpz_clear(n);
141         return -3;
142     }
143
144     padlen = size - flen - 3;
145     assert(padlen >= 8);
146
147     *p++ = 2;
148     if (RAND_bytes(p, padlen) != 1) {
149         mpz_clear(e);
150         mpz_clear(n);
151         free(p0);
152         return -4;
153     }
154     while(padlen) {
155         if (*p == 0)
156             *p = 1;
157         padlen--;
158         p++;
159     }
160     *p++ = 0;
161     memcpy(p, from, flen);
162     p += flen;
163     assert((p - p0) == size - 1);
164
165     mpz_init(enc);
166     mpz_init(dec);
167     mpz_import(dec, size - 1, 1, 1, 1, 0, p0);
168     free(p0);
169
170     mpz_powm(enc, dec, e, n);
171
172     mpz_clear(dec);
173     mpz_clear(e);
174     mpz_clear(n);
175     {
176         size_t ssize;
177         mpz_export(to, &ssize, 1, 1, 1, 0, enc);
178         assert(size >= ssize);
179         size = ssize;
180     }
181     mpz_clear(enc);
182
183     return size;
184 }
185
186 static int
187 gmp_rsa_public_decrypt(int flen, const unsigned char* from,
188                          unsigned char* to, RSA* rsa, int padding)
189 {
190     unsigned char *p;
191     size_t size;
192     mpz_t s, us, n, e;
193
194     if (padding != RSA_PKCS1_PADDING)
195         return -1;
196
197     if (flen > RSA_size(rsa))
198         return -2;
199
200     BN2mpz(n, rsa->n);
201     BN2mpz(e, rsa->e);
202
203 #if 0
204     /* Check that the exponent is larger then 3 */
205     if (mp_int_compare_value(&e, 3) <= 0) {
206         mp_int_clear(&n);
207         mp_int_clear(&e);
208         return -3;
209     }
210 #endif
211
212     mpz_init(s);
213     mpz_init(us);
214     mpz_import(s, flen, 1, 1, 1, 0, rk_UNCONST(from));
215
216     if (mpz_cmp(s, n) >= 0) {
217         mpz_clear(n);
218         mpz_clear(e);
219         return -4;
220     }
221
222     mpz_powm(us, s, e, n);
223
224     mpz_clear(s);
225     mpz_clear(n);
226     mpz_clear(e);
227
228     p = to;
229
230     mpz_export(p, &size, 1, 1, 1, 0, us);
231     assert(size <= RSA_size(rsa));
232
233     mpz_clear(us);
234
235     /* head zero was skipped by mp_int_to_unsigned */
236     if (*p == 0)
237         return -6;
238     if (*p != 1)
239         return -7;
240     size--; p++;
241     while (size && *p == 0xff) {
242         size--; p++;
243     }
244     if (size == 0 || *p != 0)
245         return -8;
246     size--; p++;
247
248     memmove(to, p, size);
249
250     return size;
251 }
252
253 static int
254 gmp_rsa_private_encrypt(int flen, const unsigned char* from,
255                           unsigned char* to, RSA* rsa, int padding)
256 {
257     unsigned char *p, *p0;
258     size_t size;
259     mpz_t in, out, n, e;
260
261     if (padding != RSA_PKCS1_PADDING)
262         return -1;
263
264     size = RSA_size(rsa);
265
266     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
267         return -2;
268
269     p0 = p = malloc(size);
270     *p++ = 0;
271     *p++ = 1;
272     memset(p, 0xff, size - flen - 3);
273     p += size - flen - 3;
274     *p++ = 0;
275     memcpy(p, from, flen);
276     p += flen;
277     assert((p - p0) == size);
278
279     BN2mpz(n, rsa->n);
280     BN2mpz(e, rsa->e);
281
282     mpz_init(in);
283     mpz_init(out);
284     mpz_import(in, size, 1, 1, 1, 0, p0);
285     free(p0);
286
287 #if 0
288     if(mp_int_compare_zero(&in) < 0 ||
289        mp_int_compare(&in, &n) >= 0) {
290         size = 0;
291         goto out;
292     }
293 #endif
294
295     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
296         mpz_t p, q, dmp1, dmq1, iqmp;
297
298         BN2mpz(p, rsa->p);
299         BN2mpz(q, rsa->q);
300         BN2mpz(dmp1, rsa->dmp1);
301         BN2mpz(dmq1, rsa->dmq1);
302         BN2mpz(iqmp, rsa->iqmp);
303
304         rsa_private_calculate(in, p, q, dmp1, dmq1, iqmp, out);
305
306         mpz_clear(p);
307         mpz_clear(q);
308         mpz_clear(dmp1);
309         mpz_clear(dmq1);
310         mpz_clear(iqmp);
311     } else {
312         mpz_t d;
313
314         BN2mpz(d, rsa->d);
315         mpz_powm(out, in, d, n);
316         mpz_clear(d);
317     }
318
319     {
320         size_t ssize;
321         mpz_export(to, &ssize, 1, 1, 1, 0, out);
322         assert(size >= ssize);
323         size = ssize;
324     }
325
326     mpz_clear(e);
327     mpz_clear(n);
328     mpz_clear(in);
329     mpz_clear(out);
330
331     return size;
332 }
333
334 static int
335 gmp_rsa_private_decrypt(int flen, const unsigned char* from,
336                           unsigned char* to, RSA* rsa, int padding)
337 {
338     unsigned char *ptr;
339     size_t size;
340     mpz_t in, out, n, e;
341
342     if (padding != RSA_PKCS1_PADDING)
343         return -1;
344
345     size = RSA_size(rsa);
346     if (flen > size)
347         return -2;
348
349     mpz_init(in);
350     mpz_init(out);
351
352     BN2mpz(n, rsa->n);
353     BN2mpz(e, rsa->e);
354
355     mpz_import(in, flen, 1, 1, 1, 0, from);
356
357     if(mpz_cmp_ui(in, 0) < 0 ||
358        mpz_cmp(in, n) >= 0) {
359         size = 0;
360         goto out;
361     }
362
363     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
364         mpz_t p, q, dmp1, dmq1, iqmp;
365
366         BN2mpz(p, rsa->p);
367         BN2mpz(q, rsa->q);
368         BN2mpz(dmp1, rsa->dmp1);
369         BN2mpz(dmq1, rsa->dmq1);
370         BN2mpz(iqmp, rsa->iqmp);
371
372         rsa_private_calculate(in, p, q, dmp1, dmq1, iqmp, out);
373
374         mpz_clear(p);
375         mpz_clear(q);
376         mpz_clear(dmp1);
377         mpz_clear(dmq1);
378         mpz_clear(iqmp);
379     } else {
380         mpz_t d;
381
382 #if 0
383         if(mp_int_compare_zero(&in) < 0 ||
384            mp_int_compare(&in, &n) >= 0)
385             return MP_RANGE;
386 #endif
387
388         BN2mpz(d, rsa->d);
389         mpz_powm(out, in, d, n);
390         mpz_clear(d);
391     }
392
393     ptr = to;
394     {
395         size_t ssize;
396         mpz_export(ptr, &ssize, 1, 1, 1, 0, out);
397         assert(size >= ssize);
398         size = ssize;
399     }
400
401     /* head zero was skipped by mp_int_to_unsigned */
402     if (*ptr != 2)
403         return -3;
404     size--; ptr++;
405     while (size && *ptr != 0) {
406         size--; ptr++;
407     }
408     if (size == 0)
409         return -4;
410     size--; ptr++;
411
412     memmove(to, ptr, size);
413
414 out:
415     mpz_clear(e);
416     mpz_clear(n);
417     mpz_clear(in);
418     mpz_clear(out);
419
420     return size;
421 }
422
423 static int
424 random_num(mpz_t num, size_t len)
425 {
426     unsigned char *p;
427
428     len = (len + 7) / 8;
429     p = malloc(len);
430     if (p == NULL)
431         return 1;
432     if (RAND_bytes(p, len) != 1) {
433         free(p);
434         return 1;
435     }
436     mpz_import(num, len, 1, 1, 1, 0, p);
437     free(p);
438     return 0;
439 }
440
441
442 static int
443 gmp_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
444 {
445     mpz_t el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
446     int counter, ret;
447
448     if (bits < 789)
449         return -1;
450
451     ret = -1;
452
453     mpz_init(el);
454     mpz_init(p);
455     mpz_init(q);
456     mpz_init(n);
457     mpz_init(d);
458     mpz_init(dmp1);
459     mpz_init(dmq1);
460     mpz_init(iqmp);
461     mpz_init(t1);
462     mpz_init(t2);
463     mpz_init(t3);
464
465     BN2mpz(el, e);
466
467     /* generate p and q so that p != q and bits(pq) ~ bits */
468
469     counter = 0;
470     do {
471         BN_GENCB_call(cb, 2, counter++);
472         random_num(p, bits / 2 + 1);
473         mpz_nextprime(p, p);
474
475         mpz_sub_ui(t1, p, 1);
476         mpz_gcd(t2, t1, el);
477     } while(mpz_cmp_ui(t2, 1) != 0);
478
479     BN_GENCB_call(cb, 3, 0);
480
481     counter = 0;
482     do {
483         BN_GENCB_call(cb, 2, counter++);
484         random_num(q, bits / 2 + 1);
485         mpz_nextprime(q, q);
486
487         mpz_sub_ui(t1, q, 1);
488         mpz_gcd(t2, t1, el);
489     } while(mpz_cmp_ui(t2, 1) != 0);
490
491     /* make p > q */
492     if (mpz_cmp(p, q) < 0)
493         mpz_swap(p, q);
494
495     BN_GENCB_call(cb, 3, 1);
496
497     /* calculate n,             n = p * q */
498     mpz_mul(n, p, q);
499
500     /* calculate d,             d = 1/e mod (p - 1)(q - 1) */
501     mpz_sub_ui(t1, p, 1);
502     mpz_sub_ui(t2, q, 1);
503     mpz_mul(t3, t1, t2);
504     mpz_invert(d, el, t3);
505
506     /* calculate dmp1           dmp1 = d mod (p-1) */
507     mpz_mod(dmp1, d, t1);
508     /* calculate dmq1           dmq1 = d mod (q-1) */
509     mpz_mod(dmq1, d, t2);
510     /* calculate iqmp           iqmp = 1/q mod p */
511     mpz_invert(iqmp, q, p);
512
513     /* fill in RSA key */
514
515     rsa->e = mpz2BN(el);
516     rsa->p = mpz2BN(p);
517     rsa->q = mpz2BN(q);
518     rsa->n = mpz2BN(n);
519     rsa->d = mpz2BN(d);
520     rsa->dmp1 = mpz2BN(dmp1);
521     rsa->dmq1 = mpz2BN(dmq1);
522     rsa->iqmp = mpz2BN(iqmp);
523
524     ret = 1;
525
526     mpz_clear(el);
527     mpz_clear(p);
528     mpz_clear(q);
529     mpz_clear(n);
530     mpz_clear(d);
531     mpz_clear(dmp1);
532     mpz_clear(dmq1);
533     mpz_clear(iqmp);
534     mpz_clear(t1);
535     mpz_clear(t2);
536     mpz_clear(t3);
537
538     return ret;
539 }
540
541 static int
542 gmp_rsa_init(RSA *rsa)
543 {
544     return 1;
545 }
546
547 static int
548 gmp_rsa_finish(RSA *rsa)
549 {
550     return 1;
551 }
552
553 const RSA_METHOD hc_rsa_gmp_method = {
554     "hcrypto GMP RSA",
555     gmp_rsa_public_encrypt,
556     gmp_rsa_public_decrypt,
557     gmp_rsa_private_encrypt,
558     gmp_rsa_private_decrypt,
559     NULL,
560     NULL,
561     gmp_rsa_init,
562     gmp_rsa_finish,
563     0,
564     NULL,
565     NULL,
566     NULL,
567     gmp_rsa_generate_key
568 };
569
570 #endif /* HAVE_GMP */
571
572 /**
573  * RSA implementation using Gnu Multipresistion Library.
574  */
575
576 const RSA_METHOD *
577 RSA_gmp_method(void)
578 {
579 #ifdef HAVE_GMP
580     return &hc_rsa_gmp_method;
581 #else
582     return NULL;
583 #endif
584 }