2641dc1bc484689c0db97e0fd4c6a29946904367
[amitay/samba.git] / source4 / heimdal / lib / hcrypto / rsa-imath.c
1 /*
2  * Copyright (c) 2006 - 2007 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden).
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors
18  *    may be used to endorse or promote products derived from this software
19  *    without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31  * SUCH DAMAGE.
32  */
33
34 #include <config.h>
35
36 #include <stdio.h>
37 #include <stdlib.h>
38 #include <krb5-types.h>
39 #include <assert.h>
40
41 #include <rsa.h>
42
43 #include <roken.h>
44
45 #include "imath/imath.h"
46 #include "imath/iprime.h"
47
48 static void
49 BN2mpz(mpz_t *s, const BIGNUM *bn)
50 {
51     size_t len;
52     void *p;
53
54     mp_int_init(s);
55
56     len = BN_num_bytes(bn);
57     p = malloc(len);
58     BN_bn2bin(bn, p);
59     mp_int_read_unsigned(s, p, len);
60     free(p);
61 }
62
63 static BIGNUM *
64 mpz2BN(mpz_t *s)
65 {
66     size_t size;
67     BIGNUM *bn;
68     void *p;
69
70     size = mp_int_unsigned_len(s);
71     p = malloc(size);
72     if (p == NULL && size != 0)
73         return NULL;
74     mp_int_to_unsigned(s, p, size);
75
76     bn = BN_bin2bn(p, size, NULL);
77     free(p);
78     return bn;
79 }
80
81 static int random_num(mp_int, size_t);
82
83 static void
84 setup_blind(mp_int n, mp_int b, mp_int bi)
85 {
86     mp_int_init(b);
87     mp_int_init(bi);
88     random_num(b, mp_int_count_bits(n));
89     mp_int_mod(b, n, b);
90     mp_int_invmod(b, n, bi);
91 }
92
93 static void
94 blind(mp_int in, mp_int b, mp_int e, mp_int n)
95 {
96     mpz_t t1;
97     mp_int_init(&t1);
98     /* in' = (in * b^e) mod n */
99     mp_int_exptmod(b, e, n, &t1);
100     mp_int_mul(&t1, in, in);
101     mp_int_mod(in, n, in);
102     mp_int_clear(&t1);
103 }
104
105 static void
106 unblind(mp_int out, mp_int bi, mp_int n)
107 {
108     /* out' = (out * 1/b) mod n */
109     mp_int_mul(out, bi, out);
110     mp_int_mod(out, n, out);
111 }
112
113 static mp_result
114 rsa_private_calculate(mp_int in, mp_int p,  mp_int q,
115                       mp_int dmp1, mp_int dmq1, mp_int iqmp,
116                       mp_int out)
117 {
118     mpz_t vp, vq, u;
119     mp_int_init(&vp); mp_int_init(&vq); mp_int_init(&u);
120
121     /* vq = c ^ (d mod (q - 1)) mod q */
122     /* vp = c ^ (d mod (p - 1)) mod p */
123     mp_int_mod(in, p, &u);
124     mp_int_exptmod(&u, dmp1, p, &vp);
125     mp_int_mod(in, q, &u);
126     mp_int_exptmod(&u, dmq1, q, &vq);
127
128     /* C2 = 1/q mod p  (iqmp) */
129     /* u = (vp - vq)C2 mod p. */
130     mp_int_sub(&vp, &vq, &u);
131     if (mp_int_compare_zero(&u) < 0)
132         mp_int_add(&u, p, &u);
133     mp_int_mul(&u, iqmp, &u);
134     mp_int_mod(&u, p, &u);
135
136     /* c ^ d mod n = vq + u q */
137     mp_int_mul(&u, q, &u);
138     mp_int_add(&u, &vq, out);
139
140     mp_int_clear(&vp);
141     mp_int_clear(&vq);
142     mp_int_clear(&u);
143
144     return MP_OK;
145 }
146
147 /*
148  *
149  */
150
151 static int
152 imath_rsa_public_encrypt(int flen, const unsigned char* from,
153                         unsigned char* to, RSA* rsa, int padding)
154 {
155     unsigned char *p, *p0;
156     mp_result res;
157     size_t size, padlen;
158     mpz_t enc, dec, n, e;
159
160     if (padding != RSA_PKCS1_PADDING)
161         return -1;
162
163     size = RSA_size(rsa);
164
165     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
166         return -2;
167
168     BN2mpz(&n, rsa->n);
169     BN2mpz(&e, rsa->e);
170
171     p = p0 = malloc(size - 1);
172     if (p0 == NULL) {
173         mp_int_clear(&e);
174         mp_int_clear(&n);
175         return -3;
176     }
177
178     padlen = size - flen - 3;
179
180     *p++ = 2;
181     if (RAND_bytes(p, padlen) != 1) {
182         mp_int_clear(&e);
183         mp_int_clear(&n);
184         free(p0);
185         return -4;
186     }
187     while(padlen) {
188         if (*p == 0)
189             *p = 1;
190         padlen--;
191         p++;
192     }
193     *p++ = 0;
194     memcpy(p, from, flen);
195     p += flen;
196     assert((p - p0) == size - 1);
197
198     mp_int_init(&enc);
199     mp_int_init(&dec);
200     mp_int_read_unsigned(&dec, p0, size - 1);
201     free(p0);
202
203     res = mp_int_exptmod(&dec, &e, &n, &enc);
204
205     mp_int_clear(&dec);
206     mp_int_clear(&e);
207     mp_int_clear(&n);
208     {
209         size_t ssize;
210         ssize = mp_int_unsigned_len(&enc);
211         assert(size >= ssize);
212         mp_int_to_unsigned(&enc, to, ssize);
213         size = ssize;
214     }
215     mp_int_clear(&enc);
216
217     return size;
218 }
219
220 static int
221 imath_rsa_public_decrypt(int flen, const unsigned char* from,
222                          unsigned char* to, RSA* rsa, int padding)
223 {
224     unsigned char *p;
225     mp_result res;
226     size_t size;
227     mpz_t s, us, n, e;
228
229     if (padding != RSA_PKCS1_PADDING)
230         return -1;
231
232     if (flen > RSA_size(rsa))
233         return -2;
234
235     BN2mpz(&n, rsa->n);
236     BN2mpz(&e, rsa->e);
237
238 #if 0
239     /* Check that the exponent is larger then 3 */
240     if (mp_int_compare_value(&e, 3) <= 0) {
241         mp_int_clear(&n);
242         mp_int_clear(&e);
243         return -3;
244     }
245 #endif
246
247     mp_int_init(&s);
248     mp_int_init(&us);
249     mp_int_read_unsigned(&s, rk_UNCONST(from), flen);
250
251     if (mp_int_compare(&s, &n) >= 0) {
252         mp_int_clear(&n);
253         mp_int_clear(&e);
254         return -4;
255     }
256
257     res = mp_int_exptmod(&s, &e, &n, &us);
258
259     mp_int_clear(&s);
260     mp_int_clear(&n);
261     mp_int_clear(&e);
262
263     if (res != MP_OK)
264         return -5;
265     p = to;
266
267
268     size = mp_int_unsigned_len(&us);
269     assert(size <= RSA_size(rsa));
270     mp_int_to_unsigned(&us, p, size);
271
272     mp_int_clear(&us);
273
274     /* head zero was skipped by mp_int_to_unsigned */
275     if (*p == 0)
276         return -6;
277     if (*p != 1)
278         return -7;
279     size--; p++;
280     while (size && *p == 0xff) {
281         size--; p++;
282     }
283     if (size == 0 || *p != 0)
284         return -8;
285     size--; p++;
286
287     memmove(to, p, size);
288
289     return size;
290 }
291
292 static int
293 imath_rsa_private_encrypt(int flen, const unsigned char* from,
294                           unsigned char* to, RSA* rsa, int padding)
295 {
296     unsigned char *p, *p0;
297     mp_result res;
298     size_t size;
299     mpz_t in, out, n, e, b, bi;
300     int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
301
302     if (padding != RSA_PKCS1_PADDING)
303         return -1;
304
305     size = RSA_size(rsa);
306
307     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
308         return -2;
309
310     p0 = p = malloc(size);
311     *p++ = 0;
312     *p++ = 1;
313     memset(p, 0xff, size - flen - 3);
314     p += size - flen - 3;
315     *p++ = 0;
316     memcpy(p, from, flen);
317     p += flen;
318     assert((p - p0) == size);
319
320     BN2mpz(&n, rsa->n);
321     BN2mpz(&e, rsa->e);
322
323     mp_int_init(&in);
324     mp_int_init(&out);
325     mp_int_read_unsigned(&in, p0, size);
326     free(p0);
327
328     if(mp_int_compare_zero(&in) < 0 ||
329        mp_int_compare(&in, &n) >= 0) {
330         size = 0;
331         goto out;
332     }
333
334     if (blinding) {
335         setup_blind(&n, &b, &bi);
336         blind(&in, &b, &e, &n);
337     }
338
339     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
340         mpz_t p, q, dmp1, dmq1, iqmp;
341
342         BN2mpz(&p, rsa->p);
343         BN2mpz(&q, rsa->q);
344         BN2mpz(&dmp1, rsa->dmp1);
345         BN2mpz(&dmq1, rsa->dmq1);
346         BN2mpz(&iqmp, rsa->iqmp);
347
348         res = rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
349
350         mp_int_clear(&p);
351         mp_int_clear(&q);
352         mp_int_clear(&dmp1);
353         mp_int_clear(&dmq1);
354         mp_int_clear(&iqmp);
355     } else {
356         mpz_t d;
357
358         BN2mpz(&d, rsa->d);
359         res = mp_int_exptmod(&in, &d, &n, &out);
360         mp_int_clear(&d);
361         if (res != MP_OK) {
362             size = 0;
363             goto out;
364         }
365     }
366
367     if (blinding) {
368         unblind(&out, &bi, &n);
369         mp_int_clear(&b);
370         mp_int_clear(&bi);
371     }
372
373     {
374         size_t ssize;
375         ssize = mp_int_unsigned_len(&out);
376         assert(size >= ssize);
377         mp_int_to_unsigned(&out, to, size);
378         size = ssize;
379     }
380
381 out:
382     mp_int_clear(&e);
383     mp_int_clear(&n);
384     mp_int_clear(&in);
385     mp_int_clear(&out);
386
387     return size;
388 }
389
390 static int
391 imath_rsa_private_decrypt(int flen, const unsigned char* from,
392                           unsigned char* to, RSA* rsa, int padding)
393 {
394     unsigned char *ptr;
395     mp_result res;
396     size_t size;
397     mpz_t in, out, n, e, b, bi;
398     int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
399
400     if (padding != RSA_PKCS1_PADDING)
401         return -1;
402
403     size = RSA_size(rsa);
404     if (flen > size)
405         return -2;
406
407     mp_int_init(&in);
408     mp_int_init(&out);
409
410     BN2mpz(&n, rsa->n);
411     BN2mpz(&e, rsa->e);
412
413     res = mp_int_read_unsigned(&in, rk_UNCONST(from), flen);
414     if (res != MP_OK) {
415         size = -1;
416         goto out;
417     }
418
419     if(mp_int_compare_zero(&in) < 0 ||
420        mp_int_compare(&in, &n) >= 0) {
421         size = 0;
422         goto out;
423     }
424
425     if (blinding) {
426         setup_blind(&n, &b, &bi);
427         blind(&in, &b, &e, &n);
428     }
429
430     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
431         mpz_t p, q, dmp1, dmq1, iqmp;
432
433         BN2mpz(&p, rsa->p);
434         BN2mpz(&q, rsa->q);
435         BN2mpz(&dmp1, rsa->dmp1);
436         BN2mpz(&dmq1, rsa->dmq1);
437         BN2mpz(&iqmp, rsa->iqmp);
438
439         res = rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
440
441         mp_int_clear(&p);
442         mp_int_clear(&q);
443         mp_int_clear(&dmp1);
444         mp_int_clear(&dmq1);
445         mp_int_clear(&iqmp);
446     } else {
447         mpz_t d;
448
449         if(mp_int_compare_zero(&in) < 0 ||
450            mp_int_compare(&in, &n) >= 0)
451             return MP_RANGE;
452
453         BN2mpz(&d, rsa->d);
454         res = mp_int_exptmod(&in, &d, &n, &out);
455         mp_int_clear(&d);
456         if (res != MP_OK) {
457             size = 0;
458             goto out;
459         }
460     }
461
462     if (blinding) {
463         unblind(&out, &bi, &n);
464         mp_int_clear(&b);
465         mp_int_clear(&bi);
466     }
467
468     ptr = to;
469     {
470         size_t ssize;
471         ssize = mp_int_unsigned_len(&out);
472         assert(size >= ssize);
473         mp_int_to_unsigned(&out, ptr, ssize);
474         size = ssize;
475     }
476
477     /* head zero was skipped by mp_int_to_unsigned */
478     if (*ptr != 2)
479         return -3;
480     size--; ptr++;
481     while (size && *ptr != 0) {
482         size--; ptr++;
483     }
484     if (size == 0)
485         return -4;
486     size--; ptr++;
487
488     memmove(to, ptr, size);
489
490 out:
491     mp_int_clear(&e);
492     mp_int_clear(&n);
493     mp_int_clear(&in);
494     mp_int_clear(&out);
495
496     return size;
497 }
498
499 static int
500 random_num(mp_int num, size_t len)
501 {
502     unsigned char *p;
503     mp_result res;
504
505     len = (len + 7) / 8;
506     p = malloc(len);
507     if (p == NULL)
508         return 1;
509     if (RAND_bytes(p, len) != 1) {
510         free(p);
511         return 1;
512     }
513     res = mp_int_read_unsigned(num, p, len);
514     free(p);
515     if (res != MP_OK)
516         return 1;
517     return 0;
518 }
519
520 #define CHECK(f, v) if ((f) != (v)) { goto out; }
521
522 static int
523 imath_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
524 {
525     mpz_t el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
526     int counter, ret;
527
528     if (bits < 789)
529         return -1;
530
531     ret = -1;
532
533     mp_int_init(&el);
534     mp_int_init(&p);
535     mp_int_init(&q);
536     mp_int_init(&n);
537     mp_int_init(&d);
538     mp_int_init(&dmp1);
539     mp_int_init(&dmq1);
540     mp_int_init(&iqmp);
541     mp_int_init(&t1);
542     mp_int_init(&t2);
543     mp_int_init(&t3);
544
545     BN2mpz(&el, e);
546
547     /* generate p and q so that p != q and bits(pq) ~ bits */
548     counter = 0;
549     do {
550         BN_GENCB_call(cb, 2, counter++);
551         CHECK(random_num(&p, bits / 2 + 1), 0);
552         CHECK(mp_int_find_prime(&p), MP_TRUE);
553
554         CHECK(mp_int_sub_value(&p, 1, &t1), MP_OK);
555         CHECK(mp_int_gcd(&t1, &el, &t2), MP_OK);
556     } while(mp_int_compare_value(&t2, 1) != 0);
557
558     BN_GENCB_call(cb, 3, 0);
559
560     counter = 0;
561     do {
562         BN_GENCB_call(cb, 2, counter++);
563         CHECK(random_num(&q, bits / 2 + 1), 0);
564         CHECK(mp_int_find_prime(&q), MP_TRUE);
565
566         if (mp_int_compare(&p, &q) == 0) /* don't let p and q be the same */
567             continue;
568
569         CHECK(mp_int_sub_value(&q, 1, &t1), MP_OK);
570         CHECK(mp_int_gcd(&t1, &el, &t2), MP_OK);
571     } while(mp_int_compare_value(&t2, 1) != 0);
572
573     /* make p > q */
574     if (mp_int_compare(&p, &q) < 0)
575         mp_int_swap(&p, &q);
576
577     BN_GENCB_call(cb, 3, 1);
578
579     /* calculate n,             n = p * q */
580     CHECK(mp_int_mul(&p, &q, &n), MP_OK);
581
582     /* calculate d,             d = 1/e mod (p - 1)(q - 1) */
583     CHECK(mp_int_sub_value(&p, 1, &t1), MP_OK);
584     CHECK(mp_int_sub_value(&q, 1, &t2), MP_OK);
585     CHECK(mp_int_mul(&t1, &t2, &t3), MP_OK);
586     CHECK(mp_int_invmod(&el, &t3, &d), MP_OK);
587
588     /* calculate dmp1           dmp1 = d mod (p-1) */
589     CHECK(mp_int_mod(&d, &t1, &dmp1), MP_OK);
590     /* calculate dmq1           dmq1 = d mod (q-1) */
591     CHECK(mp_int_mod(&d, &t2, &dmq1), MP_OK);
592     /* calculate iqmp           iqmp = 1/q mod p */
593     CHECK(mp_int_invmod(&q, &p, &iqmp), MP_OK);
594
595     /* fill in RSA key */
596
597     rsa->e = mpz2BN(&el);
598     rsa->p = mpz2BN(&p);
599     rsa->q = mpz2BN(&q);
600     rsa->n = mpz2BN(&n);
601     rsa->d = mpz2BN(&d);
602     rsa->dmp1 = mpz2BN(&dmp1);
603     rsa->dmq1 = mpz2BN(&dmq1);
604     rsa->iqmp = mpz2BN(&iqmp);
605
606     ret = 1;
607 out:
608     mp_int_clear(&el);
609     mp_int_clear(&p);
610     mp_int_clear(&q);
611     mp_int_clear(&n);
612     mp_int_clear(&d);
613     mp_int_clear(&dmp1);
614     mp_int_clear(&dmq1);
615     mp_int_clear(&iqmp);
616     mp_int_clear(&t1);
617     mp_int_clear(&t2);
618     mp_int_clear(&t3);
619
620     return ret;
621 }
622
623 static int
624 imath_rsa_init(RSA *rsa)
625 {
626     return 1;
627 }
628
629 static int
630 imath_rsa_finish(RSA *rsa)
631 {
632     return 1;
633 }
634
635 const RSA_METHOD hc_rsa_imath_method = {
636     "hcrypto imath RSA",
637     imath_rsa_public_encrypt,
638     imath_rsa_public_decrypt,
639     imath_rsa_private_encrypt,
640     imath_rsa_private_decrypt,
641     NULL,
642     NULL,
643     imath_rsa_init,
644     imath_rsa_finish,
645     0,
646     NULL,
647     NULL,
648     NULL,
649     imath_rsa_generate_key
650 };
651
652 const RSA_METHOD *
653 RSA_imath_method(void)
654 {
655     return &hc_rsa_imath_method;
656 }