r7016: - added smb.conf parm 'web tls = true/false'
[samba.git] / source / web_server / tls.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    transport layer security handling code
5
6    Copyright (C) Andrew Tridgell 2005
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 #include "includes.h"
24 #include "smbd/service_task.h"
25 #include "smbd/service_stream.h"
26 #include "web_server/web_server.h"
27 #include "lib/events/events.h"
28 #include "system/network.h"
29
30 #if HAVE_LIBGNUTLS
31 #include "gnutls/gnutls.h"
32
33 #define DH_BITS 1024
34
35 /* hold per connection tls data */
36 struct tls_session {
37         gnutls_session session;
38         BOOL done_handshake;
39 };
40
41 /* hold persistent tls data */
42 struct tls_data {
43         gnutls_certificate_credentials x509_cred;
44         gnutls_dh_params dh_params;
45 };
46
47 /*
48   initialise global tls state
49 */
50 void tls_initialise(struct task_server *task)
51 {
52         struct esp_data *edata = talloc_get_type(task->private, struct esp_data);
53         struct tls_data *tls;
54         int ret;
55         const char *keyfile = lp_web_keyfile();
56         const char *certfile = lp_web_certfile();
57         const char *cafile = lp_web_cafile();
58         const char *crlfile = lp_web_crlfile();
59
60         if (!lp_web_tls() || keyfile == NULL || *keyfile == 0) {
61                 return;
62         }
63
64         tls = talloc_zero(edata, struct tls_data);
65         edata->tls_data = tls;
66
67         ret = gnutls_global_init();
68         if (ret < 0) goto init_failed;
69
70         gnutls_certificate_allocate_credentials(&tls->x509_cred);
71         if (ret < 0) goto init_failed;
72
73         ret = gnutls_certificate_set_x509_trust_file(tls->x509_cred, cafile, 
74                                                      GNUTLS_X509_FMT_PEM);      
75         if (ret < 0) {
76                 DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
77                 goto init_failed;
78         }
79
80         if (crlfile && *crlfile) {
81                 ret = gnutls_certificate_set_x509_crl_file(tls->x509_cred, 
82                                                            crlfile, 
83                                                            GNUTLS_X509_FMT_PEM);
84                 if (ret < 0) {
85                         DEBUG(0,("TLS failed to initialise crlfile %s\n", cafile));
86                         goto init_failed;
87                 }
88         }
89         
90         ret = gnutls_certificate_set_x509_key_file(tls->x509_cred, 
91                                                    certfile, keyfile,
92                                                    GNUTLS_X509_FMT_PEM);
93         if (ret < 0) {
94                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n", 
95                          lp_web_certfile(), lp_web_keyfile()));
96                 goto init_failed;
97         }
98         
99         ret = gnutls_dh_params_init(&tls->dh_params);
100         if (ret < 0) goto init_failed;
101
102         ret = gnutls_dh_params_generate2(tls->dh_params, DH_BITS);
103         if (ret < 0) goto init_failed;
104
105         gnutls_certificate_set_dh_params(tls->x509_cred, tls->dh_params);
106         return;
107
108 init_failed:
109         DEBUG(0,("GNUTLS failed to initialise with code %d - disabling\n", ret));
110         talloc_free(tls);
111         edata->tls_data = NULL;
112 }
113
114
115 /*
116   callback for reading from a socket
117 */
118 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
119 {
120         struct websrv_context *web = talloc_get_type(ptr, struct websrv_context);
121         NTSTATUS status;
122         size_t nread;
123         
124         if (web->input.tls_first_char) {
125                 *(uint8_t *)buf = web->input.first_byte;
126                 web->input.tls_first_char = False;
127                 return 1;
128         }
129
130         status = socket_recv(web->conn->socket, buf, size, &nread, 0);
131         if (!NT_STATUS_IS_OK(status)) {
132                 EVENT_FD_READABLE(web->conn->event.fde);
133                 EVENT_FD_NOT_WRITEABLE(web->conn->event.fde);
134                 return -1;
135         }
136         if (web->output.output_pending) {
137                 EVENT_FD_WRITEABLE(web->conn->event.fde);
138         }
139         if (size != nread) {
140                 EVENT_FD_READABLE(web->conn->event.fde);
141         }
142         return nread;
143 }
144
145 /*
146   callback for writing to a socket
147 */
148 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
149 {
150         struct websrv_context *web = talloc_get_type(ptr, struct websrv_context);
151         NTSTATUS status;
152         size_t nwritten;
153         DATA_BLOB b;
154
155         if (web->tls_session == NULL) {
156                 return size;
157         }
158
159         b.data = discard_const(buf);
160         b.length = size;
161
162         status = socket_send(web->conn->socket, &b, &nwritten, 0);
163         if (!NT_STATUS_IS_OK(status)) {
164                 EVENT_FD_WRITEABLE(web->conn->event.fde);
165                 return -1;
166         }
167         if (size != nwritten) {
168                 EVENT_FD_WRITEABLE(web->conn->event.fde);
169         }
170         return nwritten;
171 }
172
173 /*
174   destroy a tls session
175  */
176 static int tls_destructor(void *ptr)
177 {
178         struct tls_session *tls_session = talloc_get_type(ptr, struct tls_session);
179         gnutls_bye(tls_session->session, GNUTLS_SHUT_WR);
180         return 0;
181 }
182
183
184 /*
185   setup for a new connection
186 */
187 NTSTATUS tls_init_connection(struct websrv_context *web)
188 {
189         struct esp_data *edata = talloc_get_type(web->task->private, struct esp_data);
190         struct tls_data *tls_data = talloc_get_type(edata->tls_data, struct tls_data);
191         struct tls_session *tls_session;
192         int ret;
193
194         if (edata->tls_data == NULL) {
195                 web->tls_session = NULL;
196                 return NT_STATUS_OK;
197         }
198
199 #define TLSCHECK(call) do { \
200         ret = call; \
201         if (ret < 0) { \
202                 DEBUG(0,("TLS failed with code %d - %s\n", ret, #call)); \
203                 goto failed; \
204         } \
205 } while (0)
206
207         tls_session = talloc_zero(web, struct tls_session);
208         web->tls_session = tls_session;
209
210         TLSCHECK(gnutls_init(&tls_session->session, GNUTLS_SERVER));
211
212         talloc_set_destructor(tls_session, tls_destructor);
213
214         TLSCHECK(gnutls_set_default_priority(tls_session->session));
215         TLSCHECK(gnutls_credentials_set(tls_session->session, GNUTLS_CRD_CERTIFICATE, tls_data->x509_cred));
216         gnutls_certificate_server_set_request(tls_session->session, GNUTLS_CERT_REQUEST);
217         gnutls_dh_set_prime_bits(tls_session->session, DH_BITS);
218         gnutls_transport_set_ptr(tls_session->session, (gnutls_transport_ptr)web);
219         gnutls_transport_set_pull_function(tls_session->session, (gnutls_pull_func)tls_pull);
220         gnutls_transport_set_push_function(tls_session->session, (gnutls_push_func)tls_push);
221         gnutls_transport_set_lowat(tls_session->session, 0);
222
223         web->input.tls_detect = True;
224         
225         return NT_STATUS_OK;
226
227 failed:
228         web->tls_session = NULL;
229         talloc_free(tls_session);
230         return NT_STATUS_OK;
231 }
232
233 /*
234   possibly continue the handshake process
235 */
236 static NTSTATUS tls_handshake(struct tls_session *tls_session)
237 {
238         int ret;
239
240         if (tls_session->done_handshake) {
241                 return NT_STATUS_OK;
242         }
243         
244         ret = gnutls_handshake(tls_session->session);
245         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
246                 return STATUS_MORE_ENTRIES;
247         }
248         if (ret < 0) {
249                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
250         }
251         tls_session->done_handshake = True;
252         return NT_STATUS_OK;
253 }
254
255
256 /*
257   receive data either by tls or normal socket_recv
258 */
259 NTSTATUS tls_socket_recv(struct websrv_context *web, void *buf, size_t wantlen, 
260                          size_t *nread)
261 {
262         int ret;
263         NTSTATUS status;
264         struct tls_session *tls_session = talloc_get_type(web->tls_session, 
265                                                           struct tls_session);
266
267         if (web->tls_session != NULL && web->input.tls_detect) {
268                 status = socket_recv(web->conn->socket, &web->input.first_byte, 
269                                      1, nread, 0);
270                 NT_STATUS_NOT_OK_RETURN(status);
271                 if (*nread == 0) return NT_STATUS_OK;
272                 web->input.tls_detect = False;
273                 /* look for the first byte of a valid HTTP operation */
274                 if (strchr("GPHO", web->input.first_byte)) {
275                         /* not a tls link */
276                         web->tls_session = NULL;
277                         talloc_free(tls_session);
278                         *(uint8_t *)buf = web->input.first_byte;
279                         return NT_STATUS_OK;
280                 }
281                 web->input.tls_first_char = True;
282         }
283
284         if (web->tls_session == NULL) {
285                 return socket_recv(web->conn->socket, buf, wantlen, nread, 0);
286         }
287
288         status = tls_handshake(tls_session);
289         NT_STATUS_NOT_OK_RETURN(status);
290
291         ret = gnutls_record_recv(tls_session->session, buf, wantlen);
292         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
293                 return STATUS_MORE_ENTRIES;
294         }
295         if (ret < 0) {
296                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
297         }
298         *nread = ret;
299         return NT_STATUS_OK;
300 }
301
302
303 /*
304   send data either by tls or normal socket_recv
305 */
306 NTSTATUS tls_socket_send(struct websrv_context *web, const DATA_BLOB *blob, 
307                          size_t *sendlen)
308 {
309         NTSTATUS status;
310         int ret;
311         struct tls_session *tls_session = talloc_get_type(web->tls_session, 
312                                                           struct tls_session);
313
314         if (web->tls_session == NULL) {
315                 return socket_send(web->conn->socket, blob, sendlen, 0);
316         }
317
318         status = tls_handshake(tls_session);
319         NT_STATUS_NOT_OK_RETURN(status);
320
321         ret = gnutls_record_send(tls_session->session, blob->data, blob->length);
322         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
323                 return STATUS_MORE_ENTRIES;
324         }
325         if (ret < 0) {
326                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
327         }
328         *sendlen = ret;
329         return NT_STATUS_OK;
330 }
331 #else
332
333 /* for systems without tls */
334 NTSTATUS tls_socket_recv(struct websrv_context *web, void *buf, size_t wantlen, 
335                          size_t *nread)
336 {
337         return socket_recv(web->conn->socket, buf, wantlen, nread, 0);
338 }
339
340 NTSTATUS tls_socket_send(struct websrv_context *web, const DATA_BLOB *blob, 
341                          size_t *sendlen)
342 {
343         return socket_send(web->conn->socket, blob, sendlen, 0);
344 }
345
346 NTSTATUS tls_init_connection(struct websrv_context *web)
347 {
348         web->tls_session = NULL;
349         return NT_STATUS_OK;
350 }
351
352 void tls_initialise(struct task_server *task)
353 {
354         struct esp_data *edata = talloc_get_type(task->private, struct esp_data);
355         edata->tls_data = NULL;
356 }
357
358 #endif