a6e09fe2834bb0ddcc38e499dfe146682a687233
[kai/samba.git] / source4 / heimdal / lib / hcrypto / rsa.c
1 /*
2  * Copyright (c) 2006 - 2008 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden).
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors
18  *    may be used to endorse or promote products derived from this software
19  *    without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31  * SUCH DAMAGE.
32  */
33
34 #include <config.h>
35
36 #include <stdio.h>
37 #include <stdlib.h>
38 #include <krb5-types.h>
39 #include <rfc2459_asn1.h>
40
41 #include <der.h>
42
43 #include <rsa.h>
44
45 #include "common.h"
46
47 #include <roken.h>
48
49 /**
50  * @page page_rsa RSA - public-key cryptography
51  *
52  * RSA is named by its inventors (Ron Rivest, Adi Shamir, and Leonard
53  * Adleman) (published in 1977), patented expired in 21 September 2000.
54  *
55  *
56  * Speed for RSA in seconds
57  *   no key blinding
58  *   1000 iteration, 
59  *   same rsa key
60  *   operation performed each eteration sign, verify, encrypt, decrypt on a random bit pattern
61  *
62  * gmp:          0.733615
63  * tfm:          2.450173
64  * ltm:          3.79 (default in hcrypto)
65  * openssl:      4.04
66  * cdsa:        15.89
67  * imath:       40.62
68  *
69  * See the library functions here: @ref hcrypto_rsa
70  */
71
72 /**
73  * Same as RSA_new_method() using NULL as engine.
74  *
75  * @return a newly allocated RSA object. Free with RSA_free().
76  *
77  * @ingroup hcrypto_rsa
78  */
79
80 RSA *
81 RSA_new(void)
82 {
83     return RSA_new_method(NULL);
84 }
85
86 /**
87  * Allocate a new RSA object using the engine, if NULL is specified as
88  * the engine, use the default RSA engine as returned by
89  * ENGINE_get_default_RSA().
90  *
91  * @param engine Specific what ENGINE RSA provider should be used.
92  *
93  * @return a newly allocated RSA object. Free with RSA_free().
94  *
95  * @ingroup hcrypto_rsa
96  */
97
98 RSA *
99 RSA_new_method(ENGINE *engine)
100 {
101     RSA *rsa;
102
103     rsa = calloc(1, sizeof(*rsa));
104     if (rsa == NULL)
105         return NULL;
106
107     rsa->references = 1;
108
109     if (engine) {
110         ENGINE_up_ref(engine);
111         rsa->engine = engine;
112     } else {
113         rsa->engine = ENGINE_get_default_RSA();
114     }
115
116     if (rsa->engine) {
117         rsa->meth = ENGINE_get_RSA(rsa->engine);
118         if (rsa->meth == NULL) {
119             ENGINE_finish(engine);
120             free(rsa);
121             return 0;
122         }
123     }
124
125     if (rsa->meth == NULL)
126         rsa->meth = rk_UNCONST(RSA_get_default_method());
127
128     (*rsa->meth->init)(rsa);
129
130     return rsa;
131 }
132
133 /**
134  * Free an allocation RSA object.
135  *
136  * @param rsa the RSA object to free.
137  * @ingroup hcrypto_rsa
138  */
139
140 void
141 RSA_free(RSA *rsa)
142 {
143     if (rsa->references <= 0)
144         abort();
145
146     if (--rsa->references > 0)
147         return;
148
149     (*rsa->meth->finish)(rsa);
150
151     if (rsa->engine)
152         ENGINE_finish(rsa->engine);
153
154 #define free_if(f) if (f) { BN_free(f); }
155     free_if(rsa->n);
156     free_if(rsa->e);
157     free_if(rsa->d);
158     free_if(rsa->p);
159     free_if(rsa->q);
160     free_if(rsa->dmp1);
161     free_if(rsa->dmq1);
162     free_if(rsa->iqmp);
163 #undef free_if
164
165     memset(rsa, 0, sizeof(*rsa));
166     free(rsa);
167 }
168
169 /**
170  * Add an extra reference to the RSA object. The object should be free
171  * with RSA_free() to drop the reference.
172  *
173  * @param rsa the object to add reference counting too.
174  *
175  * @return the current reference count, can't safely be used except
176  * for debug printing.
177  *
178  * @ingroup hcrypto_rsa
179  */
180
181 int
182 RSA_up_ref(RSA *rsa)
183 {
184     return ++rsa->references;
185 }
186
187 /**
188  * Return the RSA_METHOD used for this RSA object.
189  *
190  * @param rsa the object to get the method from.
191  *
192  * @return the method used for this RSA object.
193  *
194  * @ingroup hcrypto_rsa
195  */
196
197 const RSA_METHOD *
198 RSA_get_method(const RSA *rsa)
199 {
200     return rsa->meth;
201 }
202
203 /**
204  * Set a new method for the RSA keypair.
205  *
206  * @param rsa rsa parameter.
207  * @param method the new method for the RSA parameter.
208  *
209  * @return 1 on success.
210  *
211  * @ingroup hcrypto_rsa
212  */
213
214 int
215 RSA_set_method(RSA *rsa, const RSA_METHOD *method)
216 {
217     (*rsa->meth->finish)(rsa);
218
219     if (rsa->engine) {
220         ENGINE_finish(rsa->engine);
221         rsa->engine = NULL;
222     }
223
224     rsa->meth = method;
225     (*rsa->meth->init)(rsa);
226     return 1;
227 }
228
229 /**
230  * Set the application data for the RSA object.
231  *
232  * @param rsa the rsa object to set the parameter for
233  * @param arg the data object to store
234  *
235  * @return 1 on success.
236  *
237  * @ingroup hcrypto_rsa
238  */
239
240 int
241 RSA_set_app_data(RSA *rsa, void *arg)
242 {
243     rsa->ex_data.sk = arg;
244     return 1;
245 }
246
247 /**
248  * Get the application data for the RSA object.
249  *
250  * @param rsa the rsa object to get the parameter for
251  *
252  * @return the data object
253  *
254  * @ingroup hcrypto_rsa
255  */
256
257 void *
258 RSA_get_app_data(const RSA *rsa)
259 {
260     return rsa->ex_data.sk;
261 }
262
263 int
264 RSA_check_key(const RSA *key)
265 {
266     static const unsigned char inbuf[] = "hello, world!";
267     RSA *rsa = rk_UNCONST(key);
268     void *buffer;
269     int ret;
270
271     /*
272      * XXX I have no clue how to implement this w/o a bignum library.
273      * Well, when we have a RSA key pair, we can try to encrypt/sign
274      * and then decrypt/verify.
275      */
276
277     if ((rsa->d == NULL || rsa->n == NULL) &&
278         (rsa->p == NULL || rsa->q || rsa->dmp1 == NULL || rsa->dmq1 == NULL || rsa->iqmp == NULL))
279         return 0;
280
281     buffer = malloc(RSA_size(rsa));
282     if (buffer == NULL)
283         return 0;
284
285     ret = RSA_private_encrypt(sizeof(inbuf), inbuf, buffer,
286                              rsa, RSA_PKCS1_PADDING);
287     if (ret == -1) {
288         free(buffer);
289         return 0;
290     }
291
292     ret = RSA_public_decrypt(ret, buffer, buffer,
293                               rsa, RSA_PKCS1_PADDING);
294     if (ret == -1) {
295         free(buffer);
296         return 0;
297     }
298
299     if (ret == sizeof(inbuf) && ct_memcmp(buffer, inbuf, sizeof(inbuf)) == 0) {
300         free(buffer);
301         return 1;
302     }
303     free(buffer);
304     return 0;
305 }
306
307 int
308 RSA_size(const RSA *rsa)
309 {
310     return BN_num_bytes(rsa->n);
311 }
312
313 #define RSAFUNC(name, body) \
314 int \
315 name(int flen,const unsigned char* f, unsigned char* t, RSA* r, int p){\
316     return body; \
317 }
318
319 RSAFUNC(RSA_public_encrypt, (r)->meth->rsa_pub_enc(flen, f, t, r, p))
320 RSAFUNC(RSA_public_decrypt, (r)->meth->rsa_pub_dec(flen, f, t, r, p))
321 RSAFUNC(RSA_private_encrypt, (r)->meth->rsa_priv_enc(flen, f, t, r, p))
322 RSAFUNC(RSA_private_decrypt, (r)->meth->rsa_priv_dec(flen, f, t, r, p))
323
324 static const heim_octet_string null_entry_oid = { 2, rk_UNCONST("\x05\x00") };
325
326 static const unsigned sha1_oid_tree[] = { 1, 3, 14, 3, 2, 26 };
327 static const AlgorithmIdentifier _signature_sha1_data = {
328     { 6, rk_UNCONST(sha1_oid_tree) }, rk_UNCONST(&null_entry_oid)
329 };
330 static const unsigned sha256_oid_tree[] = { 2, 16, 840, 1, 101, 3, 4, 2, 1 };
331 static const AlgorithmIdentifier _signature_sha256_data = {
332     { 9, rk_UNCONST(sha256_oid_tree) }, rk_UNCONST(&null_entry_oid)
333 };
334 static const unsigned md5_oid_tree[] = { 1, 2, 840, 113549, 2, 5 };
335 static const AlgorithmIdentifier _signature_md5_data = {
336     { 6, rk_UNCONST(md5_oid_tree) }, rk_UNCONST(&null_entry_oid)
337 };
338
339
340 int
341 RSA_sign(int type, const unsigned char *from, unsigned int flen,
342          unsigned char *to, unsigned int *tlen, RSA *rsa)
343 {
344     if (rsa->meth->rsa_sign)
345         return rsa->meth->rsa_sign(type, from, flen, to, tlen, rsa);
346
347     if (rsa->meth->rsa_priv_enc) {
348         heim_octet_string indata;
349         DigestInfo di;
350         size_t size;
351         int ret;
352
353         memset(&di, 0, sizeof(di));
354
355         if (type == NID_sha1) {
356             di.digestAlgorithm = _signature_sha1_data;
357         } else if (type == NID_md5) {
358             di.digestAlgorithm = _signature_md5_data;
359         } else if (type == NID_sha256) {
360             di.digestAlgorithm = _signature_sha256_data;
361         } else
362             return -1;
363
364         di.digest.data = rk_UNCONST(from);
365         di.digest.length = flen;
366
367         ASN1_MALLOC_ENCODE(DigestInfo,
368                            indata.data,
369                            indata.length,
370                            &di,
371                            &size,
372                            ret);
373         if (ret)
374             return ret;
375         if (indata.length != size)
376             abort();
377
378         ret = rsa->meth->rsa_priv_enc(indata.length, indata.data, to,
379                                       rsa, RSA_PKCS1_PADDING);
380         free(indata.data);
381         if (ret > 0) {
382             *tlen = ret;
383             ret = 1;
384         } else
385             ret = 0;
386
387         return ret;
388     }
389
390     return 0;
391 }
392
393 int
394 RSA_verify(int type, const unsigned char *from, unsigned int flen,
395            unsigned char *sigbuf, unsigned int siglen, RSA *rsa)
396 {
397     if (rsa->meth->rsa_verify)
398         return rsa->meth->rsa_verify(type, from, flen, sigbuf, siglen, rsa);
399
400     if (rsa->meth->rsa_pub_dec) {
401         const AlgorithmIdentifier *digest_alg;
402         void *data;
403         DigestInfo di;
404         size_t size;
405         int ret, ret2;
406
407         data = malloc(RSA_size(rsa));
408         if (data == NULL)
409             return -1;
410
411         memset(&di, 0, sizeof(di));
412
413         ret = rsa->meth->rsa_pub_dec(siglen, sigbuf, data, rsa, RSA_PKCS1_PADDING);
414         if (ret <= 0) {
415             free(data);
416             return -2;
417         }
418
419         ret2 = decode_DigestInfo(data, ret, &di, &size);
420         free(data);
421         if (ret2 != 0)
422             return -3;
423         if (ret != size) {
424             free_DigestInfo(&di);
425             return -4;
426         }
427
428         if (flen != di.digest.length || memcmp(di.digest.data, from, flen) != 0) {
429             free_DigestInfo(&di);
430             return -5;
431         }
432
433         if (type == NID_sha1) {
434             digest_alg = &_signature_sha1_data;
435         } else if (type == NID_md5) {
436             digest_alg = &_signature_md5_data;
437         } else if (type == NID_sha256) {
438             digest_alg = &_signature_sha256_data;
439         } else {
440             free_DigestInfo(&di);
441             return -1;
442         }
443         
444         ret = der_heim_oid_cmp(&digest_alg->algorithm,
445                                &di.digestAlgorithm.algorithm);
446         free_DigestInfo(&di);
447         
448         if (ret != 0)
449             return 0;
450         return 1;
451     }
452
453     return 0;
454 }
455
456 /*
457  * A NULL RSA_METHOD that returns failure for all operations. This is
458  * used as the default RSA method if we don't have any native
459  * support.
460  */
461
462 static RSAFUNC(null_rsa_public_encrypt, -1)
463 static RSAFUNC(null_rsa_public_decrypt, -1)
464 static RSAFUNC(null_rsa_private_encrypt, -1)
465 static RSAFUNC(null_rsa_private_decrypt, -1)
466
467 /*
468  *
469  */
470
471 int
472 RSA_generate_key_ex(RSA *r, int bits, BIGNUM *e, BN_GENCB *cb)
473 {
474     if (r->meth->rsa_keygen)
475         return (*r->meth->rsa_keygen)(r, bits, e, cb);
476     return 0;
477 }
478
479
480 /*
481  *
482  */
483
484 static int
485 null_rsa_init(RSA *rsa)
486 {
487     return 1;
488 }
489
490 static int
491 null_rsa_finish(RSA *rsa)
492 {
493     return 1;
494 }
495
496 static const RSA_METHOD rsa_null_method = {
497     "hcrypto null RSA",
498     null_rsa_public_encrypt,
499     null_rsa_public_decrypt,
500     null_rsa_private_encrypt,
501     null_rsa_private_decrypt,
502     NULL,
503     NULL,
504     null_rsa_init,
505     null_rsa_finish,
506     0,
507     NULL,
508     NULL,
509     NULL
510 };
511
512 const RSA_METHOD *
513 RSA_null_method(void)
514 {
515     return &rsa_null_method;
516 }
517
518 extern const RSA_METHOD hc_rsa_gmp_method;
519 extern const RSA_METHOD hc_rsa_imath_method;
520 extern const RSA_METHOD hc_rsa_tfm_method;
521 extern const RSA_METHOD hc_rsa_ltm_method;
522 static const RSA_METHOD *default_rsa_method = &hc_rsa_ltm_method;
523
524
525 const RSA_METHOD *
526 RSA_get_default_method(void)
527 {
528     return default_rsa_method;
529 }
530
531 void
532 RSA_set_default_method(const RSA_METHOD *meth)
533 {
534     default_rsa_method = meth;
535 }
536
537 /*
538  *
539  */
540
541 RSA *
542 d2i_RSAPrivateKey(RSA *rsa, const unsigned char **pp, size_t len)
543 {
544     RSAPrivateKey data;
545     RSA *k = rsa;
546     size_t size;
547     int ret;
548
549     ret = decode_RSAPrivateKey(*pp, len, &data, &size);
550     if (ret)
551         return NULL;
552
553     *pp += size;
554
555     if (k == NULL) {
556         k = RSA_new();
557         if (k == NULL) {
558             free_RSAPrivateKey(&data);
559             return NULL;
560         }
561     }
562
563     k->n = _hc_integer_to_BN(&data.modulus, NULL);
564     k->e = _hc_integer_to_BN(&data.publicExponent, NULL);
565     k->d = _hc_integer_to_BN(&data.privateExponent, NULL);
566     k->p = _hc_integer_to_BN(&data.prime1, NULL);
567     k->q = _hc_integer_to_BN(&data.prime2, NULL);
568     k->dmp1 = _hc_integer_to_BN(&data.exponent1, NULL);
569     k->dmq1 = _hc_integer_to_BN(&data.exponent2, NULL);
570     k->iqmp = _hc_integer_to_BN(&data.coefficient, NULL);
571     free_RSAPrivateKey(&data);
572
573     if (k->n == NULL || k->e == NULL || k->d == NULL || k->p == NULL ||
574         k->q == NULL || k->dmp1 == NULL || k->dmq1 == NULL || k->iqmp == NULL)
575     {
576         RSA_free(k);
577         return NULL;
578     }
579         
580     return k;
581 }
582
583 int
584 i2d_RSAPrivateKey(RSA *rsa, unsigned char **pp)
585 {
586     RSAPrivateKey data;
587     size_t size;
588     int ret;
589
590     if (rsa->n == NULL || rsa->e == NULL || rsa->d == NULL || rsa->p == NULL ||
591         rsa->q == NULL || rsa->dmp1 == NULL || rsa->dmq1 == NULL ||
592         rsa->iqmp == NULL)
593         return -1;
594
595     memset(&data, 0, sizeof(data));
596
597     ret  = _hc_BN_to_integer(rsa->n, &data.modulus);
598     ret |= _hc_BN_to_integer(rsa->e, &data.publicExponent);
599     ret |= _hc_BN_to_integer(rsa->d, &data.privateExponent);
600     ret |= _hc_BN_to_integer(rsa->p, &data.prime1);
601     ret |= _hc_BN_to_integer(rsa->q, &data.prime2);
602     ret |= _hc_BN_to_integer(rsa->dmp1, &data.exponent1);
603     ret |= _hc_BN_to_integer(rsa->dmq1, &data.exponent2);
604     ret |= _hc_BN_to_integer(rsa->iqmp, &data.coefficient);
605     if (ret) {
606         free_RSAPrivateKey(&data);
607         return -1;
608     }
609
610     if (pp == NULL) {
611         size = length_RSAPrivateKey(&data);
612         free_RSAPrivateKey(&data);
613     } else {
614         void *p;
615         size_t len;
616
617         ASN1_MALLOC_ENCODE(RSAPrivateKey, p, len, &data, &size, ret);
618         free_RSAPrivateKey(&data);
619         if (ret)
620             return -1;
621         if (len != size)
622             abort();
623
624         memcpy(*pp, p, size);
625         free(p);
626
627         *pp += size;
628
629     }
630     return size;
631 }
632
633 int
634 i2d_RSAPublicKey(RSA *rsa, unsigned char **pp)
635 {
636     RSAPublicKey data;
637     size_t size;
638     int ret;
639
640     memset(&data, 0, sizeof(data));
641
642     if (_hc_BN_to_integer(rsa->n, &data.modulus) ||
643         _hc_BN_to_integer(rsa->e, &data.publicExponent))
644     {
645         free_RSAPublicKey(&data);
646         return -1;
647     }
648
649     if (pp == NULL) {
650         size = length_RSAPublicKey(&data);
651         free_RSAPublicKey(&data);
652     } else {
653         void *p;
654         size_t len;
655
656         ASN1_MALLOC_ENCODE(RSAPublicKey, p, len, &data, &size, ret);
657         free_RSAPublicKey(&data);
658         if (ret)
659             return -1;
660         if (len != size)
661             abort();
662
663         memcpy(*pp, p, size);
664         free(p);
665
666         *pp += size;
667     }
668
669     return size;
670 }
671
672 RSA *
673 d2i_RSAPublicKey(RSA *rsa, const unsigned char **pp, size_t len)
674 {
675     RSAPublicKey data;
676     RSA *k = rsa;
677     size_t size;
678     int ret;
679
680     ret = decode_RSAPublicKey(*pp, len, &data, &size);
681     if (ret)
682         return NULL;
683
684     *pp += size;
685
686     if (k == NULL) {
687         k = RSA_new();
688         if (k == NULL) {
689             free_RSAPublicKey(&data);
690             return NULL;
691         }
692     }
693
694     k->n = _hc_integer_to_BN(&data.modulus, NULL);
695     k->e = _hc_integer_to_BN(&data.publicExponent, NULL);
696
697     free_RSAPublicKey(&data);
698
699     if (k->n == NULL || k->e == NULL) {
700         RSA_free(k);
701         return NULL;
702     }
703         
704     return k;
705 }