r23456: Update Samba4 to current lorikeet-heimdal.
[jelmer/samba4-debian.git] / source / heimdal / lib / hx509 / ks_file.c
1 /*
2  * Copyright (c) 2005 - 2007 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden). 
4  * All rights reserved. 
5  *
6  * Redistribution and use in source and binary forms, with or without 
7  * modification, are permitted provided that the following conditions 
8  * are met: 
9  *
10  * 1. Redistributions of source code must retain the above copyright 
11  *    notice, this list of conditions and the following disclaimer. 
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright 
14  *    notice, this list of conditions and the following disclaimer in the 
15  *    documentation and/or other materials provided with the distribution. 
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors 
18  *    may be used to endorse or promote products derived from this software 
19  *    without specific prior written permission. 
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND 
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE 
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 
31  * SUCH DAMAGE. 
32  */
33
34 #include "hx_locl.h"
35 RCSID("$Id: ks_file.c 20776 2007-06-01 22:02:01Z lha $");
36
37 struct ks_file {
38     hx509_certs certs;
39     char *fn;
40 };
41
42 struct header {
43     char *header;
44     char *value;
45     struct header *next;
46 };
47
48 static int
49 add_headers(struct header **headers, const char *header, const char *value)
50 {
51     struct header *h;
52     h = calloc(1, sizeof(*h));
53     if (h == NULL)
54         return ENOMEM;
55     h->header = strdup(header);
56     if (h->header == NULL) {
57         free(h);
58         return ENOMEM;
59     }
60     h->value = strdup(value);
61     if (h->value == NULL) {
62         free(h->header);
63         free(h);
64         return ENOMEM;
65     }
66
67     h->next = *headers;
68     *headers = h;
69
70     return 0;
71 }
72
73 static void
74 free_headers(struct header *headers)
75 {
76     struct header *h;
77     while (headers) {
78         h = headers;
79         headers = headers->next;
80         free(h->header);
81         free(h->value);
82         free(h);
83     }
84 }
85
86 static const char *
87 find_header(const struct header *headers, const char *header)
88 {
89     while(headers) {
90         if (strcmp(header, headers->header) == 0)
91             return headers->value;
92         headers = headers->next;
93     }
94     return NULL;
95 }
96
97 /*
98  *
99  */
100
101 static int
102 parse_certificate(hx509_context context, const char *fn, 
103                   struct hx509_collector *c, 
104                   const struct header *headers,
105                   const void *data, size_t len)
106 {
107     hx509_cert cert;
108     Certificate t;
109     size_t size;
110     int ret;
111
112     ret = decode_Certificate(data, len, &t, &size);
113     if (ret) {
114         hx509_set_error_string(context, 0, ret, 
115                                "Failed to parse certificate in %s",
116                                fn);
117         return ret;
118     }
119
120     ret = hx509_cert_init(context, &t, &cert);
121     free_Certificate(&t);
122     if (ret)
123         return ret;
124
125     ret = _hx509_collector_certs_add(context, c, cert);
126     hx509_cert_free(cert);
127     return ret;
128 }
129
130 static int
131 try_decrypt(hx509_context context,
132             struct hx509_collector *collector,
133             const AlgorithmIdentifier *alg,
134             const EVP_CIPHER *c,
135             const void *ivdata,
136             const void *password,
137             size_t passwordlen,
138             const void *cipher,
139             size_t len)
140 {
141     heim_octet_string clear;
142     size_t keylen;
143     void *key;
144     int ret;
145
146     keylen = EVP_CIPHER_key_length(c);
147
148     key = malloc(keylen);
149     if (key == NULL) {
150         hx509_clear_error_string(context);
151         return ENOMEM;
152     }
153
154     ret = EVP_BytesToKey(c, EVP_md5(), ivdata,
155                          password, passwordlen,
156                          1, key, NULL);
157     if (ret <= 0) {
158         hx509_set_error_string(context, 0, HX509_CRYPTO_INTERNAL_ERROR,
159                                "Failed to do string2key for private key");
160         return HX509_CRYPTO_INTERNAL_ERROR;
161     }
162
163     clear.data = malloc(len);
164     if (clear.data == NULL) {
165         hx509_set_error_string(context, 0, ENOMEM,
166                                "Out of memory to decrypt for private key");
167         ret = ENOMEM;
168         goto out;
169     }
170     clear.length = len;
171
172     {
173         EVP_CIPHER_CTX ctx;
174         EVP_CIPHER_CTX_init(&ctx);
175         EVP_CipherInit_ex(&ctx, c, NULL, key, ivdata, 0);
176         EVP_Cipher(&ctx, clear.data, cipher, len);
177         EVP_CIPHER_CTX_cleanup(&ctx);
178     }   
179
180     ret = _hx509_collector_private_key_add(context,
181                                            collector,
182                                            alg,
183                                            NULL,
184                                            &clear,
185                                            NULL);
186
187     memset(clear.data, 0, clear.length);
188     free(clear.data);
189 out:
190     memset(key, 0, keylen);
191     free(key);
192     return ret;
193 }
194
195 static int
196 parse_rsa_private_key(hx509_context context, const char *fn,
197                       struct hx509_collector *c, 
198                       const struct header *headers,
199                       const void *data, size_t len)
200 {
201     int ret = 0;
202     const char *enc;
203
204     enc = find_header(headers, "Proc-Type");
205     if (enc) {
206         const char *dek;
207         char *type, *iv;
208         ssize_t ssize, size;
209         void *ivdata;
210         const EVP_CIPHER *cipher;
211         const struct _hx509_password *pw;
212         hx509_lock lock;
213         int i, decrypted = 0;
214
215         lock = _hx509_collector_get_lock(c);
216         if (lock == NULL) {
217             hx509_set_error_string(context, 0, HX509_ALG_NOT_SUPP,
218                                    "Failed to get password for "
219                                    "password protected file %s", fn);
220             return HX509_ALG_NOT_SUPP;
221         }
222
223         if (strcmp(enc, "4,ENCRYPTED") != 0) {
224             hx509_set_error_string(context, 0, HX509_PARSING_KEY_FAILED,
225                                    "RSA key encrypted in unknown method %s "
226                                    "in file",
227                                    enc, fn);
228             hx509_clear_error_string(context);
229             return HX509_PARSING_KEY_FAILED;
230         }
231
232         dek = find_header(headers, "DEK-Info");
233         if (dek == NULL) {
234             hx509_set_error_string(context, 0, HX509_PARSING_KEY_FAILED,
235                                    "Encrypted RSA missing DEK-Info");
236             return HX509_PARSING_KEY_FAILED;
237         }
238
239         type = strdup(dek);
240         if (type == NULL) {
241             hx509_clear_error_string(context);
242             return ENOMEM;
243         }
244
245         iv = strchr(type, ',');
246         if (iv)
247             *iv++ = '\0';
248
249         size = strlen(iv);
250         ivdata = malloc(size);
251         if (ivdata == NULL) {
252             hx509_clear_error_string(context);
253             free(type);
254             return ENOMEM;
255         }
256
257         cipher = EVP_get_cipherbyname(type);
258         if (cipher == NULL) {
259             free(ivdata);
260             hx509_set_error_string(context, 0, HX509_ALG_NOT_SUPP,
261                                    "RSA key encrypted with "
262                                    "unsupported cipher: %s",
263                                    type);
264             free(type);
265             return HX509_ALG_NOT_SUPP;
266         }
267
268 #define PKCS5_SALT_LEN 8
269
270         ssize = hex_decode(iv, ivdata, size);
271         free(type);
272         type = NULL;
273         iv = NULL;
274
275         if (ssize < 0 || ssize < PKCS5_SALT_LEN || ssize < EVP_CIPHER_iv_length(cipher)) {
276             free(ivdata);
277             hx509_set_error_string(context, 0, HX509_PARSING_KEY_FAILED,
278                                    "Salt have wrong length in RSA key file");
279             return HX509_PARSING_KEY_FAILED;
280         }
281         
282         pw = _hx509_lock_get_passwords(lock);
283         if (pw != NULL) {
284             const void *password;
285             size_t passwordlen;
286
287             for (i = 0; i < pw->len; i++) {
288                 password = pw->val[i];
289                 passwordlen = strlen(password);
290                 
291                 ret = try_decrypt(context, c, hx509_signature_rsa(),
292                                   cipher, ivdata, password, passwordlen,
293                                   data, len);
294                 if (ret == 0) {
295                     decrypted = 1;
296                     break;
297                 }
298             }
299         }
300         if (!decrypted) {
301             hx509_prompt prompt;
302             char password[128];
303
304             memset(&prompt, 0, sizeof(prompt));
305
306             prompt.prompt = "Password for keyfile: ";
307             prompt.type = HX509_PROMPT_TYPE_PASSWORD;
308             prompt.reply.data = password;
309             prompt.reply.length = sizeof(password);
310
311             ret = hx509_lock_prompt(lock, &prompt);
312             if (ret == 0)
313                 ret = try_decrypt(context, c, hx509_signature_rsa(),
314                                   cipher, ivdata, password, strlen(password),
315                                   data, len);
316             /* XXX add password to lock password collection ? */
317             memset(password, 0, sizeof(password));
318         }
319         free(ivdata);
320
321     } else {
322         heim_octet_string keydata;
323
324         keydata.data = rk_UNCONST(data);
325         keydata.length = len;
326
327         ret = _hx509_collector_private_key_add(context,
328                                                c,
329                                                hx509_signature_rsa(),
330                                                NULL,
331                                                &keydata,
332                                                NULL);
333     }
334
335     return ret;
336 }
337
338
339 struct pem_formats {
340     const char *name;
341     int (*func)(hx509_context, const char *, struct hx509_collector *, 
342                 const struct header *, const void *, size_t);
343 } formats[] = {
344     { "CERTIFICATE", parse_certificate },
345     { "RSA PRIVATE KEY", parse_rsa_private_key }
346 };
347
348
349 static int
350 parse_pem_file(hx509_context context, 
351                const char *fn,
352                struct hx509_collector *c,
353                int *found_data)
354 {
355     struct header *headers = NULL;
356     char *type = NULL;
357     void *data = NULL;
358     size_t len = 0;
359     char buf[1024];
360     int ret;
361     FILE *f;
362
363
364     enum { BEFORE, SEARCHHEADER, INHEADER, INDATA, DONE } where;
365
366     where = BEFORE;
367     *found_data = 0;
368
369     if ((f = fopen(fn, "r")) == NULL) {
370         hx509_set_error_string(context, 0, ENOENT, 
371                                "Failed to open PEM file \"%s\": %s", 
372                                fn, strerror(errno));
373         return ENOENT;
374     }
375     ret = 0;
376
377     while (fgets(buf, sizeof(buf), f) != NULL) {
378         char *p;
379         int i;
380
381         i = strcspn(buf, "\n");
382         if (buf[i] == '\n') {
383             buf[i] = '\0';
384             if (i > 0)
385                 i--;
386         }
387         if (buf[i] == '\r') {
388             buf[i] = '\0';
389             if (i > 0)
390                 i--;
391         }
392             
393         switch (where) {
394         case BEFORE:
395             if (strncmp("-----BEGIN ", buf, 11) == 0) {
396                 type = strdup(buf + 11);
397                 if (type == NULL)
398                     break;
399                 p = strchr(type, '-');
400                 if (p)
401                     *p = '\0';
402                 *found_data = 1;
403                 where = SEARCHHEADER;
404             }
405             break;
406         case SEARCHHEADER:
407             p = strchr(buf, ':');
408             if (p == NULL) {
409                 where = INDATA;
410                 goto indata;
411             }
412             /* FALLTHOUGH */
413         case INHEADER:
414             if (buf[0] == '\0') {
415                 where = INDATA;
416                 break;
417             }
418             p = strchr(buf, ':');
419             if (p) {
420                 *p++ = '\0';
421                 while (isspace((int)*p))
422                     p++;
423                 add_headers(&headers, buf, p);
424             }
425             break;
426         case INDATA:
427         indata:
428
429             if (strncmp("-----END ", buf, 9) == 0) {
430                 where = DONE;
431                 break;
432             }
433
434             p = emalloc(i);
435             i = base64_decode(buf, p);
436             if (i < 0) {
437                 free(p);
438                 goto out;
439             }
440             
441             data = erealloc(data, len + i);
442             memcpy(((char *)data) + len, p, i);
443             free(p);
444             len += i;
445             break;
446         case DONE:
447             abort();
448         }
449
450         if (where == DONE) {
451             int j;
452
453             for (j = 0; j < sizeof(formats)/sizeof(formats[0]); j++) {
454                 const char *q = formats[j].name;
455                 if (strcasecmp(type, q) == 0) {
456                     ret = (*formats[j].func)(context, fn, c, 
457                                              headers, data, len);
458                     break;
459                 }
460             }
461             if (j == sizeof(formats)/sizeof(formats[0])) {
462                 ret = HX509_UNSUPPORTED_OPERATION;
463                 hx509_set_error_string(context, 0, ret,
464                                        "Found no matching PEM format for %s",
465                                        type);
466             }
467         out:
468             free(data);
469             data = NULL;
470             len = 0;
471             free(type);
472             type = NULL;
473             where = BEFORE;
474             free_headers(headers);
475             headers = NULL;
476             if (ret)
477                 break;
478         }
479     }
480
481     fclose(f);
482
483     if (where != BEFORE) {
484         hx509_set_error_string(context, 0, HX509_PARSING_KEY_FAILED,
485                                "File ends before end of PEM end tag");
486         ret = HX509_PARSING_KEY_FAILED;
487     }
488     if (data)
489         free(data);
490     if (type)
491         free(type);
492     if (headers)
493         free_headers(headers);
494
495     return ret;
496 }
497
498 /*
499  *
500  */
501
502 static int
503 file_init(hx509_context context,
504           hx509_certs certs, void **data, int flags, 
505           const char *residue, hx509_lock lock)
506 {
507     char *p, *pnext;
508     struct ks_file *f = NULL;
509     struct hx509_collector *c = NULL;
510     hx509_private_key *keys = NULL;
511     int ret;
512
513     *data = NULL;
514
515     if (lock == NULL)
516         lock = _hx509_empty_lock;
517
518     f = calloc(1, sizeof(*f));
519     if (f == NULL) {
520         hx509_clear_error_string(context);
521         return ENOMEM;
522     }
523
524     f->fn = strdup(residue);
525     if (f->fn == NULL) {
526         hx509_clear_error_string(context);
527         ret = ENOMEM;
528         goto out;
529     }
530
531     /* 
532      * XXX this is broken, the function should parse the file before
533      * overwriting it
534      */
535
536     if (flags & HX509_CERTS_CREATE) {
537         ret = hx509_certs_init(context, "MEMORY:ks-file-create", 
538                                0, lock, &f->certs);
539         if (ret)
540             goto out;
541         *data = f;
542         return 0;
543     }
544
545     ret = _hx509_collector_alloc(context, lock, &c);
546     if (ret)
547         goto out;
548
549     for (p = f->fn; p != NULL; p = pnext) {
550         int found_data;
551
552         pnext = strchr(p, ',');
553         if (pnext)
554             *pnext++ = '\0';
555         
556         ret = parse_pem_file(context, p, c, &found_data);
557         if (ret)
558             goto out;
559
560         if (!found_data) {
561             size_t length;
562             void *ptr;
563             int i;
564
565             ret = _hx509_map_file(p, &ptr, &length, NULL);
566             if (ret) {
567                 hx509_clear_error_string(context);
568                 goto out;
569             }
570
571             for (i = 0; i < sizeof(formats)/sizeof(formats[0]); i++) {
572                 ret = (*formats[i].func)(context, p, c, NULL, ptr, length);
573                 if (ret == 0)
574                     break;
575             }
576             _hx509_unmap_file(ptr, length);
577             if (ret)
578                 goto out;
579         }
580     }
581
582     ret = _hx509_collector_collect_certs(context, c, &f->certs);
583     if (ret)
584         goto out;
585
586     ret = _hx509_collector_collect_private_keys(context, c, &keys);
587     if (ret == 0) {
588         int i;
589
590         for (i = 0; keys[i]; i++)
591             _hx509_certs_keys_add(context, f->certs, keys[i]);
592         _hx509_certs_keys_free(context, keys);
593     }
594
595 out:
596     if (ret == 0)
597         *data = f;
598     else {
599         if (f->fn)
600             free(f->fn);
601         free(f);
602     }
603     if (c)
604         _hx509_collector_free(c);
605     return ret;
606 }
607
608 static int
609 file_free(hx509_certs certs, void *data)
610 {
611     struct ks_file *f = data;
612     hx509_certs_free(&f->certs);
613     free(f->fn);
614     free(f);
615     return 0;
616 }
617
618 static void
619 pem_header(FILE *f, const char *type, const char *str)
620 {
621     fprintf(f, "-----%s %s-----\n", type, str);
622 }
623
624 static int
625 dump_pem_file(hx509_context context, const char *header,
626               FILE *f, const void *data, size_t size)
627 {
628     const char *p = data;
629     size_t length;
630     char *line;
631
632 #define ENCODE_LINE_LENGTH      54
633     
634     pem_header(f, "BEGIN", header);
635
636     while (size > 0) {
637         ssize_t l;
638         
639         length = size;
640         if (length > ENCODE_LINE_LENGTH)
641             length = ENCODE_LINE_LENGTH;
642         
643         l = base64_encode(p, length, &line);
644         if (l < 0) {
645             hx509_set_error_string(context, 0, ENOMEM,
646                                    "malloc - out of memory");
647             return ENOMEM;
648         }
649         size -= length;
650         fprintf(f, "%s\n", line);
651         p += length;
652         free(line);
653     }
654
655     pem_header(f, "END", header);
656
657     return 0;
658 }
659
660 static int
661 store_private_key(hx509_context context, FILE *f, hx509_private_key key)
662 {
663     heim_octet_string data;
664     int ret;
665
666     ret = _hx509_private_key_export(context, key, &data);
667     if (ret == 0)
668         dump_pem_file(context, _hx509_private_pem_name(key), f,
669                       data.data, data.length);
670     free(data.data);
671     return ret;
672 }
673
674 static int
675 store_func(hx509_context context, void *ctx, hx509_cert c)
676 {
677     FILE *f = (FILE *)ctx;
678     heim_octet_string data;
679     int ret;
680
681     ret = hx509_cert_binary(context, c, &data);
682     if (ret)
683         return ret;
684     
685     dump_pem_file(context, "CERTIFICATE", f, data.data, data.length);
686     free(data.data);
687
688     if (_hx509_cert_private_key_exportable(c))
689         store_private_key(context, f, _hx509_cert_private_key(c));
690
691     return 0;
692 }
693
694 static int
695 file_store(hx509_context context, 
696            hx509_certs certs, void *data, int flags, hx509_lock lock)
697 {
698     struct ks_file *f = data;
699     FILE *fh;
700     int ret;
701
702     fh = fopen(f->fn, "w");
703     if (fh == NULL) {
704         hx509_set_error_string(context, 0, ENOENT,
705                                "Failed to open file %s for writing");
706         return ENOENT;
707     }
708
709     ret = hx509_certs_iter(context, f->certs, store_func, fh);
710     fclose(fh);
711     return ret;
712 }
713
714 static int 
715 file_add(hx509_context context, hx509_certs certs, void *data, hx509_cert c)
716 {
717     struct ks_file *f = data;
718     return hx509_certs_add(context, f->certs, c);
719 }
720
721 static int 
722 file_iter_start(hx509_context context,
723                 hx509_certs certs, void *data, void **cursor)
724 {
725     struct ks_file *f = data;
726     return hx509_certs_start_seq(context, f->certs, cursor);
727 }
728
729 static int
730 file_iter(hx509_context context,
731           hx509_certs certs, void *data, void *iter, hx509_cert *cert)
732 {
733     struct ks_file *f = data;
734     return hx509_certs_next_cert(context, f->certs, iter, cert);
735 }
736
737 static int
738 file_iter_end(hx509_context context,
739               hx509_certs certs,
740               void *data,
741               void *cursor)
742 {
743     struct ks_file *f = data;
744     return hx509_certs_end_seq(context, f->certs, cursor);
745 }
746
747 static int
748 file_getkeys(hx509_context context,
749              hx509_certs certs,
750              void *data,
751              hx509_private_key **keys)
752 {
753     struct ks_file *f = data;
754     return _hx509_certs_keys_get(context, f->certs, keys);
755 }
756
757 static int
758 file_addkey(hx509_context context,
759              hx509_certs certs,
760              void *data,
761              hx509_private_key key)
762 {
763     struct ks_file *f = data;
764     return _hx509_certs_keys_add(context, f->certs, key);
765 }
766
767 static struct hx509_keyset_ops keyset_file = {
768     "FILE",
769     0,
770     file_init,
771     file_store,
772     file_free,
773     file_add,
774     NULL,
775     file_iter_start,
776     file_iter,
777     file_iter_end,
778     NULL,
779     file_getkeys,
780     file_addkey
781 };
782
783 void
784 _hx509_ks_file_register(hx509_context context)
785 {
786     _hx509_ks_register(context, &keyset_file);
787 }