r7565: fixed handling of sasl data in ldap server
[jelmer/samba4-debian.git] / source / ldap_server / ldap_server.c
index 9338baa165f8597fc76be44c62676587445fc3d8..9f62d72e2c911405022aac04430660b920ad1794 100644 (file)
 */
 
 #include "includes.h"
-#include "events.h"
+#include "lib/events/events.h"
 #include "auth/auth.h"
 #include "dlinklist.h"
 #include "asn_1.h"
 #include "ldap_server/ldap_server.h"
+#include "smbd/service_stream.h"
+#include "lib/socket/socket.h"
 
 /*
   close the socket and shutdown a server_context
 */
 static void ldapsrv_terminate_connection(struct ldapsrv_connection *ldap_conn, const char *reason)
 {
-       server_terminate_connection(ldap_conn->connection, reason);
-}
-
-/*
-  add a socket address to the list of events, one event per port
-*/
-static void add_socket(struct server_service *service, 
-                      const struct model_ops *model_ops, 
-                      struct ipv4_addr *ifip)
-{
-       struct server_socket *srv_sock;
-       uint16_t port = 389;
-       char *ip_str = talloc_strdup(service, sys_inet_ntoa(*ifip));
-
-       srv_sock = service_setup_socket(service, model_ops, "ipv4", ip_str, &port);
-
-       port = 3268;
-       srv_sock = service_setup_socket(service, model_ops, "ipv4", ip_str, &port);
-
-       talloc_free(ip_str);
-}
-
-/****************************************************************************
- Open the socket communication.
-****************************************************************************/
-static void ldapsrv_init(struct server_service *service,
-                        const struct model_ops *model_ops)
-{      
-       struct ldapsrv_service *ldap_service;
-       struct ldapsrv_partition *part;
-
-       DEBUG(10,("ldapsrv_init\n"));
-
-       ldap_service = talloc_p(service, struct ldapsrv_service);
-       if (!ldap_service) {
-               DEBUG(0,("talloc_p(service, struct ldapsrv_service) failed\n"));
-               return;
-       }
-       ZERO_STRUCTP(ldap_service);
-
-       part = talloc_p(ldap_service, struct ldapsrv_partition);
-       if (!ldap_service) {
-               DEBUG(0,("talloc_p(ldap_service, struct ldapsrv_partition) failed\n"));
-               return;
-       }
-       part->base_dn = ""; /* RootDSE */
-       part->ops = ldapsrv_get_rootdse_partition_ops();
-
-       ldap_service->rootDSE = part;
-       DLIST_ADD_END(ldap_service->partitions, part, struct ldapsrv_partition *);
-
-       part = talloc_p(ldap_service, struct ldapsrv_partition);
-       if (!ldap_service) {
-               DEBUG(0,("talloc_p(ldap_service, struct ldapsrv_partition) failed\n"));
-               return;
-       }
-       part->base_dn = "*"; /* default partition */
-       part->ops = ldapsrv_get_sldb_partition_ops();
-
-       ldap_service->default_partition = part;
-       DLIST_ADD_END(ldap_service->partitions, part, struct ldapsrv_partition *);
-
-       service->private_data = ldap_service;
-
-       if (lp_interfaces() && lp_bind_interfaces_only()) {
-               int num_interfaces = iface_count();
-               int i;
-
-               /* We have been given an interfaces line, and been 
-                  told to only bind to those interfaces. Create a
-                  socket per interface and bind to only these.
-               */
-               for(i = 0; i < num_interfaces; i++) {
-                       struct ipv4_addr *ifip = iface_n_ip(i);
-
-                       if (ifip == NULL) {
-                               DEBUG(0,("ldapsrv_init: interface %d has NULL "
-                                        "IP address !\n", i));
-                               continue;
-                       }
-
-                       add_socket(service, model_ops, ifip);
-               }
-       } else {
-               struct ipv4_addr ifip;
-
-               /* Just bind to lp_socket_address() (usually 0.0.0.0) */
-               ifip = interpret_addr2(lp_socket_address());
-               add_socket(service, model_ops, &ifip);
-       }
+       stream_terminate_connection(ldap_conn->connection, reason);
 }
 
 /* This rw-buf api is made to avoid memcpy. For now do that like mad...  The
@@ -131,7 +44,7 @@ static void ldapsrv_init(struct server_service *service,
 void ldapsrv_consumed_from_buf(struct rw_buffer *buf,
                                   size_t length)
 {
-       memcpy(buf->data, buf->data+length, buf->length-length);
+       memmove(buf->data, buf->data+length, buf->length-length);
        buf->length -= length;
 }
 
@@ -173,6 +86,7 @@ static BOOL read_into_buf(struct socket_context *sock, struct rw_buffer *buf)
                talloc_free(tmp_blob.data);
                return False;
        }
+       tmp_blob.length = nread;
 
        ret = ldapsrv_append_to_buf(buf, tmp_blob.data, tmp_blob.length);
 
@@ -185,21 +99,27 @@ static BOOL ldapsrv_read_buf(struct ldapsrv_connection *conn)
 {
        NTSTATUS status;
        DATA_BLOB tmp_blob;
-       DATA_BLOB creds;
+       DATA_BLOB wrapped;
+       DATA_BLOB unwrapped;
        BOOL ret;
        uint8_t *buf;
-       int buf_length, sasl_length;
+       size_t buf_length, sasl_length;
        struct socket_context *sock = conn->connection->socket;
        TALLOC_CTX *mem_ctx;
        size_t nread;
 
-       if (!conn->gensec || !conn->session_info ||
-          !(gensec_have_feature(conn->gensec, GENSEC_WANT_SIGN) &&
-            gensec_have_feature(conn->gensec, GENSEC_WANT_SEAL))) {
+       if (!conn->gensec) {
+               return read_into_buf(sock, &conn->in_buffer);
+       }
+       if (!conn->session_info) {
+               return read_into_buf(sock, &conn->in_buffer);
+       }
+       if (!(gensec_have_feature(conn->gensec, GENSEC_FEATURE_SIGN) ||
+             gensec_have_feature(conn->gensec, GENSEC_FEATURE_SEAL))) {
                return read_into_buf(sock, &conn->in_buffer);
        }
 
-       mem_ctx = talloc(conn, 0);
+       mem_ctx = talloc_new(conn);
        if (!mem_ctx) {
                DEBUG(0,("no memory\n"));
                return False;
@@ -235,47 +155,25 @@ static BOOL ldapsrv_read_buf(struct ldapsrv_connection *conn)
 
        sasl_length = RIVAL(buf, 0);
 
-       if (buf_length < (4 + sasl_length)) {
+       if ((buf_length - 4) < sasl_length) {
                /* not enough yet */
                talloc_free(mem_ctx);
                return True;
        }
 
-       creds.data = buf + 4;
-       creds.length = gensec_sig_size(conn->gensec);
+       wrapped.data = buf + 4;
+       wrapped.length = sasl_length;
 
-       if (creds.length > sasl_length) {
-               /* invalid packet? */
+       status = gensec_unwrap(conn->gensec, mem_ctx,
+                              &wrapped, 
+                              &unwrapped);
+       if (!NT_STATUS_IS_OK(status)) {
+               DEBUG(0,("gensec_unwrap: %s\n",nt_errstr(status)));
                talloc_free(mem_ctx);
                return False;
        }
 
-       tmp_blob.data = buf + (4 + creds.length);
-       tmp_blob.length = (4 + sasl_length) - (4 + creds.length);
-
-       if (gensec_have_feature(conn->gensec, GENSEC_WANT_SEAL)) {
-               status = gensec_unseal_packet(conn->gensec, mem_ctx,
-                                             tmp_blob.data, tmp_blob.length,
-                                             tmp_blob.data, tmp_blob.length,
-                                             &creds);
-               if (!NT_STATUS_IS_OK(status)) {
-                       DEBUG(0,("gensec_unseal_packet: %s\n",nt_errstr(status)));
-                       talloc_free(mem_ctx);
-                       return False;
-               }
-       } else {
-               status = gensec_check_packet(conn->gensec, mem_ctx,
-                                             tmp_blob.data, tmp_blob.length,
-                                             tmp_blob.data, tmp_blob.length,
-                                             &creds);
-               if (!NT_STATUS_IS_OK(status)) {
-                       DEBUG(0,("gensec_check_packet: %s\n",nt_errstr(status)));
-                       talloc_free(mem_ctx);
-                       return False;
-               }
-       }
-
-       ret = ldapsrv_append_to_buf(&conn->in_buffer, tmp_blob.data, tmp_blob.length);
+       ret = ldapsrv_append_to_buf(&conn->in_buffer, unwrapped.data, unwrapped.length);
        if (!ret) {
                talloc_free(mem_ctx);
                return False;
@@ -310,72 +208,64 @@ static BOOL write_from_buf(struct socket_context *sock, struct rw_buffer *buf)
 static BOOL ldapsrv_write_buf(struct ldapsrv_connection *conn)
 {
        NTSTATUS status;
+       DATA_BLOB wrapped;
        DATA_BLOB tmp_blob;
-       DATA_BLOB creds;
        DATA_BLOB sasl;
        size_t sendlen;
        BOOL ret;
        struct socket_context *sock = conn->connection->socket;
        TALLOC_CTX *mem_ctx;
 
-       if (!conn->gensec || !conn->session_info ||
-          !(gensec_have_feature(conn->gensec, GENSEC_WANT_SIGN) &&
-            gensec_have_feature(conn->gensec, GENSEC_WANT_SEAL))) {
+
+       if (!conn->gensec) {
+               return write_from_buf(sock, &conn->out_buffer);
+       }
+       if (!conn->session_info) {
+               return write_from_buf(sock, &conn->out_buffer);
+       }
+       if (conn->sasl_out_buffer.length == 0 &&
+           !(gensec_have_feature(conn->gensec, GENSEC_FEATURE_SIGN) ||
+             gensec_have_feature(conn->gensec, GENSEC_FEATURE_SEAL))) {
                return write_from_buf(sock, &conn->out_buffer);
        }
 
-       mem_ctx = talloc(conn, 0);
+       mem_ctx = talloc_new(conn);
        if (!mem_ctx) {
                DEBUG(0,("no memory\n"));
                return False;
        }
 
-       tmp_blob.data = conn->out_buffer.data;
-       tmp_blob.length = conn->out_buffer.length;
-
-       if (tmp_blob.length == 0) {
+       if (conn->out_buffer.length == 0) {
                goto nodata;
        }
 
-       if (gensec_have_feature(conn->gensec, GENSEC_WANT_SEAL)) {
-               status = gensec_seal_packet(conn->gensec, mem_ctx,
-                                           tmp_blob.data, tmp_blob.length,
-                                           tmp_blob.data, tmp_blob.length,
-                                           &creds);
-               if (!NT_STATUS_IS_OK(status)) {
-                       DEBUG(0,("gensec_seal_packet: %s\n",nt_errstr(status)));
-                       talloc_free(mem_ctx);
-                       return False;
-               }
-       } else {
-               status = gensec_sign_packet(conn->gensec, mem_ctx,
-                                           tmp_blob.data, tmp_blob.length,
-                                           tmp_blob.data, tmp_blob.length,
-                                           &creds);
-               if (!NT_STATUS_IS_OK(status)) {
-                       DEBUG(0,("gensec_sign_packet: %s\n",nt_errstr(status)));
-                       talloc_free(mem_ctx);
-                       return False;
-               }               
+       tmp_blob.data = conn->out_buffer.data;
+       tmp_blob.length = conn->out_buffer.length;
+       status = gensec_wrap(conn->gensec, mem_ctx,
+                            &tmp_blob,
+                            &wrapped);
+       if (!NT_STATUS_IS_OK(status)) {
+               DEBUG(0,("gensec_wrap: %s\n",nt_errstr(status)));
+               talloc_free(mem_ctx);
+               return False;
        }
 
-       sasl = data_blob_talloc(mem_ctx, NULL, 4 + creds.length + tmp_blob.length);
+       sasl = data_blob_talloc(mem_ctx, NULL, 4 + wrapped.length);
        if (!sasl.data) {
                DEBUG(0,("no memory\n"));
                talloc_free(mem_ctx);
                return False;
        }
 
-       RSIVAL(sasl.data, 0, creds.length + tmp_blob.length);
-       memcpy(sasl.data + 4, creds.data, creds.length);
-       memcpy(sasl.data + 4 + creds.length, tmp_blob.data, tmp_blob.length);
+       RSIVAL(sasl.data, 0, wrapped.length);
+       memcpy(sasl.data + 4, wrapped.data, wrapped.length);
 
        ret = ldapsrv_append_to_buf(&conn->sasl_out_buffer, sasl.data, sasl.length);
        if (!ret) {
                talloc_free(mem_ctx);
                return False;
        }
-       ldapsrv_consumed_from_buf(&conn->out_buffer, tmp_blob.length);
+       ldapsrv_consumed_from_buf(&conn->out_buffer, conn->out_buffer.length);
 nodata:
        tmp_blob.data = conn->sasl_out_buffer.data;
        tmp_blob.length = conn->sasl_out_buffer.length;
@@ -440,12 +330,11 @@ NTSTATUS ldapsrv_flush_responses(struct ldapsrv_connection *conn)
 /*
   called when a LDAP socket becomes readable
 */
-static void ldapsrv_recv(struct server_connection *conn, struct timeval t,
-                        uint16_t flags)
+static void ldapsrv_recv(struct stream_connection *conn, uint16_t flags)
 {
-       struct ldapsrv_connection *ldap_conn = conn->private_data;
+       struct ldapsrv_connection *ldap_conn = talloc_get_type(conn->private, struct ldapsrv_connection);
        uint8_t *buf;
-       int buf_length, msg_length;
+       size_t buf_length, msg_length;
        DATA_BLOB blob;
        struct asn1_data data;
        struct ldapsrv_call *call;
@@ -486,7 +375,7 @@ static void ldapsrv_recv(struct server_connection *conn, struct timeval t,
                        return;
                }
 
-               call = talloc_p(ldap_conn, struct ldapsrv_call);
+               call = talloc(ldap_conn, struct ldapsrv_call);
                if (!call) {
                        ldapsrv_terminate_connection(ldap_conn, "no memory");
                        return;         
@@ -527,7 +416,7 @@ static void ldapsrv_recv(struct server_connection *conn, struct timeval t,
        }
 
        if ((ldap_conn->out_buffer.length > 0)||(ldap_conn->sasl_out_buffer.length > 0)) {
-               conn->event.fde->flags |= EVENT_FD_WRITE;
+               EVENT_FD_WRITEABLE(conn->event.fde);
        }
 
        return;
@@ -536,10 +425,9 @@ static void ldapsrv_recv(struct server_connection *conn, struct timeval t,
 /*
   called when a LDAP socket becomes writable
 */
-static void ldapsrv_send(struct server_connection *conn, struct timeval t,
-                        uint16_t flags)
+static void ldapsrv_send(struct stream_connection *conn, uint16_t flags)
 {
-       struct ldapsrv_connection *ldap_conn = conn->private_data;
+       struct ldapsrv_connection *ldap_conn = talloc_get_type(conn->private, struct ldapsrv_connection);
 
        DEBUG(10,("ldapsrv_send\n"));
 
@@ -549,76 +437,118 @@ static void ldapsrv_send(struct server_connection *conn, struct timeval t,
        }
 
        if (ldap_conn->out_buffer.length == 0 && ldap_conn->sasl_out_buffer.length == 0) {
-               conn->event.fde->flags &= ~EVENT_FD_WRITE;
+               EVENT_FD_NOT_WRITEABLE(conn->event.fde);
        }
 
        return;
 }
 
-/*
-  called when connection is idle
-*/
-static void ldapsrv_idle(struct server_connection *conn, struct timeval t)
-{
-       DEBUG(10,("ldapsrv_idle: not implemented!\n"));
-       return;
-}
-
-static void ldapsrv_close(struct server_connection *conn, const char *reason)
-{
-       return;
-}
-
 /*
   initialise a server_context from a open socket and register a event handler
   for reading from that socket
 */
-static void ldapsrv_accept(struct server_connection *conn)
+static void ldapsrv_accept(struct stream_connection *conn)
 {
        struct ldapsrv_connection *ldap_conn;
 
        DEBUG(10, ("ldapsrv_accept\n"));
 
-       ldap_conn = talloc_p(conn, struct ldapsrv_connection);
+       ldap_conn = talloc_zero(conn, struct ldapsrv_connection);
 
        if (ldap_conn == NULL)
                return;
 
-       ZERO_STRUCTP(ldap_conn);
        ldap_conn->connection = conn;
-       ldap_conn->service = talloc_reference(ldap_conn, conn->service->private_data);
-
-       conn->private_data = ldap_conn;
-
-       return;
+       ldap_conn->service = talloc_get_type(conn->private, struct ldapsrv_service);
+       conn->private = ldap_conn;
 }
 
-/*
-  called on a fatal error that should cause this server to terminate
-*/
-static void ldapsrv_exit(struct server_service *service, const char *reason)
-{
-       DEBUG(10,("ldapsrv_exit\n"));
-       return;
-}
-
-static const struct server_service_ops ldap_server_ops = {
+static const struct stream_server_ops ldap_stream_ops = {
        .name                   = "ldap",
-       .service_init           = ldapsrv_init,
        .accept_connection      = ldapsrv_accept,
        .recv_handler           = ldapsrv_recv,
        .send_handler           = ldapsrv_send,
-       .idle_handler           = ldapsrv_idle,
-       .close_connection       = ldapsrv_close,
-       .service_exit           = ldapsrv_exit, 
 };
 
-const struct server_service_ops *ldapsrv_get_ops(void)
+/*
+  add a socket address to the list of events, one event per port
+*/
+static NTSTATUS add_socket(struct event_context *event_context, const struct model_ops *model_ops,
+                          const char *address, struct ldapsrv_service *ldap_service)
 {
-       return &ldap_server_ops;
+       uint16_t port = 389;
+       NTSTATUS status;
+
+       status = stream_setup_socket(event_context, model_ops, &ldap_stream_ops, 
+                                    "ipv4", address, &port, ldap_service);
+       NT_STATUS_NOT_OK_RETURN(status);
+
+       port = 3268;
+
+       return stream_setup_socket(event_context, model_ops, &ldap_stream_ops, 
+                                  "ipv4", address, &port, ldap_service);
+}
+
+/*
+  open the ldap server sockets
+*/
+static NTSTATUS ldapsrv_init(struct event_context *event_context, const struct model_ops *model_ops)
+{      
+       struct ldapsrv_service *ldap_service;
+       struct ldapsrv_partition *rootDSE_part;
+       struct ldapsrv_partition *part;
+       NTSTATUS status;
+
+       DEBUG(10,("ldapsrv_init\n"));
+
+       ldap_service = talloc_zero(event_context, struct ldapsrv_service);
+       NT_STATUS_HAVE_NO_MEMORY(ldap_service);
+
+       rootDSE_part = talloc(ldap_service, struct ldapsrv_partition);
+       NT_STATUS_HAVE_NO_MEMORY(rootDSE_part);
+
+       rootDSE_part->base_dn = ""; /* RootDSE */
+       rootDSE_part->ops = ldapsrv_get_rootdse_partition_ops();
+
+       ldap_service->rootDSE = rootDSE_part;
+       DLIST_ADD_END(ldap_service->partitions, rootDSE_part, struct ldapsrv_partition *);
+
+       part = talloc(ldap_service, struct ldapsrv_partition);
+       NT_STATUS_HAVE_NO_MEMORY(part);
+
+       part->base_dn = "*"; /* default partition */
+       if (lp_parm_bool(-1, "ldapsrv", "hacked", False)) {
+               part->ops = ldapsrv_get_hldb_partition_ops();
+       } else {
+               part->ops = ldapsrv_get_sldb_partition_ops();
+       }
+
+       ldap_service->default_partition = part;
+       DLIST_ADD_END(ldap_service->partitions, part, struct ldapsrv_partition *);
+
+       if (lp_interfaces() && lp_bind_interfaces_only()) {
+               int num_interfaces = iface_count();
+               int i;
+
+               /* We have been given an interfaces line, and been 
+                  told to only bind to those interfaces. Create a
+                  socket per interface and bind to only these.
+               */
+               for(i = 0; i < num_interfaces; i++) {
+                       const char *address = iface_n_ip(i);
+                       status = add_socket(event_context, model_ops, address, ldap_service);
+                       NT_STATUS_NOT_OK_RETURN(status);
+               }
+       } else {
+               status = add_socket(event_context, model_ops, lp_socket_address(), ldap_service);
+               NT_STATUS_NOT_OK_RETURN(status);
+       }
+
+       return NT_STATUS_OK;
 }
 
+
 NTSTATUS server_service_ldap_init(void)
 {
-       return NT_STATUS_OK;    
+       return register_server_service("ldap", ldapsrv_init);
 }