r7742: abstracted out the tls code from the web server, so that our other servers
[kai/samba.git] / source4 / lib / tls / tls.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    transport layer security handling code
5
6    Copyright (C) Andrew Tridgell 2005
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 #include "includes.h"
24 #include "lib/events/events.h"
25 #include "lib/socket/socket.h"
26 #include "lib/tls/tls.h"
27
28 #if HAVE_LIBGNUTLS
29 #include "gnutls/gnutls.h"
30
31 #define DH_BITS 1024
32
33 /* hold persistent tls data */
34 struct tls_params {
35         gnutls_certificate_credentials x509_cred;
36         gnutls_dh_params dh_params;
37         BOOL tls_enabled;
38 };
39
40 /* hold per connection tls data */
41 struct tls_context {
42         struct tls_params *params;
43         struct socket_context *socket;
44         struct fd_event *fde;
45         gnutls_session session;
46         BOOL done_handshake;
47         BOOL have_first_byte;
48         uint8_t first_byte;
49         BOOL tls_enabled;
50         BOOL tls_detect;
51         const char *plain_chars;
52         BOOL output_pending;
53 };
54
55
56 /*
57   callback for reading from a socket
58 */
59 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
60 {
61         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
62         NTSTATUS status;
63         size_t nread;
64         
65         if (tls->have_first_byte) {
66                 *(uint8_t *)buf = tls->first_byte;
67                 tls->have_first_byte = False;
68                 return 1;
69         }
70
71         status = socket_recv(tls->socket, buf, size, &nread, 0);
72         if (!NT_STATUS_IS_OK(status)) {
73                 EVENT_FD_READABLE(tls->fde);
74                 EVENT_FD_NOT_WRITEABLE(tls->fde);
75                 return -1;
76         }
77         if (tls->output_pending) {
78                 EVENT_FD_WRITEABLE(tls->fde);
79         }
80         if (size != nread) {
81                 EVENT_FD_READABLE(tls->fde);
82         }
83         return nread;
84 }
85
86 /*
87   callback for writing to a socket
88 */
89 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
90 {
91         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
92         NTSTATUS status;
93         size_t nwritten;
94         DATA_BLOB b;
95
96         if (!tls->tls_enabled) {
97                 return size;
98         }
99
100         b.data = discard_const(buf);
101         b.length = size;
102
103         status = socket_send(tls->socket, &b, &nwritten, 0);
104         if (!NT_STATUS_IS_OK(status)) {
105                 EVENT_FD_WRITEABLE(tls->fde);
106                 return -1;
107         }
108         if (size != nwritten) {
109                 EVENT_FD_WRITEABLE(tls->fde);
110         }
111         return nwritten;
112 }
113
114 /*
115   destroy a tls session
116  */
117 static int tls_destructor(void *ptr)
118 {
119         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
120         int ret;
121         ret = gnutls_bye(tls->session, GNUTLS_SHUT_WR);
122         if (ret < 0) {
123                 DEBUG(0,("TLS gnutls_bye failed - %s\n", gnutls_strerror(ret)));
124         }
125         return 0;
126 }
127
128
129 /*
130   possibly continue the handshake process
131 */
132 static NTSTATUS tls_handshake(struct tls_context *tls)
133 {
134         int ret;
135
136         if (tls->done_handshake) {
137                 return NT_STATUS_OK;
138         }
139         
140         ret = gnutls_handshake(tls->session);
141         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
142                 return STATUS_MORE_ENTRIES;
143         }
144         if (ret < 0) {
145                 DEBUG(0,("TLS gnutls_handshake failed - %s\n", gnutls_strerror(ret)));
146                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
147         }
148         tls->done_handshake = True;
149         return NT_STATUS_OK;
150 }
151
152
153 /*
154   receive data either by tls or normal socket_recv
155 */
156 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
157                          size_t *nread)
158 {
159         int ret;
160         NTSTATUS status;
161         if (tls->tls_enabled && tls->tls_detect) {
162                 status = socket_recv(tls->socket, &tls->first_byte, 1, nread, 0);
163                 NT_STATUS_NOT_OK_RETURN(status);
164                 if (*nread == 0) return NT_STATUS_OK;
165                 tls->tls_detect = False;
166                 /* look for the first byte of a valid HTTP operation */
167                 if (strchr(tls->plain_chars, tls->first_byte)) {
168                         /* not a tls link */
169                         tls->tls_enabled = False;
170                         *(uint8_t *)buf = tls->first_byte;
171                         return NT_STATUS_OK;
172                 }
173                 tls->have_first_byte = True;
174         }
175
176         if (!tls->tls_enabled) {
177                 return socket_recv(tls->socket, buf, wantlen, nread, 0);
178         }
179
180         status = tls_handshake(tls);
181         NT_STATUS_NOT_OK_RETURN(status);
182
183         ret = gnutls_record_recv(tls->session, buf, wantlen);
184         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
185                 return STATUS_MORE_ENTRIES;
186         }
187         if (ret < 0) {
188                 DEBUG(0,("gnutls_record_recv failed - %s\n", gnutls_strerror(ret)));
189                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
190         }
191         *nread = ret;
192         return NT_STATUS_OK;
193 }
194
195
196 /*
197   send data either by tls or normal socket_recv
198 */
199 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
200 {
201         NTSTATUS status;
202         int ret;
203
204         if (!tls->tls_enabled) {
205                 return socket_send(tls->socket, blob, sendlen, 0);
206         }
207
208         status = tls_handshake(tls);
209         NT_STATUS_NOT_OK_RETURN(status);
210
211         ret = gnutls_record_send(tls->session, blob->data, blob->length);
212         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
213                 return STATUS_MORE_ENTRIES;
214         }
215         if (ret < 0) {
216                 DEBUG(0,("gnutls_record_send failed - %s\n", gnutls_strerror(ret)));
217                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
218         }
219         *sendlen = ret;
220         tls->output_pending = (ret < blob->length);
221         return NT_STATUS_OK;
222 }
223
224
225 /*
226   initialise global tls state
227 */
228 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
229 {
230         struct tls_params *params;
231         int ret;
232         const char *keyfile = lp_tls_keyfile();
233         const char *certfile = lp_tls_certfile();
234         const char *cafile = lp_tls_cafile();
235         const char *crlfile = lp_tls_crlfile();
236         void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *);
237
238         params = talloc(mem_ctx, struct tls_params);
239         if (params == NULL) return NULL;
240
241         if (!lp_tls_enabled() || keyfile == NULL || *keyfile == 0) {
242                 params->tls_enabled = False;
243                 return params;
244         }
245
246         if (!file_exist(cafile)) {
247                 tls_cert_generate(params, keyfile, certfile, cafile);
248         }
249
250         ret = gnutls_global_init();
251         if (ret < 0) goto init_failed;
252
253         gnutls_certificate_allocate_credentials(&params->x509_cred);
254         if (ret < 0) goto init_failed;
255
256         if (cafile && *cafile) {
257                 ret = gnutls_certificate_set_x509_trust_file(params->x509_cred, cafile, 
258                                                              GNUTLS_X509_FMT_PEM);      
259                 if (ret < 0) {
260                         DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
261                         goto init_failed;
262                 }
263         }
264
265         if (crlfile && *crlfile) {
266                 ret = gnutls_certificate_set_x509_crl_file(params->x509_cred, 
267                                                            crlfile, 
268                                                            GNUTLS_X509_FMT_PEM);
269                 if (ret < 0) {
270                         DEBUG(0,("TLS failed to initialise crlfile %s\n", crlfile));
271                         goto init_failed;
272                 }
273         }
274         
275         ret = gnutls_certificate_set_x509_key_file(params->x509_cred, 
276                                                    certfile, keyfile,
277                                                    GNUTLS_X509_FMT_PEM);
278         if (ret < 0) {
279                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n", 
280                          certfile, keyfile));
281                 goto init_failed;
282         }
283         
284         ret = gnutls_dh_params_init(&params->dh_params);
285         if (ret < 0) goto init_failed;
286
287         ret = gnutls_dh_params_generate2(params->dh_params, DH_BITS);
288         if (ret < 0) goto init_failed;
289
290         gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params);
291
292         params->tls_enabled = True;
293         return params;
294
295 init_failed:
296         DEBUG(0,("GNUTLS failed to initialise - %s\n", gnutls_strerror(ret)));
297         params->tls_enabled = False;
298         return params;
299 }
300
301
302 /*
303   setup for a new connection
304 */
305 struct tls_context *tls_init_server(struct tls_params *params, 
306                                     struct socket_context *socket,
307                                     struct fd_event *fde, 
308                                     const char *plain_chars)
309 {
310         struct tls_context *tls;
311         int ret;
312
313         tls = talloc(socket, struct tls_context);
314         if (tls == NULL) return NULL;
315
316         tls->socket          = socket;
317         tls->fde             = fde;
318
319         if (!params->tls_enabled) {
320                 tls->tls_enabled = False;
321                 return tls;
322         }
323
324 #define TLSCHECK(call) do { \
325         ret = call; \
326         if (ret < 0) { \
327                 DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \
328                 goto failed; \
329         } \
330 } while (0)
331
332         TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER));
333
334         talloc_set_destructor(tls, tls_destructor);
335
336         TLSCHECK(gnutls_set_default_priority(tls->session));
337         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, 
338                                         params->x509_cred));
339         gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
340         gnutls_dh_set_prime_bits(tls->session, DH_BITS);
341         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
342         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
343         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
344         gnutls_transport_set_lowat(tls->session, 0);
345
346         tls->plain_chars = plain_chars;
347         if (plain_chars) {
348                 tls->tls_detect = True;
349         } else {
350                 tls->tls_detect = False;
351         }
352
353         tls->output_pending  = False;
354         tls->params          = params;
355         tls->done_handshake  = False;
356         tls->have_first_byte = False;
357         tls->tls_enabled     = True;
358         
359         return tls;
360
361 failed:
362         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
363         tls->tls_enabled = False;
364         params->tls_enabled = False;
365         return tls;
366 }
367
368 BOOL tls_enabled(struct tls_context *tls)
369 {
370         return tls->tls_enabled;
371 }
372
373 BOOL tls_support(struct tls_params *params)
374 {
375         return params->tls_enabled;
376 }
377
378
379 #else
380
381 /* for systems without tls we just map the tls socket calls to the
382    normal socket calls */
383
384 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
385 {
386         return talloc_new(mem_ctx);
387 }
388
389 struct tls_context *tls_init_server(struct tls_params *params, 
390                                     struct socket_context *sock, 
391                                     struct fd_event *fde,
392                                     const char *plain_chars)
393 {
394         if (plain_chars == NULL) return NULL;
395         return (struct tls_context *)sock;
396 }
397
398
399 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
400                          size_t *nread)
401 {
402         return socket_recv((struct socket_context *)tls, buf, wantlen, nread, 0);
403 }
404
405 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
406 {
407         return socket_send((struct socket_context *)tls, blob, sendlen, 0);
408 }
409
410 BOOL tls_enabled(struct tls_context *tls)
411 {
412         return False;
413 }
414
415 BOOL tls_support(struct tls_params *params)
416 {
417         return False;
418 }
419
420 #endif