s4:heimdal: import lorikeet-heimdal-200909210500 (commit 290db8d23647a27c39b97c189a0b...
[amitay/samba.git] / source4 / heimdal / lib / hcrypto / rsa-imath.c
1 /*
2  * Copyright (c) 2006 - 2007 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden).
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors
18  *    may be used to endorse or promote products derived from this software
19  *    without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31  * SUCH DAMAGE.
32  */
33
34 #include <config.h>
35
36 #include <stdio.h>
37 #include <stdlib.h>
38 #include <krb5-types.h>
39 #include <assert.h>
40
41 #include <rsa.h>
42
43 #include <roken.h>
44
45 #include "imath/imath.h"
46 #include "imath/iprime.h"
47
48 static void
49 BN2mpz(mpz_t *s, const BIGNUM *bn)
50 {
51     size_t len;
52     void *p;
53
54     mp_int_init(s);
55
56     len = BN_num_bytes(bn);
57     p = malloc(len);
58     BN_bn2bin(bn, p);
59     mp_int_read_unsigned(s, p, len);
60     free(p);
61 }
62
63 static BIGNUM *
64 mpz2BN(mpz_t *s)
65 {
66     size_t size;
67     BIGNUM *bn;
68     void *p;
69
70     size = mp_int_unsigned_len(s);
71     p = malloc(size);
72     if (p == NULL && size != 0)
73         return NULL;
74     mp_int_to_unsigned(s, p, size);
75
76     bn = BN_bin2bn(p, size, NULL);
77     free(p);
78     return bn;
79 }
80
81 static int random_num(mp_int, size_t);
82
83 static void
84 setup_blind(mp_int n, mp_int b, mp_int bi)
85 {
86     mp_int_init(b);
87     mp_int_init(bi);
88     random_num(b, mp_int_count_bits(n));
89     mp_int_mod(b, n, b);
90     mp_int_invmod(b, n, bi);
91 }
92
93 static void
94 blind(mp_int in, mp_int b, mp_int e, mp_int n)
95 {
96     mpz_t t1;
97     mp_int_init(&t1);
98     /* in' = (in * b^e) mod n */
99     mp_int_exptmod(b, e, n, &t1);
100     mp_int_mul(&t1, in, in);
101     mp_int_mod(in, n, in);
102     mp_int_clear(&t1);
103 }
104
105 static void
106 unblind(mp_int out, mp_int bi, mp_int n)
107 {
108     /* out' = (out * 1/b) mod n */
109     mp_int_mul(out, bi, out);
110     mp_int_mod(out, n, out);
111 }
112
113 static mp_result
114 rsa_private_calculate(mp_int in, mp_int p,  mp_int q,
115                       mp_int dmp1, mp_int dmq1, mp_int iqmp,
116                       mp_int out)
117 {
118     mpz_t vp, vq, u;
119     mp_int_init(&vp); mp_int_init(&vq); mp_int_init(&u);
120
121     /* vq = c ^ (d mod (q - 1)) mod q */
122     /* vp = c ^ (d mod (p - 1)) mod p */
123     mp_int_mod(in, p, &u);
124     mp_int_exptmod(&u, dmp1, p, &vp);
125     mp_int_mod(in, q, &u);
126     mp_int_exptmod(&u, dmq1, q, &vq);
127
128     /* C2 = 1/q mod p  (iqmp) */
129     /* u = (vp - vq)C2 mod p. */
130     mp_int_sub(&vp, &vq, &u);
131     if (mp_int_compare_zero(&u) < 0)
132         mp_int_add(&u, p, &u);
133     mp_int_mul(&u, iqmp, &u);
134     mp_int_mod(&u, p, &u);
135
136     /* c ^ d mod n = vq + u q */
137     mp_int_mul(&u, q, &u);
138     mp_int_add(&u, &vq, out);
139
140     mp_int_clear(&vp);
141     mp_int_clear(&vq);
142     mp_int_clear(&u);
143
144     return MP_OK;
145 }
146
147 /*
148  *
149  */
150
151 static int
152 imath_rsa_public_encrypt(int flen, const unsigned char* from,
153                         unsigned char* to, RSA* rsa, int padding)
154 {
155     unsigned char *p, *p0;
156     mp_result res;
157     size_t size, padlen;
158     mpz_t enc, dec, n, e;
159
160     if (padding != RSA_PKCS1_PADDING)
161         return -1;
162
163     size = RSA_size(rsa);
164
165     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
166         return -2;
167
168     BN2mpz(&n, rsa->n);
169     BN2mpz(&e, rsa->e);
170
171     p = p0 = malloc(size - 1);
172     if (p0 == NULL) {
173         mp_int_clear(&e);
174         mp_int_clear(&n);
175         return -3;
176     }
177
178     padlen = size - flen - 3;
179
180     *p++ = 2;
181     if (RAND_bytes(p, padlen) != 1) {
182         mp_int_clear(&e);
183         mp_int_clear(&n);
184         free(p0);
185         return -4;
186     }
187     while(padlen) {
188         if (*p == 0)
189             *p = 1;
190         padlen--;
191         p++;
192     }
193     *p++ = 0;
194     memcpy(p, from, flen);
195     p += flen;
196     assert((p - p0) == size - 1);
197
198     mp_int_init(&enc);
199     mp_int_init(&dec);
200     mp_int_read_unsigned(&dec, p0, size - 1);
201     free(p0);
202
203     res = mp_int_exptmod(&dec, &e, &n, &enc);
204
205     mp_int_clear(&dec);
206     mp_int_clear(&e);
207     mp_int_clear(&n);
208
209     if (res != MP_OK)
210         return -4;
211
212     {
213         size_t ssize;
214         ssize = mp_int_unsigned_len(&enc);
215         assert(size >= ssize);
216         mp_int_to_unsigned(&enc, to, ssize);
217         size = ssize;
218     }
219     mp_int_clear(&enc);
220
221     return size;
222 }
223
224 static int
225 imath_rsa_public_decrypt(int flen, const unsigned char* from,
226                          unsigned char* to, RSA* rsa, int padding)
227 {
228     unsigned char *p;
229     mp_result res;
230     size_t size;
231     mpz_t s, us, n, e;
232
233     if (padding != RSA_PKCS1_PADDING)
234         return -1;
235
236     if (flen > RSA_size(rsa))
237         return -2;
238
239     BN2mpz(&n, rsa->n);
240     BN2mpz(&e, rsa->e);
241
242 #if 0
243     /* Check that the exponent is larger then 3 */
244     if (mp_int_compare_value(&e, 3) <= 0) {
245         mp_int_clear(&n);
246         mp_int_clear(&e);
247         return -3;
248     }
249 #endif
250
251     mp_int_init(&s);
252     mp_int_init(&us);
253     mp_int_read_unsigned(&s, rk_UNCONST(from), flen);
254
255     if (mp_int_compare(&s, &n) >= 0) {
256         mp_int_clear(&n);
257         mp_int_clear(&e);
258         return -4;
259     }
260
261     res = mp_int_exptmod(&s, &e, &n, &us);
262
263     mp_int_clear(&s);
264     mp_int_clear(&n);
265     mp_int_clear(&e);
266
267     if (res != MP_OK)
268         return -5;
269     p = to;
270
271
272     size = mp_int_unsigned_len(&us);
273     assert(size <= RSA_size(rsa));
274     mp_int_to_unsigned(&us, p, size);
275
276     mp_int_clear(&us);
277
278     /* head zero was skipped by mp_int_to_unsigned */
279     if (*p == 0)
280         return -6;
281     if (*p != 1)
282         return -7;
283     size--; p++;
284     while (size && *p == 0xff) {
285         size--; p++;
286     }
287     if (size == 0 || *p != 0)
288         return -8;
289     size--; p++;
290
291     memmove(to, p, size);
292
293     return size;
294 }
295
296 static int
297 imath_rsa_private_encrypt(int flen, const unsigned char* from,
298                           unsigned char* to, RSA* rsa, int padding)
299 {
300     unsigned char *p, *p0;
301     mp_result res;
302     int size;
303     mpz_t in, out, n, e, b, bi;
304     int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
305     int do_unblind = 0;
306
307     if (padding != RSA_PKCS1_PADDING)
308         return -1;
309
310     size = RSA_size(rsa);
311
312     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
313         return -2;
314
315     p0 = p = malloc(size);
316     *p++ = 0;
317     *p++ = 1;
318     memset(p, 0xff, size - flen - 3);
319     p += size - flen - 3;
320     *p++ = 0;
321     memcpy(p, from, flen);
322     p += flen;
323     assert((p - p0) == size);
324
325     BN2mpz(&n, rsa->n);
326     BN2mpz(&e, rsa->e);
327
328     mp_int_init(&in);
329     mp_int_init(&out);
330     mp_int_read_unsigned(&in, p0, size);
331     free(p0);
332
333     if(mp_int_compare_zero(&in) < 0 ||
334        mp_int_compare(&in, &n) >= 0) {
335         size = -3;
336         goto out;
337     }
338
339     if (blinding) {
340         setup_blind(&n, &b, &bi);
341         blind(&in, &b, &e, &n);
342         do_unblind = 1;
343     }
344
345     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
346         mpz_t p, q, dmp1, dmq1, iqmp;
347
348         BN2mpz(&p, rsa->p);
349         BN2mpz(&q, rsa->q);
350         BN2mpz(&dmp1, rsa->dmp1);
351         BN2mpz(&dmq1, rsa->dmq1);
352         BN2mpz(&iqmp, rsa->iqmp);
353
354         res = rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
355
356         mp_int_clear(&p);
357         mp_int_clear(&q);
358         mp_int_clear(&dmp1);
359         mp_int_clear(&dmq1);
360         mp_int_clear(&iqmp);
361
362         if (res != MP_OK) {
363             size = -4;
364             goto out;
365         }
366     } else {
367         mpz_t d;
368
369         BN2mpz(&d, rsa->d);
370         res = mp_int_exptmod(&in, &d, &n, &out);
371         mp_int_clear(&d);
372         if (res != MP_OK) {
373             size = -5;
374             goto out;
375         }
376     }
377
378     if (do_unblind)
379         unblind(&out, &bi, &n);
380
381     if (size > 0) {
382         size_t ssize;
383         ssize = mp_int_unsigned_len(&out);
384         assert(size >= ssize);
385         mp_int_to_unsigned(&out, to, size);
386         size = ssize;
387     }
388
389  out:
390     if (do_unblind) {
391         mp_int_clear(&b);
392         mp_int_clear(&bi);
393     }
394
395     mp_int_clear(&e);
396     mp_int_clear(&n);
397     mp_int_clear(&in);
398     mp_int_clear(&out);
399
400     return size;
401 }
402
403 static int
404 imath_rsa_private_decrypt(int flen, const unsigned char* from,
405                           unsigned char* to, RSA* rsa, int padding)
406 {
407     unsigned char *ptr;
408     mp_result res;
409     size_t size;
410     mpz_t in, out, n, e, b, bi;
411     int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
412     int do_unblind = 0;
413
414     if (padding != RSA_PKCS1_PADDING)
415         return -1;
416
417     size = RSA_size(rsa);
418     if (flen > size)
419         return -2;
420
421     mp_int_init(&in);
422     mp_int_init(&out);
423
424     BN2mpz(&n, rsa->n);
425     BN2mpz(&e, rsa->e);
426
427     res = mp_int_read_unsigned(&in, rk_UNCONST(from), flen);
428     if (res != MP_OK) {
429         size = -1;
430         goto out;
431     }
432
433     if(mp_int_compare_zero(&in) < 0 ||
434        mp_int_compare(&in, &n) >= 0) {
435         size = -2;
436         goto out;
437     }
438
439     if (blinding) {
440         setup_blind(&n, &b, &bi);
441         blind(&in, &b, &e, &n);
442         do_unblind = 1;
443     }
444
445     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
446         mpz_t p, q, dmp1, dmq1, iqmp;
447
448         BN2mpz(&p, rsa->p);
449         BN2mpz(&q, rsa->q);
450         BN2mpz(&dmp1, rsa->dmp1);
451         BN2mpz(&dmq1, rsa->dmq1);
452         BN2mpz(&iqmp, rsa->iqmp);
453
454         res = rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
455
456         mp_int_clear(&p);
457         mp_int_clear(&q);
458         mp_int_clear(&dmp1);
459         mp_int_clear(&dmq1);
460         mp_int_clear(&iqmp);
461
462         if (res != MP_OK) {
463             size = -3;
464             goto out;
465         }
466
467     } else {
468         mpz_t d;
469
470         if(mp_int_compare_zero(&in) < 0 ||
471            mp_int_compare(&in, &n) >= 0)
472             return MP_RANGE;
473
474         BN2mpz(&d, rsa->d);
475         res = mp_int_exptmod(&in, &d, &n, &out);
476         mp_int_clear(&d);
477         if (res != MP_OK) {
478             size = -4;
479             goto out;
480         }
481     }
482
483     if (do_unblind)
484         unblind(&out, &bi, &n);
485
486     ptr = to;
487     {
488         size_t ssize;
489         ssize = mp_int_unsigned_len(&out);
490         assert(size >= ssize);
491         mp_int_to_unsigned(&out, ptr, ssize);
492         size = ssize;
493     }
494
495     /* head zero was skipped by mp_int_to_unsigned */
496     if (*ptr != 2) {
497         size = -5;
498         goto out;
499     }
500     size--; ptr++;
501     while (size && *ptr != 0) {
502         size--; ptr++;
503     }
504     if (size == 0)
505         return -6;
506     size--; ptr++;
507
508     memmove(to, ptr, size);
509
510  out:
511     if (do_unblind) {
512         mp_int_clear(&b);
513         mp_int_clear(&bi);
514     }
515
516     mp_int_clear(&e);
517     mp_int_clear(&n);
518     mp_int_clear(&in);
519     mp_int_clear(&out);
520
521     return size;
522 }
523
524 static int
525 random_num(mp_int num, size_t len)
526 {
527     unsigned char *p;
528     mp_result res;
529
530     len = (len + 7) / 8;
531     p = malloc(len);
532     if (p == NULL)
533         return 1;
534     if (RAND_bytes(p, len) != 1) {
535         free(p);
536         return 1;
537     }
538     res = mp_int_read_unsigned(num, p, len);
539     free(p);
540     if (res != MP_OK)
541         return 1;
542     return 0;
543 }
544
545 #define CHECK(f, v) if ((f) != (v)) { goto out; }
546
547 static int
548 imath_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
549 {
550     mpz_t el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
551     int counter, ret;
552
553     if (bits < 789)
554         return -1;
555
556     ret = -1;
557
558     mp_int_init(&el);
559     mp_int_init(&p);
560     mp_int_init(&q);
561     mp_int_init(&n);
562     mp_int_init(&d);
563     mp_int_init(&dmp1);
564     mp_int_init(&dmq1);
565     mp_int_init(&iqmp);
566     mp_int_init(&t1);
567     mp_int_init(&t2);
568     mp_int_init(&t3);
569
570     BN2mpz(&el, e);
571
572     /* generate p and q so that p != q and bits(pq) ~ bits */
573     counter = 0;
574     do {
575         BN_GENCB_call(cb, 2, counter++);
576         CHECK(random_num(&p, bits / 2 + 1), 0);
577         CHECK(mp_int_find_prime(&p), MP_TRUE);
578
579         CHECK(mp_int_sub_value(&p, 1, &t1), MP_OK);
580         CHECK(mp_int_gcd(&t1, &el, &t2), MP_OK);
581     } while(mp_int_compare_value(&t2, 1) != 0);
582
583     BN_GENCB_call(cb, 3, 0);
584
585     counter = 0;
586     do {
587         BN_GENCB_call(cb, 2, counter++);
588         CHECK(random_num(&q, bits / 2 + 1), 0);
589         CHECK(mp_int_find_prime(&q), MP_TRUE);
590
591         if (mp_int_compare(&p, &q) == 0) /* don't let p and q be the same */
592             continue;
593
594         CHECK(mp_int_sub_value(&q, 1, &t1), MP_OK);
595         CHECK(mp_int_gcd(&t1, &el, &t2), MP_OK);
596     } while(mp_int_compare_value(&t2, 1) != 0);
597
598     /* make p > q */
599     if (mp_int_compare(&p, &q) < 0)
600         mp_int_swap(&p, &q);
601
602     BN_GENCB_call(cb, 3, 1);
603
604     /* calculate n,             n = p * q */
605     CHECK(mp_int_mul(&p, &q, &n), MP_OK);
606
607     /* calculate d,             d = 1/e mod (p - 1)(q - 1) */
608     CHECK(mp_int_sub_value(&p, 1, &t1), MP_OK);
609     CHECK(mp_int_sub_value(&q, 1, &t2), MP_OK);
610     CHECK(mp_int_mul(&t1, &t2, &t3), MP_OK);
611     CHECK(mp_int_invmod(&el, &t3, &d), MP_OK);
612
613     /* calculate dmp1           dmp1 = d mod (p-1) */
614     CHECK(mp_int_mod(&d, &t1, &dmp1), MP_OK);
615     /* calculate dmq1           dmq1 = d mod (q-1) */
616     CHECK(mp_int_mod(&d, &t2, &dmq1), MP_OK);
617     /* calculate iqmp           iqmp = 1/q mod p */
618     CHECK(mp_int_invmod(&q, &p, &iqmp), MP_OK);
619
620     /* fill in RSA key */
621
622     rsa->e = mpz2BN(&el);
623     rsa->p = mpz2BN(&p);
624     rsa->q = mpz2BN(&q);
625     rsa->n = mpz2BN(&n);
626     rsa->d = mpz2BN(&d);
627     rsa->dmp1 = mpz2BN(&dmp1);
628     rsa->dmq1 = mpz2BN(&dmq1);
629     rsa->iqmp = mpz2BN(&iqmp);
630
631     ret = 1;
632 out:
633     mp_int_clear(&el);
634     mp_int_clear(&p);
635     mp_int_clear(&q);
636     mp_int_clear(&n);
637     mp_int_clear(&d);
638     mp_int_clear(&dmp1);
639     mp_int_clear(&dmq1);
640     mp_int_clear(&iqmp);
641     mp_int_clear(&t1);
642     mp_int_clear(&t2);
643     mp_int_clear(&t3);
644
645     return ret;
646 }
647
648 static int
649 imath_rsa_init(RSA *rsa)
650 {
651     return 1;
652 }
653
654 static int
655 imath_rsa_finish(RSA *rsa)
656 {
657     return 1;
658 }
659
660 const RSA_METHOD hc_rsa_imath_method = {
661     "hcrypto imath RSA",
662     imath_rsa_public_encrypt,
663     imath_rsa_public_decrypt,
664     imath_rsa_private_encrypt,
665     imath_rsa_private_decrypt,
666     NULL,
667     NULL,
668     imath_rsa_init,
669     imath_rsa_finish,
670     0,
671     NULL,
672     NULL,
673     NULL,
674     imath_rsa_generate_key
675 };
676
677 const RSA_METHOD *
678 RSA_imath_method(void)
679 {
680     return &hc_rsa_imath_method;
681 }