r19217: Merge from SAMBA_4_0_RELEASE:
[samba.git] / source4 / lib / tls / tls.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    transport layer security handling code
5
6    Copyright (C) Andrew Tridgell 2004-2005
7    Copyright (C) Stefan Metzmacher 2004
8    Copyright (C) Andrew Bartlett 2006
9  
10    This program is free software; you can redistribute it and/or modify
11    it under the terms of the GNU General Public License as published by
12    the Free Software Foundation; either version 2 of the License, or
13    (at your option) any later version.
14    
15    This program is distributed in the hope that it will be useful,
16    but WITHOUT ANY WARRANTY; without even the implied warranty of
17    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18    GNU General Public License for more details.
19    
20    You should have received a copy of the GNU General Public License
21    along with this program; if not, write to the Free Software
22    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
23 */
24
25 #include "includes.h"
26 #include "lib/events/events.h"
27 #include "lib/socket/socket.h"
28
29 #if ENABLE_GNUTLS
30 #include "gnutls/gnutls.h"
31
32 #define DH_BITS 1024
33
34 #if defined(HAVE_GNUTLS_DATUM) && !defined(HAVE_GNUTLS_DATUM_T)
35 typedef gnutls_datum gnutls_datum_t;
36 #endif
37
38 /* hold persistent tls data */
39 struct tls_params {
40         gnutls_certificate_credentials x509_cred;
41         gnutls_dh_params dh_params;
42         BOOL tls_enabled;
43 };
44 #endif
45
46 /* hold per connection tls data */
47 struct tls_context {
48         struct socket_context *socket;
49         struct fd_event *fde;
50         BOOL tls_enabled;
51 #if ENABLE_GNUTLS
52         gnutls_session session;
53         BOOL done_handshake;
54         BOOL have_first_byte;
55         uint8_t first_byte;
56         BOOL tls_detect;
57         const char *plain_chars;
58         BOOL output_pending;
59         gnutls_certificate_credentials xcred;
60         BOOL interrupted;
61 #endif
62 };
63
64 BOOL tls_enabled(struct socket_context *sock)
65 {
66         struct tls_context *tls;
67         if (!sock) {
68                 return False;
69         }
70         if (strcmp(sock->backend_name, "tls") != 0) {
71                 return False;
72         }
73         tls = talloc_get_type(sock->private_data, struct tls_context);
74         if (!tls) {
75                 return False;
76         }
77         return tls->tls_enabled;
78 }
79
80
81 #if ENABLE_GNUTLS
82
83 static const struct socket_ops tls_socket_ops;
84
85 static NTSTATUS tls_socket_init(struct socket_context *sock)
86 {
87         switch (sock->type) {
88         case SOCKET_TYPE_STREAM:
89                 break;
90         default:
91                 return NT_STATUS_INVALID_PARAMETER;
92         }
93
94         sock->backend_name = "tls";
95
96         return NT_STATUS_OK;
97 }
98
99 #define TLSCHECK(call) do { \
100         ret = call; \
101         if (ret < 0) { \
102                 DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \
103                 goto failed; \
104         } \
105 } while (0)
106
107
108 /*
109   callback for reading from a socket
110 */
111 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
112 {
113         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
114         NTSTATUS status;
115         size_t nread;
116         
117         if (tls->have_first_byte) {
118                 *(uint8_t *)buf = tls->first_byte;
119                 tls->have_first_byte = False;
120                 return 1;
121         }
122
123         status = socket_recv(tls->socket, buf, size, &nread);
124         if (NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
125                 return 0;
126         }
127         if (NT_STATUS_IS_ERR(status)) {
128                 EVENT_FD_NOT_READABLE(tls->fde);
129                 EVENT_FD_NOT_WRITEABLE(tls->fde);
130                 errno = EBADF;
131                 return -1;
132         }
133         if (!NT_STATUS_IS_OK(status)) {
134                 EVENT_FD_READABLE(tls->fde);
135                 errno = EAGAIN;
136                 return -1;
137         }
138         if (tls->output_pending) {
139                 EVENT_FD_WRITEABLE(tls->fde);
140         }
141         if (size != nread) {
142                 EVENT_FD_READABLE(tls->fde);
143         }
144         return nread;
145 }
146
147 /*
148   callback for writing to a socket
149 */
150 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
151 {
152         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
153         NTSTATUS status;
154         size_t nwritten;
155         DATA_BLOB b;
156
157         if (!tls->tls_enabled) {
158                 return size;
159         }
160
161         b.data = discard_const(buf);
162         b.length = size;
163
164         status = socket_send(tls->socket, &b, &nwritten);
165         if (NT_STATUS_EQUAL(status, STATUS_MORE_ENTRIES)) {
166                 errno = EAGAIN;
167                 return -1;
168         }
169         if (!NT_STATUS_IS_OK(status)) {
170                 EVENT_FD_WRITEABLE(tls->fde);
171                 return -1;
172         }
173         if (size != nwritten) {
174                 EVENT_FD_WRITEABLE(tls->fde);
175         }
176         return nwritten;
177 }
178
179 /*
180   destroy a tls session
181  */
182 static int tls_destructor(struct tls_context *tls)
183 {
184         int ret;
185         ret = gnutls_bye(tls->session, GNUTLS_SHUT_WR);
186         if (ret < 0) {
187                 DEBUG(0,("TLS gnutls_bye failed - %s\n", gnutls_strerror(ret)));
188         }
189         return 0;
190 }
191
192
193 /*
194   possibly continue the handshake process
195 */
196 static NTSTATUS tls_handshake(struct tls_context *tls)
197 {
198         int ret;
199
200         if (tls->done_handshake) {
201                 return NT_STATUS_OK;
202         }
203         
204         ret = gnutls_handshake(tls->session);
205         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
206                 if (gnutls_record_get_direction(tls->session) == 1) {
207                         EVENT_FD_WRITEABLE(tls->fde);
208                 }
209                 return STATUS_MORE_ENTRIES;
210         }
211         if (ret < 0) {
212                 DEBUG(0,("TLS gnutls_handshake failed - %s\n", gnutls_strerror(ret)));
213                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
214         }
215         tls->done_handshake = True;
216         return NT_STATUS_OK;
217 }
218
219 /*
220   possibly continue an interrupted operation
221 */
222 static NTSTATUS tls_interrupted(struct tls_context *tls)
223 {
224         int ret;
225
226         if (!tls->interrupted) {
227                 return NT_STATUS_OK;
228         }
229         if (gnutls_record_get_direction(tls->session) == 1) {
230                 ret = gnutls_record_send(tls->session, NULL, 0);
231         } else {
232                 ret = gnutls_record_recv(tls->session, NULL, 0);
233         }
234         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
235                 return STATUS_MORE_ENTRIES;
236         }
237         tls->interrupted = False;
238         return NT_STATUS_OK;
239 }
240
241 /*
242   see how many bytes are pending on the connection
243 */
244 static NTSTATUS tls_socket_pending(struct socket_context *sock, size_t *npending)
245 {
246         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
247         if (!tls->tls_enabled || tls->tls_detect) {
248                 return socket_pending(tls->socket, npending);
249         }
250         *npending = gnutls_record_check_pending(tls->session);
251         if (*npending == 0) {
252                 NTSTATUS status = socket_pending(tls->socket, npending);
253                 if (*npending == 0) {
254                         /* seems to be a gnutls bug */
255                         (*npending) = 100;
256                 }
257                 return status;
258         }
259         return NT_STATUS_OK;
260 }
261
262 /*
263   receive data either by tls or normal socket_recv
264 */
265 static NTSTATUS tls_socket_recv(struct socket_context *sock, void *buf, 
266                                 size_t wantlen, size_t *nread)
267 {
268         int ret;
269         NTSTATUS status;
270         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
271
272         if (tls->tls_enabled && tls->tls_detect) {
273                 status = socket_recv(tls->socket, &tls->first_byte, 1, nread);
274                 NT_STATUS_NOT_OK_RETURN(status);
275                 if (*nread == 0) return NT_STATUS_OK;
276                 tls->tls_detect = False;
277                 /* look for the first byte of a valid HTTP operation */
278                 if (strchr(tls->plain_chars, tls->first_byte)) {
279                         /* not a tls link */
280                         tls->tls_enabled = False;
281                         *(uint8_t *)buf = tls->first_byte;
282                         return NT_STATUS_OK;
283                 }
284                 tls->have_first_byte = True;
285         }
286
287         if (!tls->tls_enabled) {
288                 return socket_recv(tls->socket, buf, wantlen, nread);
289         }
290
291         status = tls_handshake(tls);
292         NT_STATUS_NOT_OK_RETURN(status);
293
294         status = tls_interrupted(tls);
295         NT_STATUS_NOT_OK_RETURN(status);
296
297         ret = gnutls_record_recv(tls->session, buf, wantlen);
298         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
299                 if (gnutls_record_get_direction(tls->session) == 1) {
300                         EVENT_FD_WRITEABLE(tls->fde);
301                 }
302                 tls->interrupted = True;
303                 return STATUS_MORE_ENTRIES;
304         }
305         if (ret < 0) {
306                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
307         }
308         *nread = ret;
309         return NT_STATUS_OK;
310 }
311
312
313 /*
314   send data either by tls or normal socket_recv
315 */
316 static NTSTATUS tls_socket_send(struct socket_context *sock, 
317                                 const DATA_BLOB *blob, size_t *sendlen)
318 {
319         NTSTATUS status;
320         int ret;
321         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
322
323         if (!tls->tls_enabled) {
324                 return socket_send(tls->socket, blob, sendlen);
325         }
326
327         status = tls_handshake(tls);
328         NT_STATUS_NOT_OK_RETURN(status);
329
330         status = tls_interrupted(tls);
331         NT_STATUS_NOT_OK_RETURN(status);
332
333         ret = gnutls_record_send(tls->session, blob->data, blob->length);
334         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
335                 if (gnutls_record_get_direction(tls->session) == 1) {
336                         EVENT_FD_WRITEABLE(tls->fde);
337                 }
338                 tls->interrupted = True;
339                 return STATUS_MORE_ENTRIES;
340         }
341         if (ret < 0) {
342                 DEBUG(0,("gnutls_record_send of %d failed - %s\n", (int)blob->length, gnutls_strerror(ret)));
343                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
344         }
345         *sendlen = ret;
346         tls->output_pending = (ret < blob->length);
347         return NT_STATUS_OK;
348 }
349
350
351 /*
352   initialise global tls state
353 */
354 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
355 {
356         struct tls_params *params;
357         int ret;
358         TALLOC_CTX *tmp_ctx = talloc_new(mem_ctx);
359         const char *keyfile = private_path(tmp_ctx, lp_tls_keyfile());
360         const char *certfile = private_path(tmp_ctx, lp_tls_certfile());
361         const char *cafile = private_path(tmp_ctx, lp_tls_cafile());
362         const char *crlfile = private_path(tmp_ctx, lp_tls_crlfile());
363         const char *dhpfile = private_path(tmp_ctx, lp_tls_dhpfile());
364         void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *);
365
366         params = talloc(mem_ctx, struct tls_params);
367         if (params == NULL) {
368                 talloc_free(tmp_ctx);
369                 return NULL;
370         }
371
372         if (!lp_tls_enabled() || keyfile == NULL || *keyfile == 0) {
373                 params->tls_enabled = False;
374                 talloc_free(tmp_ctx);
375                 return params;
376         }
377
378         if (!file_exist(cafile)) {
379                 tls_cert_generate(params, keyfile, certfile, cafile);
380         }
381
382         ret = gnutls_global_init();
383         if (ret < 0) goto init_failed;
384
385         gnutls_certificate_allocate_credentials(&params->x509_cred);
386         if (ret < 0) goto init_failed;
387
388         if (cafile && *cafile) {
389                 ret = gnutls_certificate_set_x509_trust_file(params->x509_cred, cafile, 
390                                                              GNUTLS_X509_FMT_PEM);      
391                 if (ret < 0) {
392                         DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
393                         goto init_failed;
394                 }
395         }
396
397         if (crlfile && *crlfile) {
398                 ret = gnutls_certificate_set_x509_crl_file(params->x509_cred, 
399                                                            crlfile, 
400                                                            GNUTLS_X509_FMT_PEM);
401                 if (ret < 0) {
402                         DEBUG(0,("TLS failed to initialise crlfile %s\n", crlfile));
403                         goto init_failed;
404                 }
405         }
406         
407         ret = gnutls_certificate_set_x509_key_file(params->x509_cred, 
408                                                    certfile, keyfile,
409                                                    GNUTLS_X509_FMT_PEM);
410         if (ret < 0) {
411                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n", 
412                          certfile, keyfile));
413                 goto init_failed;
414         }
415         
416         
417         ret = gnutls_dh_params_init(&params->dh_params);
418         if (ret < 0) goto init_failed;
419
420         if (dhpfile && *dhpfile) {
421                 gnutls_datum_t dhparms;
422                 size_t size;
423                 dhparms.data = (uint8_t *)file_load(dhpfile, &size, mem_ctx);
424
425                 if (!dhparms.data) {
426                         DEBUG(0,("Failed to read DH Parms from %s\n", dhpfile));
427                         goto init_failed;
428                 }
429                 dhparms.size = size;
430                         
431                 ret = gnutls_dh_params_import_pkcs3(params->dh_params, &dhparms, GNUTLS_X509_FMT_PEM);
432                 if (ret < 0) goto init_failed;
433         } else {
434                 ret = gnutls_dh_params_generate2(params->dh_params, DH_BITS);
435                 if (ret < 0) goto init_failed;
436         }
437                 
438         gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params);
439
440         params->tls_enabled = True;
441
442         talloc_free(tmp_ctx);
443         return params;
444
445 init_failed:
446         DEBUG(0,("GNUTLS failed to initialise - %s\n", gnutls_strerror(ret)));
447         params->tls_enabled = False;
448         talloc_free(tmp_ctx);
449         return params;
450 }
451
452
453 /*
454   setup for a new connection
455 */
456 struct socket_context *tls_init_server(struct tls_params *params, 
457                                        struct socket_context *socket,
458                                        struct fd_event *fde, 
459                                        const char *plain_chars)
460 {
461         struct tls_context *tls;
462         int ret;
463         struct socket_context *new_sock;
464         NTSTATUS nt_status;
465         
466         nt_status = socket_create_with_ops(socket, &tls_socket_ops, &new_sock, 
467                                            SOCKET_TYPE_STREAM, 
468                                            socket->flags | SOCKET_FLAG_ENCRYPT);
469         if (!NT_STATUS_IS_OK(nt_status)) {
470                 return NULL;
471         }
472
473         tls = talloc(new_sock, struct tls_context);
474         if (tls == NULL) {
475                 return NULL;
476         }
477
478         tls->socket          = socket;
479         tls->fde             = fde;
480         if (talloc_reference(tls, fde) == NULL) {
481                 talloc_free(new_sock);
482                 return NULL;
483         }
484         if (talloc_reference(tls, socket) == NULL) {
485                 talloc_free(new_sock);
486                 return NULL;
487         }
488
489         new_sock->private_data    = tls;
490
491         if (!params->tls_enabled) {
492                 talloc_free(new_sock);
493                 return NULL;
494         }
495
496         TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER));
497
498         talloc_set_destructor(tls, tls_destructor);
499
500         TLSCHECK(gnutls_set_default_priority(tls->session));
501         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, 
502                                         params->x509_cred));
503         gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
504         gnutls_dh_set_prime_bits(tls->session, DH_BITS);
505         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
506         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
507         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
508         gnutls_transport_set_lowat(tls->session, 0);
509
510         tls->plain_chars = plain_chars;
511         if (plain_chars) {
512                 tls->tls_detect = True;
513         } else {
514                 tls->tls_detect = False;
515         }
516
517         tls->output_pending  = False;
518         tls->done_handshake  = False;
519         tls->have_first_byte = False;
520         tls->tls_enabled     = True;
521         tls->interrupted     = False;
522         
523         new_sock->state = SOCKET_STATE_SERVER_CONNECTED;
524
525         return new_sock;
526
527 failed:
528         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
529         talloc_free(new_sock);
530         return NULL;
531 }
532
533
534 /*
535   setup for a new client connection
536 */
537 struct socket_context *tls_init_client(struct socket_context *socket,
538                                        struct fd_event *fde)
539 {
540         struct tls_context *tls;
541         int ret = 0;
542         const int cert_type_priority[] = { GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0 };
543         char *cafile;
544         struct socket_context *new_sock;
545         NTSTATUS nt_status;
546         
547         nt_status = socket_create_with_ops(socket, &tls_socket_ops, &new_sock, 
548                                            SOCKET_TYPE_STREAM, 
549                                            socket->flags | SOCKET_FLAG_ENCRYPT);
550         if (!NT_STATUS_IS_OK(nt_status)) {
551                 return NULL;
552         }
553
554         tls = talloc(new_sock, struct tls_context);
555         if (tls == NULL) return NULL;
556
557         tls->socket          = socket;
558         tls->fde             = fde;
559         if (talloc_reference(tls, fde) == NULL) {
560                 return NULL;
561         }
562         if (talloc_reference(tls, socket) == NULL) {
563                 return NULL;
564         }
565         new_sock->private_data    = tls;
566
567         cafile = private_path(tls, lp_tls_cafile());
568         if (!cafile || !*cafile) {
569                 goto failed;
570         }
571
572         gnutls_global_init();
573
574         gnutls_certificate_allocate_credentials(&tls->xcred);
575         gnutls_certificate_set_x509_trust_file(tls->xcred, cafile, GNUTLS_X509_FMT_PEM);
576         talloc_free(cafile);
577         TLSCHECK(gnutls_init(&tls->session, GNUTLS_CLIENT));
578         TLSCHECK(gnutls_set_default_priority(tls->session));
579         gnutls_certificate_type_set_priority(tls->session, cert_type_priority);
580         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->xcred));
581
582         talloc_set_destructor(tls, tls_destructor);
583
584         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
585         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
586         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
587         gnutls_transport_set_lowat(tls->session, 0);
588         tls->tls_detect = False;
589
590         tls->output_pending  = False;
591         tls->done_handshake  = False;
592         tls->have_first_byte = False;
593         tls->tls_enabled     = True;
594         tls->interrupted     = False;
595         
596         new_sock->state = SOCKET_STATE_CLIENT_CONNECTED;
597
598         return new_sock;
599
600 failed:
601         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
602         tls->tls_enabled = False;
603         return new_sock;
604 }
605
606 static NTSTATUS tls_socket_set_option(struct socket_context *sock, const char *option, const char *val)
607 {
608         set_socket_options(socket_get_fd(sock), option);
609         return NT_STATUS_OK;
610 }
611
612 static char *tls_socket_get_peer_name(struct socket_context *sock, TALLOC_CTX *mem_ctx)
613 {
614         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
615         return socket_get_peer_name(tls->socket, mem_ctx);
616 }
617
618 static struct socket_address *tls_socket_get_peer_addr(struct socket_context *sock, TALLOC_CTX *mem_ctx)
619 {
620         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
621         return socket_get_peer_addr(tls->socket, mem_ctx);
622 }
623
624 static struct socket_address *tls_socket_get_my_addr(struct socket_context *sock, TALLOC_CTX *mem_ctx)
625 {
626         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
627         return socket_get_my_addr(tls->socket, mem_ctx);
628 }
629
630 static int tls_socket_get_fd(struct socket_context *sock)
631 {
632         struct tls_context *tls = talloc_get_type(sock->private_data, struct tls_context);
633         return socket_get_fd(tls->socket);
634 }
635
636 static const struct socket_ops tls_socket_ops = {
637         .name                   = "tls",
638         .fn_init                = tls_socket_init,
639         .fn_recv                = tls_socket_recv,
640         .fn_send                = tls_socket_send,
641         .fn_pending             = tls_socket_pending,
642
643         .fn_set_option          = tls_socket_set_option,
644
645         .fn_get_peer_name       = tls_socket_get_peer_name,
646         .fn_get_peer_addr       = tls_socket_get_peer_addr,
647         .fn_get_my_addr         = tls_socket_get_my_addr,
648         .fn_get_fd              = tls_socket_get_fd
649 };
650
651 BOOL tls_support(struct tls_params *params)
652 {
653         return params->tls_enabled;
654 }
655
656 #else
657
658 /* for systems without tls we just fail the operations, and the caller
659  * will retain the original socket */
660
661 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
662 {
663         return talloc_new(mem_ctx);
664 }
665
666 /*
667   setup for a new connection
668 */
669 struct socket_context *tls_init_server(struct tls_params *params, 
670                                     struct socket_context *socket,
671                                     struct fd_event *fde, 
672                                     const char *plain_chars)
673 {
674         return NULL;
675 }
676
677
678 /*
679   setup for a new client connection
680 */
681 struct socket_context *tls_init_client(struct socket_context *socket,
682                                        struct fd_event *fde)
683 {
684         return NULL;
685 }
686
687 BOOL tls_support(struct tls_params *params)
688 {
689         return False;
690 }
691
692 #endif
693