r7912: make private_path() recognise a non-relative filename, so we can have
[jelmer/samba4-debian.git] / source / lib / tls / tls.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    transport layer security handling code
5
6    Copyright (C) Andrew Tridgell 2005
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 #include "includes.h"
24 #include "lib/events/events.h"
25 #include "lib/socket/socket.h"
26 #include "lib/tls/tls.h"
27
28 #if HAVE_LIBGNUTLS
29 #include "gnutls/gnutls.h"
30
31 #define DH_BITS 1024
32
33 /* hold persistent tls data */
34 struct tls_params {
35         gnutls_certificate_credentials x509_cred;
36         gnutls_dh_params dh_params;
37         BOOL tls_enabled;
38 };
39
40 /* hold per connection tls data */
41 struct tls_context {
42         struct socket_context *socket;
43         struct fd_event *fde;
44         gnutls_session session;
45         BOOL done_handshake;
46         BOOL have_first_byte;
47         uint8_t first_byte;
48         BOOL tls_enabled;
49         BOOL tls_detect;
50         const char *plain_chars;
51         BOOL output_pending;
52         gnutls_certificate_credentials xcred;
53         BOOL interrupted;
54 };
55
56 #define TLSCHECK(call) do { \
57         ret = call; \
58         if (ret < 0) { \
59                 DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \
60                 goto failed; \
61         } \
62 } while (0)
63
64
65
66 /*
67   callback for reading from a socket
68 */
69 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
70 {
71         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
72         NTSTATUS status;
73         size_t nread;
74         
75         if (tls->have_first_byte) {
76                 *(uint8_t *)buf = tls->first_byte;
77                 tls->have_first_byte = False;
78                 return 1;
79         }
80
81         status = socket_recv(tls->socket, buf, size, &nread, 0);
82         if (NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
83                 return 0;
84         }
85         if (NT_STATUS_IS_ERR(status)) {
86                 EVENT_FD_NOT_READABLE(tls->fde);
87                 EVENT_FD_NOT_WRITEABLE(tls->fde);
88                 errno = EBADF;
89                 return -1;
90         }
91         if (!NT_STATUS_IS_OK(status)) {
92                 EVENT_FD_READABLE(tls->fde);
93                 errno = EAGAIN;
94                 return -1;
95         }
96         if (tls->output_pending) {
97                 EVENT_FD_WRITEABLE(tls->fde);
98         }
99         if (size != nread) {
100                 EVENT_FD_READABLE(tls->fde);
101         }
102         return nread;
103 }
104
105 /*
106   callback for writing to a socket
107 */
108 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
109 {
110         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
111         NTSTATUS status;
112         size_t nwritten;
113         DATA_BLOB b;
114
115         if (!tls->tls_enabled) {
116                 return size;
117         }
118
119         b.data = discard_const(buf);
120         b.length = size;
121
122         status = socket_send(tls->socket, &b, &nwritten, 0);
123         if (NT_STATUS_EQUAL(status, STATUS_MORE_ENTRIES)) {
124                 errno = EAGAIN;
125                 return -1;
126         }
127         if (!NT_STATUS_IS_OK(status)) {
128                 EVENT_FD_WRITEABLE(tls->fde);
129                 return -1;
130         }
131         if (size != nwritten) {
132                 EVENT_FD_WRITEABLE(tls->fde);
133         }
134         return nwritten;
135 }
136
137 /*
138   destroy a tls session
139  */
140 static int tls_destructor(void *ptr)
141 {
142         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
143         int ret;
144         ret = gnutls_bye(tls->session, GNUTLS_SHUT_WR);
145         if (ret < 0) {
146                 DEBUG(0,("TLS gnutls_bye failed - %s\n", gnutls_strerror(ret)));
147         }
148         return 0;
149 }
150
151
152 /*
153   possibly continue the handshake process
154 */
155 static NTSTATUS tls_handshake(struct tls_context *tls)
156 {
157         int ret;
158
159         if (tls->done_handshake) {
160                 return NT_STATUS_OK;
161         }
162         
163         ret = gnutls_handshake(tls->session);
164         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
165                 if (gnutls_record_get_direction(tls->session) == 1) {
166                         EVENT_FD_WRITEABLE(tls->fde);
167                 }
168                 return STATUS_MORE_ENTRIES;
169         }
170         if (ret < 0) {
171                 DEBUG(0,("TLS gnutls_handshake failed - %s\n", gnutls_strerror(ret)));
172                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
173         }
174         tls->done_handshake = True;
175         return NT_STATUS_OK;
176 }
177
178 /*
179   possibly continue an interrupted operation
180 */
181 static NTSTATUS tls_interrupted(struct tls_context *tls)
182 {
183         int ret;
184
185         if (!tls->interrupted) {
186                 return NT_STATUS_OK;
187         }
188         if (gnutls_record_get_direction(tls->session) == 1) {
189                 ret = gnutls_record_send(tls->session, NULL, 0);
190         } else {
191                 ret = gnutls_record_recv(tls->session, NULL, 0);
192         }
193         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
194                 return STATUS_MORE_ENTRIES;
195         }
196         tls->interrupted = False;
197         return NT_STATUS_OK;
198 }
199
200 /*
201   see how many bytes are pending on the connection
202 */
203 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
204 {
205         if (!tls->tls_enabled || tls->tls_detect) {
206                 return socket_pending(tls->socket, npending);
207         }
208         *npending = gnutls_record_check_pending(tls->session);
209         if (*npending == 0) {
210                 NTSTATUS status = socket_pending(tls->socket, npending);
211                 if (*npending == 0) {
212                         /* seems to be a gnutls bug */
213                         (*npending) = 100;
214                 }
215                 return status;
216         }
217         return NT_STATUS_OK;
218 }
219
220 /*
221   receive data either by tls or normal socket_recv
222 */
223 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
224                          size_t *nread)
225 {
226         int ret;
227         NTSTATUS status;
228         if (tls->tls_enabled && tls->tls_detect) {
229                 status = socket_recv(tls->socket, &tls->first_byte, 1, nread, 0);
230                 NT_STATUS_NOT_OK_RETURN(status);
231                 if (*nread == 0) return NT_STATUS_OK;
232                 tls->tls_detect = False;
233                 /* look for the first byte of a valid HTTP operation */
234                 if (strchr(tls->plain_chars, tls->first_byte)) {
235                         /* not a tls link */
236                         tls->tls_enabled = False;
237                         *(uint8_t *)buf = tls->first_byte;
238                         return NT_STATUS_OK;
239                 }
240                 tls->have_first_byte = True;
241         }
242
243         if (!tls->tls_enabled) {
244                 return socket_recv(tls->socket, buf, wantlen, nread, 0);
245         }
246
247         status = tls_handshake(tls);
248         NT_STATUS_NOT_OK_RETURN(status);
249
250         status = tls_interrupted(tls);
251         NT_STATUS_NOT_OK_RETURN(status);
252
253         ret = gnutls_record_recv(tls->session, buf, wantlen);
254         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
255                 if (gnutls_record_get_direction(tls->session) == 1) {
256                         EVENT_FD_WRITEABLE(tls->fde);
257                 }
258                 tls->interrupted = True;
259                 return STATUS_MORE_ENTRIES;
260         }
261         if (ret < 0) {
262                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
263         }
264         *nread = ret;
265         return NT_STATUS_OK;
266 }
267
268
269 /*
270   send data either by tls or normal socket_recv
271 */
272 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
273 {
274         NTSTATUS status;
275         int ret;
276
277         if (!tls->tls_enabled) {
278                 return socket_send(tls->socket, blob, sendlen, 0);
279         }
280
281         status = tls_handshake(tls);
282         NT_STATUS_NOT_OK_RETURN(status);
283
284         status = tls_interrupted(tls);
285         NT_STATUS_NOT_OK_RETURN(status);
286
287         ret = gnutls_record_send(tls->session, blob->data, blob->length);
288         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
289                 if (gnutls_record_get_direction(tls->session) == 1) {
290                         EVENT_FD_WRITEABLE(tls->fde);
291                 }
292                 tls->interrupted = True;
293                 return STATUS_MORE_ENTRIES;
294         }
295         if (ret < 0) {
296                 DEBUG(0,("gnutls_record_send of %d failed - %s\n", blob->length, gnutls_strerror(ret)));
297                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
298         }
299         *sendlen = ret;
300         tls->output_pending = (ret < blob->length);
301         return NT_STATUS_OK;
302 }
303
304
305 /*
306   initialise global tls state
307 */
308 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
309 {
310         struct tls_params *params;
311         int ret;
312         TALLOC_CTX *tmp_ctx = talloc_new(mem_ctx);
313         const char *keyfile = private_path(tmp_ctx, lp_tls_keyfile());
314         const char *certfile = private_path(tmp_ctx, lp_tls_certfile());
315         const char *cafile = private_path(tmp_ctx, lp_tls_cafile());
316         const char *crlfile = private_path(tmp_ctx, lp_tls_crlfile());
317         void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *);
318
319         params = talloc(mem_ctx, struct tls_params);
320         if (params == NULL) {
321                 talloc_free(tmp_ctx);
322                 return NULL;
323         }
324
325         if (!lp_tls_enabled() || keyfile == NULL || *keyfile == 0) {
326                 params->tls_enabled = False;
327                 talloc_free(tmp_ctx);
328                 return params;
329         }
330
331         if (!file_exist(cafile)) {
332                 tls_cert_generate(params, keyfile, certfile, cafile);
333         }
334
335         ret = gnutls_global_init();
336         if (ret < 0) goto init_failed;
337
338         gnutls_certificate_allocate_credentials(&params->x509_cred);
339         if (ret < 0) goto init_failed;
340
341         if (cafile && *cafile) {
342                 ret = gnutls_certificate_set_x509_trust_file(params->x509_cred, cafile, 
343                                                              GNUTLS_X509_FMT_PEM);      
344                 if (ret < 0) {
345                         DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
346                         goto init_failed;
347                 }
348         }
349
350         if (crlfile && *crlfile) {
351                 ret = gnutls_certificate_set_x509_crl_file(params->x509_cred, 
352                                                            crlfile, 
353                                                            GNUTLS_X509_FMT_PEM);
354                 if (ret < 0) {
355                         DEBUG(0,("TLS failed to initialise crlfile %s\n", crlfile));
356                         goto init_failed;
357                 }
358         }
359         
360         ret = gnutls_certificate_set_x509_key_file(params->x509_cred, 
361                                                    certfile, keyfile,
362                                                    GNUTLS_X509_FMT_PEM);
363         if (ret < 0) {
364                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n", 
365                          certfile, keyfile));
366                 goto init_failed;
367         }
368         
369         ret = gnutls_dh_params_init(&params->dh_params);
370         if (ret < 0) goto init_failed;
371
372         ret = gnutls_dh_params_generate2(params->dh_params, DH_BITS);
373         if (ret < 0) goto init_failed;
374
375         gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params);
376
377         params->tls_enabled = True;
378
379         talloc_free(tmp_ctx);
380         return params;
381
382 init_failed:
383         DEBUG(0,("GNUTLS failed to initialise - %s\n", gnutls_strerror(ret)));
384         params->tls_enabled = False;
385         talloc_free(tmp_ctx);
386         return params;
387 }
388
389
390 /*
391   setup for a new connection
392 */
393 struct tls_context *tls_init_server(struct tls_params *params, 
394                                     struct socket_context *socket,
395                                     struct fd_event *fde, 
396                                     const char *plain_chars,
397                                     BOOL tls_enable)
398 {
399         struct tls_context *tls;
400         int ret;
401
402         tls = talloc(socket, struct tls_context);
403         if (tls == NULL) return NULL;
404
405         tls->socket          = socket;
406         tls->fde             = fde;
407
408         if (!params->tls_enabled || !tls_enable) {
409                 tls->tls_enabled = False;
410                 return tls;
411         }
412
413         TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER));
414
415         talloc_set_destructor(tls, tls_destructor);
416
417         TLSCHECK(gnutls_set_default_priority(tls->session));
418         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, 
419                                         params->x509_cred));
420         gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
421         gnutls_dh_set_prime_bits(tls->session, DH_BITS);
422         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
423         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
424         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
425         gnutls_transport_set_lowat(tls->session, 0);
426
427         tls->plain_chars = plain_chars;
428         if (plain_chars) {
429                 tls->tls_detect = True;
430         } else {
431                 tls->tls_detect = False;
432         }
433
434         tls->output_pending  = False;
435         tls->done_handshake  = False;
436         tls->have_first_byte = False;
437         tls->tls_enabled     = True;
438         tls->interrupted     = False;
439         
440         return tls;
441
442 failed:
443         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
444         tls->tls_enabled = False;
445         params->tls_enabled = False;
446         return tls;
447 }
448
449
450 /*
451   setup for a new client connection
452 */
453 struct tls_context *tls_init_client(struct socket_context *socket,
454                                     struct fd_event *fde, 
455                                     BOOL tls_enable)
456 {
457         struct tls_context *tls;
458         int ret;
459         const int cert_type_priority[] = { GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0 };
460         char *cafile;
461
462         tls = talloc(socket, struct tls_context);
463         if (tls == NULL) return NULL;
464
465         tls->socket          = socket;
466         tls->fde             = fde;
467         tls->tls_enabled     = tls_enable;
468
469         if (!tls->tls_enabled) {
470                 return tls;
471         }
472
473         cafile = private_path(tls, lp_tls_cafile());
474         if (!cafile || !*cafile) {
475                 goto failed;
476         }
477
478         gnutls_global_init();
479
480         gnutls_certificate_allocate_credentials(&tls->xcred);
481         gnutls_certificate_set_x509_trust_file(tls->xcred, cafile, GNUTLS_X509_FMT_PEM);
482         talloc_free(cafile);
483         TLSCHECK(gnutls_init(&tls->session, GNUTLS_CLIENT));
484         TLSCHECK(gnutls_set_default_priority(tls->session));
485         gnutls_certificate_type_set_priority(tls->session, cert_type_priority);
486         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->xcred));
487
488         talloc_set_destructor(tls, tls_destructor);
489
490         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
491         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
492         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
493         gnutls_transport_set_lowat(tls->session, 0);
494         tls->tls_detect = False;
495
496         tls->output_pending  = False;
497         tls->done_handshake  = False;
498         tls->have_first_byte = False;
499         tls->tls_enabled     = True;
500         tls->interrupted     = False;
501         
502         return tls;
503
504 failed:
505         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
506         tls->tls_enabled = False;
507         return tls;
508 }
509
510 BOOL tls_enabled(struct tls_context *tls)
511 {
512         return tls->tls_enabled;
513 }
514
515 BOOL tls_support(struct tls_params *params)
516 {
517         return params->tls_enabled;
518 }
519
520 #else
521
522 /* for systems without tls we just map the tls socket calls to the
523    normal socket calls */
524
525 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
526 {
527         return talloc_new(mem_ctx);
528 }
529
530 struct tls_context *tls_init_server(struct tls_params *params, 
531                                     struct socket_context *sock, 
532                                     struct fd_event *fde,
533                                     const char *plain_chars,
534                                     BOOL tls_enable)
535 {
536         if (tls_enable && plain_chars == NULL) return NULL;
537         return (struct tls_context *)sock;
538 }
539
540 struct tls_context *tls_init_client(struct socket_context *sock, 
541                                     struct fd_event *fde,
542                                     BOOL tls_enable)
543 {
544         return (struct tls_context *)sock;
545 }
546
547
548 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
549                          size_t *nread)
550 {
551         return socket_recv((struct socket_context *)tls, buf, wantlen, nread, 0);
552 }
553
554 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
555 {
556         return socket_send((struct socket_context *)tls, blob, sendlen, 0);
557 }
558
559 BOOL tls_enabled(struct tls_context *tls)
560 {
561         return False;
562 }
563
564 BOOL tls_support(struct tls_params *params)
565 {
566         return False;
567 }
568
569 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
570 {
571         return socket_pending((struct socket_context *)tls, npending);
572 }
573
574 #endif