r7746: - added TLS support to our ldap server
[kai/samba.git] / source / ldap_server / ldap_server.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    LDAP server
5
6    Copyright (C) Andrew Tridgell 2005
7    Copyright (C) Volker Lendecke 2004
8    Copyright (C) Stefan Metzmacher 2004
9    
10    This program is free software; you can redistribute it and/or modify
11    it under the terms of the GNU General Public License as published by
12    the Free Software Foundation; either version 2 of the License, or
13    (at your option) any later version.
14    
15    This program is distributed in the hope that it will be useful,
16    but WITHOUT ANY WARRANTY; without even the implied warranty of
17    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18    GNU General Public License for more details.
19    
20    You should have received a copy of the GNU General Public License
21    along with this program; if not, write to the Free Software
22    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
23 */
24
25 #include "includes.h"
26 #include "lib/events/events.h"
27 #include "auth/auth.h"
28 #include "dlinklist.h"
29 #include "asn_1.h"
30 #include "ldap_server/ldap_server.h"
31 #include "smbd/service_task.h"
32 #include "smbd/service_stream.h"
33 #include "lib/socket/socket.h"
34 #include "lib/tls/tls.h"
35
36 /*
37   close the socket and shutdown a server_context
38 */
39 static void ldapsrv_terminate_connection(struct ldapsrv_connection *ldap_conn, const char *reason)
40 {
41         talloc_free(ldap_conn->tls);
42         ldap_conn->tls = NULL;
43         stream_terminate_connection(ldap_conn->connection, reason);
44 }
45
46 /* This rw-buf api is made to avoid memcpy. For now do that like mad...  The
47    idea is to write into a circular list of buffers where the ideal case is
48    that a read(2) holds a complete request that is then thrown away
49    completely. */
50
51 void ldapsrv_consumed_from_buf(struct rw_buffer *buf,
52                                    size_t length)
53 {
54         memmove(buf->data, buf->data+length, buf->length-length);
55         buf->length -= length;
56 }
57
58 static void peek_into_read_buf(struct rw_buffer *buf, uint8_t **out,
59                                size_t *out_length)
60 {
61         *out = buf->data;
62         *out_length = buf->length;
63 }
64
65 BOOL ldapsrv_append_to_buf(struct rw_buffer *buf, uint8_t *data, size_t length)
66 {
67         buf->data = realloc(buf->data, buf->length+length);
68
69         if (buf->data == NULL)
70                 return False;
71
72         memcpy(buf->data+buf->length, data, length);
73
74         buf->length += length;
75
76         return True;
77 }
78
79 static BOOL read_into_buf(struct ldapsrv_connection *conn, struct rw_buffer *buf)
80 {
81         NTSTATUS status;
82         DATA_BLOB tmp_blob;
83         BOOL ret;
84         size_t nread;
85
86         tmp_blob = data_blob_talloc(conn, NULL, 1024);
87         if (tmp_blob.data == NULL) {
88                 return False;
89         }
90
91         status = tls_socket_recv(conn->tls, tmp_blob.data, tmp_blob.length, &nread);
92         if (NT_STATUS_IS_OK(status) && nread == 0) {
93                 return False;
94         }
95         if (NT_STATUS_IS_ERR(status)) {
96                 DEBUG(10,("socket_recv: %s\n",nt_errstr(status)));
97                 talloc_free(tmp_blob.data);
98                 return False;
99         }
100         if (!NT_STATUS_IS_OK(status)) {
101                 talloc_free(tmp_blob.data);
102                 return True;
103         }
104
105         tmp_blob.length = nread;
106
107         ret = ldapsrv_append_to_buf(buf, tmp_blob.data, tmp_blob.length);
108
109         talloc_free(tmp_blob.data);
110
111         return ret;
112 }
113
114 static BOOL ldapsrv_read_buf(struct ldapsrv_connection *conn)
115 {
116         NTSTATUS status;
117         DATA_BLOB tmp_blob;
118         DATA_BLOB wrapped;
119         DATA_BLOB unwrapped;
120         BOOL ret;
121         uint8_t *buf;
122         size_t buf_length, sasl_length;
123         TALLOC_CTX *mem_ctx;
124         size_t nread;
125
126         if (!conn->gensec || !conn->session_info ||
127             !(gensec_have_feature(conn->gensec, GENSEC_FEATURE_SIGN) ||
128               gensec_have_feature(conn->gensec, GENSEC_FEATURE_SEAL))) {
129                 return read_into_buf(conn, &conn->in_buffer);
130         }
131
132         mem_ctx = talloc_new(conn);
133         if (!mem_ctx) {
134                 DEBUG(0,("no memory\n"));
135                 return False;
136         }
137
138         tmp_blob = data_blob_talloc(mem_ctx, NULL, 1024);
139         if (tmp_blob.data == NULL) {
140                 talloc_free(mem_ctx);
141                 return False;
142         }
143
144         status = tls_socket_recv(conn->tls, tmp_blob.data, tmp_blob.length, &nread);
145         if (NT_STATUS_IS_OK(status) && nread == 0) {
146                 talloc_free(conn->tls);
147                 return False;
148         }
149         if (NT_STATUS_IS_ERR(status)) {
150                 DEBUG(10,("socket_recv: %s\n",nt_errstr(status)));
151                 talloc_free(mem_ctx);
152                 return False;
153         }
154         if (!NT_STATUS_IS_OK(status)) {
155                 talloc_free(mem_ctx);
156                 return True;
157         }
158         tmp_blob.length = nread;
159
160         ret = ldapsrv_append_to_buf(&conn->sasl_in_buffer, tmp_blob.data, tmp_blob.length);
161         if (!ret) {
162                 talloc_free(mem_ctx);
163                 return False;
164         }
165
166         peek_into_read_buf(&conn->sasl_in_buffer, &buf, &buf_length);
167
168         if (buf_length < 4) {
169                 /* not enough yet */
170                 talloc_free(mem_ctx);
171                 return True;
172         }
173
174         sasl_length = RIVAL(buf, 0);
175
176         if ((buf_length - 4) < sasl_length) {
177                 /* not enough yet */
178                 talloc_free(mem_ctx);
179                 return True;
180         }
181
182         wrapped.data = buf + 4;
183         wrapped.length = sasl_length;
184
185         status = gensec_unwrap(conn->gensec, mem_ctx,
186                                &wrapped, 
187                                &unwrapped);
188         if (!NT_STATUS_IS_OK(status)) {
189                 DEBUG(0,("gensec_unwrap: %s\n",nt_errstr(status)));
190                 talloc_free(mem_ctx);
191                 return False;
192         }
193
194         ret = ldapsrv_append_to_buf(&conn->in_buffer, unwrapped.data, unwrapped.length);
195         if (!ret) {
196                 talloc_free(mem_ctx);
197                 return False;
198         }
199
200         ldapsrv_consumed_from_buf(&conn->sasl_in_buffer, 4 + sasl_length);
201
202         talloc_free(mem_ctx);
203         return ret;
204 }
205
206 static BOOL write_from_buf(struct ldapsrv_connection *conn, struct rw_buffer *buf)
207 {
208         NTSTATUS status;
209         DATA_BLOB tmp_blob;
210         size_t sendlen;
211
212         tmp_blob.data = buf->data;
213         tmp_blob.length = buf->length;
214
215         status = tls_socket_send(conn->tls, &tmp_blob, &sendlen);
216         if (!NT_STATUS_IS_OK(status)) {
217                 DEBUG(10,("socket_send() %s\n",nt_errstr(status)));
218                 return False;
219         }
220
221         ldapsrv_consumed_from_buf(buf, sendlen);
222
223         return True;
224 }
225
226 static BOOL ldapsrv_write_buf(struct ldapsrv_connection *conn)
227 {
228         NTSTATUS status;
229         DATA_BLOB wrapped;
230         DATA_BLOB tmp_blob;
231         DATA_BLOB sasl;
232         size_t sendlen;
233         BOOL ret;
234         TALLOC_CTX *mem_ctx;
235
236
237         if (!conn->gensec) {
238                 return write_from_buf(conn, &conn->out_buffer);
239         }
240         if (!conn->session_info) {
241                 return write_from_buf(conn, &conn->out_buffer);
242         }
243         if (conn->sasl_out_buffer.length == 0 &&
244             !(gensec_have_feature(conn->gensec, GENSEC_FEATURE_SIGN) ||
245               gensec_have_feature(conn->gensec, GENSEC_FEATURE_SEAL))) {
246                 return write_from_buf(conn, &conn->out_buffer);
247         }
248
249         mem_ctx = talloc_new(conn);
250         if (!mem_ctx) {
251                 DEBUG(0,("no memory\n"));
252                 return False;
253         }
254
255         if (conn->out_buffer.length == 0) {
256                 goto nodata;
257         }
258
259         tmp_blob.data = conn->out_buffer.data;
260         tmp_blob.length = conn->out_buffer.length;
261         status = gensec_wrap(conn->gensec, mem_ctx,
262                              &tmp_blob,
263                              &wrapped);
264         if (!NT_STATUS_IS_OK(status)) {
265                 DEBUG(0,("gensec_wrap: %s\n",nt_errstr(status)));
266                 talloc_free(mem_ctx);
267                 return False;
268         }
269
270         sasl = data_blob_talloc(mem_ctx, NULL, 4 + wrapped.length);
271         if (!sasl.data) {
272                 DEBUG(0,("no memory\n"));
273                 talloc_free(mem_ctx);
274                 return False;
275         }
276
277         RSIVAL(sasl.data, 0, wrapped.length);
278         memcpy(sasl.data + 4, wrapped.data, wrapped.length);
279
280         ret = ldapsrv_append_to_buf(&conn->sasl_out_buffer, sasl.data, sasl.length);
281         if (!ret) {
282                 talloc_free(mem_ctx);
283                 return False;
284         }
285         ldapsrv_consumed_from_buf(&conn->out_buffer, conn->out_buffer.length);
286 nodata:
287         tmp_blob.data = conn->sasl_out_buffer.data;
288         tmp_blob.length = conn->sasl_out_buffer.length;
289
290         status = tls_socket_send(conn->tls, &tmp_blob, &sendlen);
291         if (!NT_STATUS_IS_OK(status)) {
292                 DEBUG(10,("socket_send() %s\n",nt_errstr(status)));
293                 talloc_free(mem_ctx);
294                 return False;
295         }
296
297         ldapsrv_consumed_from_buf(&conn->sasl_out_buffer, sendlen);
298
299         talloc_free(mem_ctx);
300
301         return True;
302 }
303
304 static BOOL ldap_encode_to_buf(struct ldap_message *msg, struct rw_buffer *buf)
305 {
306         DATA_BLOB blob;
307         BOOL res;
308
309         if (!ldap_encode(msg, &blob))
310                 return False;
311
312         res = ldapsrv_append_to_buf(buf, blob.data, blob.length);
313
314         data_blob_free(&blob);
315         return res;
316 }
317
318 NTSTATUS ldapsrv_do_responses(struct ldapsrv_connection *conn)
319 {
320         struct ldapsrv_call *call, *next_call = NULL;
321         struct ldapsrv_reply *reply, *next_reply = NULL;
322
323         for (call=conn->calls; call; call=next_call) {
324                 for (reply=call->replies; reply; reply=next_reply) {
325                         if (!ldap_encode_to_buf(reply->msg, &conn->out_buffer)) {
326                                 return NT_STATUS_FOOBAR;
327                         }
328                         next_reply = reply->next;
329                         DLIST_REMOVE(call->replies, reply);
330                         reply->state = LDAPSRV_REPLY_STATE_SEND;
331                         talloc_free(reply);
332                 }
333                 next_call = call->next;
334                 DLIST_REMOVE(conn->calls, call);
335                 call->state = LDAPSRV_CALL_STATE_COMPLETE;
336                 talloc_free(call);
337         }
338
339         return NT_STATUS_OK;
340 }
341
342 NTSTATUS ldapsrv_flush_responses(struct ldapsrv_connection *conn)
343 {
344         return NT_STATUS_OK;
345 }
346
347 /*
348   called when a LDAP socket becomes readable
349 */
350 static void ldapsrv_recv(struct stream_connection *conn, uint16_t flags)
351 {
352         struct ldapsrv_connection *ldap_conn = talloc_get_type(conn->private, struct ldapsrv_connection);
353         uint8_t *buf;
354         size_t buf_length;
355         struct ldapsrv_call *call;
356         NTSTATUS status;
357
358         if (!ldapsrv_read_buf(ldap_conn)) {
359                 ldapsrv_terminate_connection(ldap_conn, "ldapsrv_read_buf() failed");
360                 return;
361         }
362
363         peek_into_read_buf(&ldap_conn->in_buffer, &buf, &buf_length);
364
365         while (buf_length > 0) {
366                 DATA_BLOB blob;
367                 struct asn1_data data;
368                 struct ldap_message *msg = talloc(conn, struct ldap_message);
369
370                 blob.data = buf;
371                 blob.length = buf_length;
372
373                 if (!asn1_load(&data, blob)) {
374                         ldapsrv_terminate_connection(ldap_conn, "asn1_load() failed");
375                         return;
376                 }
377
378                 if (!ldap_decode(&data, msg)) {
379                         if (data.ofs == data.length) {
380                                 /* we don't have a complete msg yet */
381                                 talloc_free(msg);
382                                 asn1_free(&data);
383                                 return;
384                         }
385                         asn1_free(&data);
386                         talloc_free(msg);
387                         ldapsrv_terminate_connection(ldap_conn, "ldap_decode() failed");
388                         return;
389                 }
390
391                 ldapsrv_consumed_from_buf(&ldap_conn->in_buffer, data.ofs);
392                 asn1_free(&data);
393
394                 call = talloc_zero(ldap_conn, struct ldapsrv_call);
395                 if (!call) {
396                         ldapsrv_terminate_connection(ldap_conn, "no memory");
397                         return;         
398                 }
399
400                 call->request = talloc_steal(call, msg);
401                 call->state = LDAPSRV_CALL_STATE_NEW;
402                 call->conn = ldap_conn;
403
404                 DLIST_ADD_END(ldap_conn->calls, call, struct ldapsrv_call *);
405
406                 status = ldapsrv_do_call(call);
407                 if (!NT_STATUS_IS_OK(status)) {
408                         ldapsrv_terminate_connection(ldap_conn, "ldapsrv_do_call() failed");
409                         return;
410                 }
411
412                 peek_into_read_buf(&ldap_conn->in_buffer, &buf, &buf_length);
413         }
414
415         status = ldapsrv_do_responses(ldap_conn);
416         if (!NT_STATUS_IS_OK(status)) {
417                 ldapsrv_terminate_connection(ldap_conn, "ldapsrv_do_responses() failed");
418                 return;
419         }
420
421         if ((ldap_conn->out_buffer.length > 0)||(ldap_conn->sasl_out_buffer.length > 0)) {
422                 EVENT_FD_WRITEABLE(conn->event.fde);
423         }
424
425         return;
426 }
427         
428 /*
429   called when a LDAP socket becomes writable
430 */
431 static void ldapsrv_send(struct stream_connection *conn, uint16_t flags)
432 {
433         struct ldapsrv_connection *ldap_conn = talloc_get_type(conn->private, struct ldapsrv_connection);
434
435         DEBUG(10,("ldapsrv_send\n"));
436
437         if (!ldapsrv_write_buf(ldap_conn)) {
438                 ldapsrv_terminate_connection(ldap_conn, "ldapsrv_write_buf() failed");
439                 return;
440         }
441
442         if (ldap_conn->out_buffer.length == 0 && ldap_conn->sasl_out_buffer.length == 0) {
443                 EVENT_FD_NOT_WRITEABLE(conn->event.fde);
444         }
445
446         return;
447 }
448
449 /*
450   initialise a server_context from a open socket and register a event handler
451   for reading from that socket
452 */
453 static void ldapsrv_accept(struct stream_connection *conn)
454 {
455         struct ldapsrv_service *ldapsrv_service = 
456                 talloc_get_type(conn->private, struct ldapsrv_service);
457         struct ldapsrv_connection *ldap_conn;
458
459         ldap_conn = talloc_zero(conn, struct ldapsrv_connection);
460         if (ldap_conn == NULL) goto failed;
461
462         ldap_conn->connection = conn;
463         ldap_conn->service = talloc_get_type(conn->private, struct ldapsrv_service);
464         conn->private = ldap_conn;
465
466         /* note that '0' is a ASN1_SEQUENCE(0), which is the first byte on
467            any ldap connection */
468         ldap_conn->tls = tls_init_server(ldapsrv_service->tls_params, conn->socket, 
469                                          conn->event.fde, "0");
470         if (ldap_conn->tls == NULL) goto failed;
471
472         return;
473
474 failed:
475         talloc_free(conn);
476 }
477
478 static const struct stream_server_ops ldap_stream_ops = {
479         .name                   = "ldap",
480         .accept_connection      = ldapsrv_accept,
481         .recv_handler           = ldapsrv_recv,
482         .send_handler           = ldapsrv_send,
483 };
484
485 /*
486   add a socket address to the list of events, one event per port
487 */
488 static NTSTATUS add_socket(struct event_context *event_context, const struct model_ops *model_ops,
489                            const char *address, struct ldapsrv_service *ldap_service)
490 {
491         uint16_t port = 389;
492         NTSTATUS status;
493
494         status = stream_setup_socket(event_context, model_ops, &ldap_stream_ops, 
495                                      "ipv4", address, &port, ldap_service);
496         if (!NT_STATUS_IS_OK(status)) {
497                 DEBUG(0,("ldapsrv failed to bind to %s:%u - %s\n",
498                          address, port, nt_errstr(status)));
499         }
500
501         /* add ldaps server */
502         port = 636;
503         status = stream_setup_socket(event_context, model_ops, &ldap_stream_ops, 
504                                      "ipv4", address, &port, ldap_service);
505         if (!NT_STATUS_IS_OK(status)) {
506                 DEBUG(0,("ldapsrv failed to bind to %s:%u - %s\n",
507                          address, port, nt_errstr(status)));
508         }
509         return status;
510 }
511
512 /*
513   open the ldap server sockets
514 */
515 static void ldapsrv_task_init(struct task_server *task)
516 {       
517         struct ldapsrv_service *ldap_service;
518         struct ldapsrv_partition *rootDSE_part;
519         struct ldapsrv_partition *part;
520         NTSTATUS status;
521
522         ldap_service = talloc_zero(task, struct ldapsrv_service);
523         if (ldap_service == NULL) goto failed;
524
525         ldap_service->tls_params = tls_initialise(ldap_service);
526         if (ldap_service->tls_params == NULL) goto failed;
527
528         rootDSE_part = talloc(ldap_service, struct ldapsrv_partition);
529         if (rootDSE_part == NULL) goto failed;
530
531         rootDSE_part->base_dn = ""; /* RootDSE */
532         rootDSE_part->ops = ldapsrv_get_rootdse_partition_ops();
533
534         ldap_service->rootDSE = rootDSE_part;
535         DLIST_ADD_END(ldap_service->partitions, rootDSE_part, struct ldapsrv_partition *);
536
537         part = talloc(ldap_service, struct ldapsrv_partition);
538         if (part == NULL) goto failed;
539
540         part->base_dn = "*"; /* default partition */
541         if (lp_parm_bool(-1, "ldapsrv", "hacked", False)) {
542                 part->ops = ldapsrv_get_hldb_partition_ops();
543         } else {
544                 part->ops = ldapsrv_get_sldb_partition_ops();
545         }
546
547         ldap_service->default_partition = part;
548         DLIST_ADD_END(ldap_service->partitions, part, struct ldapsrv_partition *);
549
550         if (lp_interfaces() && lp_bind_interfaces_only()) {
551                 int num_interfaces = iface_count();
552                 int i;
553
554                 /* We have been given an interfaces line, and been 
555                    told to only bind to those interfaces. Create a
556                    socket per interface and bind to only these.
557                 */
558                 for(i = 0; i < num_interfaces; i++) {
559                         const char *address = iface_n_ip(i);
560                         status = add_socket(task->event_ctx, task->model_ops, address, ldap_service);
561                         if (!NT_STATUS_IS_OK(status)) goto failed;
562                 }
563         } else {
564                 status = add_socket(task->event_ctx, task->model_ops, lp_socket_address(), ldap_service);
565                 if (!NT_STATUS_IS_OK(status)) goto failed;
566         }
567
568         return;
569
570 failed:
571         task_terminate(task, "Failed to startup ldap server task");     
572 }
573
574 /*
575   called on startup of the web server service It's job is to start
576   listening on all configured sockets
577 */
578 static NTSTATUS ldapsrv_init(struct event_context *event_context, 
579                              const struct model_ops *model_ops)
580 {       
581         return task_server_startup(event_context, model_ops, ldapsrv_task_init);
582 }
583
584
585 NTSTATUS server_service_ldap_init(void)
586 {
587         return register_server_service("ldap", ldapsrv_init);
588 }