r12608: Remove some unused #include lines.
[sfrench/samba-autobuild/.git] / source4 / librpc / rpc / dcerpc_sock.c
index 4d2a66174993d3c5cffad7ad914c129079c985cd..2867a8eaccbeebf4cea4c526c2a05728d52baf52 100644 (file)
 */
 
 #include "includes.h"
-#include "dlinklist.h"
 #include "lib/events/events.h"
-#include "librpc/gen_ndr/ndr_epmapper.h"
 #include "lib/socket/socket.h"
-
-#define MIN_HDR_SIZE 16
-
-struct sock_blob {
-       struct sock_blob *next, *prev;
-       DATA_BLOB data;
-};
+#include "lib/stream/packet.h"
+#include "libcli/composite/composite.h"
 
 /* transport private information used by general socket pipe transports */
 struct sock_private {
-       struct event_context *event_ctx;
        struct fd_event *fde;
        struct socket_context *sock;
        char *server_name;
-       uint32_t port;
 
-       struct sock_blob *pending_send;
-
-       struct {
-               size_t received;
-               DATA_BLOB data;
-               uint_t pending_count;
-       } recv;
+       struct packet_context *packet;
+       uint32_t pending_reads;
 };
 
 
@@ -60,133 +46,56 @@ static void sock_dead(struct dcerpc_connection *p, NTSTATUS status)
        struct sock_private *sock = p->transport.private;
 
        if (sock && sock->sock != NULL) {
+               talloc_free(sock->fde);
                talloc_free(sock->sock);
                sock->sock = NULL;
        }
 
-       /* wipe any pending sends */
-       while (sock->pending_send) {
-               struct sock_blob *blob = sock->pending_send;
-               DLIST_REMOVE(sock->pending_send, blob);
-               talloc_free(blob);
-       }
-
        if (!NT_STATUS_IS_OK(status)) {
                p->transport.recv_data(p, NULL, status);
        }
-
-       talloc_free(sock->fde);
 }
 
+
 /*
-  process send requests
+  handle socket recv errors
 */
-static void sock_process_send(struct dcerpc_connection *p)
+static void sock_error_handler(void *private, NTSTATUS status)
 {
-       struct sock_private *sock = p->transport.private;
+       struct dcerpc_connection *p = talloc_get_type(private, 
+                                                     struct dcerpc_connection);
+       sock_dead(p, status);
+}
 
-       while (sock->pending_send) {
-               struct sock_blob *blob = sock->pending_send;
-               NTSTATUS status;
-               size_t sent;
-               status = socket_send(sock->sock, &blob->data, &sent, 0);
-               if (NT_STATUS_IS_ERR(status)) {
-                       sock_dead(p, NT_STATUS_NET_WRITE_FAULT);
-                       break;
-               }
-               if (sent == 0) {
-                       break;
-               }
-
-               blob->data.data += sent;
-               blob->data.length -= sent;
-
-               if (blob->data.length != 0) {
-                       break;
-               }
-
-               DLIST_REMOVE(sock->pending_send, blob);
-               talloc_free(blob);
+/*
+  check if a blob is a complete packet
+*/
+static NTSTATUS sock_complete_packet(void *private, DATA_BLOB blob, size_t *size)
+{
+       if (blob.length < DCERPC_FRAG_LEN_OFFSET+2) {
+               return STATUS_MORE_ENTRIES;
        }
-
-       if (sock->pending_send == NULL) {
-               EVENT_FD_NOT_WRITEABLE(sock->fde);
+       *size = dcerpc_get_frag_length(&blob);
+       if (*size > blob.length) {
+               return STATUS_MORE_ENTRIES;
        }
+       return NT_STATUS_OK;
 }
 
-
 /*
   process recv requests
 */
-static void sock_process_recv(struct dcerpc_connection *p)
+static NTSTATUS sock_process_recv(void *private, DATA_BLOB blob)
 {
+       struct dcerpc_connection *p = talloc_get_type(private, 
+                                                     struct dcerpc_connection);
        struct sock_private *sock = p->transport.private;
-       NTSTATUS status;
-       size_t nread;
-
-       if (sock->recv.data.data == NULL) {
-               sock->recv.data = data_blob_talloc(sock, NULL, MIN_HDR_SIZE);
-       }
-
-       /* read in the base header to get the fragment length */
-       if (sock->recv.received < MIN_HDR_SIZE) {
-               uint32_t frag_length;
-
-               status = socket_recv(sock->sock, 
-                                    sock->recv.data.data + sock->recv.received, 
-                                    MIN_HDR_SIZE - sock->recv.received, 
-                                    &nread, 0);
-               if (NT_STATUS_IS_ERR(status)) {
-                       sock_dead(p, NT_STATUS_NET_WRITE_FAULT);
-                       return;
-               }
-               if (nread == 0) {
-                       return;
-               }
-               
-               sock->recv.received += nread;
-
-               if (sock->recv.received != MIN_HDR_SIZE) {
-                       return;
-               }
-               frag_length = dcerpc_get_frag_length(&sock->recv.data);
-
-               sock->recv.data.data = talloc_realloc(sock, sock->recv.data.data,
-                                                     uint8_t, frag_length);
-               if (sock->recv.data.data == NULL) {
-                       sock_dead(p, NT_STATUS_NO_MEMORY);
-                       return;
-               }
-               sock->recv.data.length = frag_length;
-       }
-
-       /* read in the rest of the packet */
-       status = socket_recv(sock->sock, 
-                            sock->recv.data.data + sock->recv.received, 
-                            sock->recv.data.length - sock->recv.received, 
-                            &nread, 0);
-       if (NT_STATUS_IS_ERR(status)) {
-               sock_dead(p, NT_STATUS_NET_WRITE_FAULT);
-               return;
-       }
-       if (nread == 0) {
-               return;
-       }
-       sock->recv.received += nread;
-
-       if (sock->recv.received != sock->recv.data.length) {
-               return;
-       }
-
-       /* we have a full packet */
-       p->transport.recv_data(p, &sock->recv.data, NT_STATUS_OK);
-       talloc_free(sock->recv.data.data);
-       sock->recv.data = data_blob(NULL, 0);
-       sock->recv.received = 0;
-       sock->recv.pending_count--;
-       if (sock->recv.pending_count == 0) {
-               EVENT_FD_NOT_READABLE(sock->fde);
+       sock->pending_reads--;
+       if (sock->pending_reads == 0) {
+               packet_recv_disable(sock->packet);
        }
+       p->transport.recv_data(p, &blob, NT_STATUS_OK);
+       return NT_STATUS_OK;
 }
 
 /*
@@ -195,11 +104,12 @@ static void sock_process_recv(struct dcerpc_connection *p)
 static void sock_io_handler(struct event_context *ev, struct fd_event *fde, 
                            uint16_t flags, void *private)
 {
-       struct dcerpc_connection *p = talloc_get_type(private, struct dcerpc_connection);
+       struct dcerpc_connection *p = talloc_get_type(private, 
+                                                     struct dcerpc_connection);
        struct sock_private *sock = p->transport.private;
 
        if (flags & EVENT_FD_WRITE) {
-               sock_process_send(p);
+               packet_queue_run(sock->packet);
                return;
        }
 
@@ -208,20 +118,19 @@ static void sock_io_handler(struct event_context *ev, struct fd_event *fde,
        }
 
        if (flags & EVENT_FD_READ) {
-               sock_process_recv(p);
+               packet_recv(sock->packet);
        }
 }
 
 /* 
-   initiate a read request 
+   initiate a read request - not needed for dcerpc sockets
 */
 static NTSTATUS sock_send_read(struct dcerpc_connection *p)
 {
        struct sock_private *sock = p->transport.private;
-
-       sock->recv.pending_count++;
-       if (sock->recv.pending_count == 1) {
-               EVENT_FD_READABLE(sock->fde);
+       sock->pending_reads++;
+       if (sock->pending_reads == 1) {
+               packet_recv_enable(sock->packet);
        }
        return NT_STATUS_OK;
 }
@@ -229,30 +138,27 @@ static NTSTATUS sock_send_read(struct dcerpc_connection *p)
 /* 
    send an initial pdu in a multi-pdu sequence
 */
-static NTSTATUS sock_send_request(struct dcerpc_connection *p, DATA_BLOB *data, BOOL trigger_read)
+static NTSTATUS sock_send_request(struct dcerpc_connection *p, DATA_BLOB *data, 
+                                 BOOL trigger_read)
 {
        struct sock_private *sock = p->transport.private;
-       struct sock_blob *blob;
+       DATA_BLOB blob;
+       NTSTATUS status;
 
        if (sock->sock == NULL) {
                return NT_STATUS_CONNECTION_DISCONNECTED;
        }
 
-       blob = talloc(sock, struct sock_blob);
-       if (blob == NULL) {
+       blob = data_blob_talloc(sock->packet, data->data, data->length);
+       if (blob.data == NULL) {
                return NT_STATUS_NO_MEMORY;
        }
 
-       blob->data = data_blob_talloc(blob, data->data, data->length);
-       if (blob->data.data == NULL) {
-               talloc_free(blob);
-               return NT_STATUS_NO_MEMORY;
+       status = packet_send(sock->packet, blob);
+       if (!NT_STATUS_IS_OK(status)) {
+               return status;
        }
 
-       DLIST_ADD_END(sock->pending_send, blob, struct sock_blob *);
-
-       EVENT_FD_WRITEABLE(sock->fde);
-
        if (trigger_read) {
                sock_send_read(p);
        }
@@ -260,22 +166,16 @@ static NTSTATUS sock_send_request(struct dcerpc_connection *p, DATA_BLOB *data,
        return NT_STATUS_OK;
 }
 
-/* 
-   return the event context so the caller can process asynchronously
-*/
-static struct event_context *sock_event_context(struct dcerpc_connection *p)
-{
-       struct sock_private *sock = p->transport.private;
-
-       return sock->event_ctx;
-}
-
 /* 
    shutdown sock pipe connection
 */
 static NTSTATUS sock_shutdown_pipe(struct dcerpc_connection *p)
 {
-       sock_dead(p, NT_STATUS_OK);
+       struct sock_private *sock = p->transport.private;
+
+       if (sock && sock->sock) {
+               sock_dead(p, NT_STATUS_OK);
+       }
 
        return NT_STATUS_OK;
 }
@@ -289,70 +189,171 @@ static const char *sock_peer_name(struct dcerpc_connection *p)
        return sock->server_name;
 }
 
-/* 
-   open a rpc connection using the generic socket library
-*/
-static NTSTATUS dcerpc_pipe_open_socket(struct dcerpc_connection *c, 
-                                       const char *server,
-                                       uint32_t port, 
-                                       const char *type,
-                                       enum dcerpc_transport_t transport)
-{
-       struct sock_private *sock;
-       struct socket_context *socket_ctx;
-       NTSTATUS status;
 
-       sock = talloc(c, struct sock_private);
-       if (!sock) {
-               return NT_STATUS_NO_MEMORY;
-       }
+struct pipe_open_socket_state {
+       struct dcerpc_connection *conn;
+       struct socket_context *socket_ctx;
+       struct sock_private *sock;
+       const char *server;
+       uint32_t port;
+       enum dcerpc_transport_t transport;
+};
 
-       status = socket_create(type, SOCKET_TYPE_STREAM, &socket_ctx, 0);
-       if (!NT_STATUS_IS_OK(status)) {
-               talloc_free(sock);
-               return status;
-       }
-       talloc_steal(sock, socket_ctx);
 
-       status = socket_connect(socket_ctx, NULL, 0, server, port, 0);
-       if (!NT_STATUS_IS_OK(status)) {
-               talloc_free(sock);
-               return status;
+static void continue_socket_connect(struct composite_context *ctx)
+{
+       struct dcerpc_connection *conn;
+       struct sock_private *sock;
+       struct composite_context *c = talloc_get_type(ctx->async.private_data,
+                                                     struct composite_context);
+       struct pipe_open_socket_state *s = talloc_get_type(c->private_data,
+                                                          struct pipe_open_socket_state);
+
+       /* make it easier to write a function calls */
+       conn = s->conn;
+       sock = s->sock;
+
+       c->status = socket_connect_recv(ctx);
+       if (!NT_STATUS_IS_OK(c->status)) {
+               DEBUG(0, ("Failed to connect host %s on port %d - %s\n", s->server, s->port,
+                         nt_errstr(c->status)));
+               composite_error(c, c->status);
+               return;
        }
 
        /*
          fill in the transport methods
        */
-       c->transport.transport = transport;
-       c->transport.private = NULL;
+       conn->transport.transport     = s->transport;
+       conn->transport.private       = NULL;
+
+       conn->transport.send_request  = sock_send_request;
+       conn->transport.send_read     = sock_send_read;
+       conn->transport.recv_data     = NULL;
+
+       conn->transport.shutdown_pipe = sock_shutdown_pipe;
+       conn->transport.peer_name     = sock_peer_name;
 
-       c->transport.send_request = sock_send_request;
-       c->transport.send_read = sock_send_read;
-       c->transport.event_context = sock_event_context;
-       c->transport.recv_data = NULL;
+       sock->sock          = s->socket_ctx;
+       sock->pending_reads = 0;
+       sock->server_name   = strupper_talloc(sock, s->server);
 
-       c->transport.shutdown_pipe = sock_shutdown_pipe;
-       c->transport.peer_name = sock_peer_name;
+       sock->fde = event_add_fd(conn->event_ctx, sock->sock, socket_get_fd(sock->sock),
+                                0, sock_io_handler, conn);
        
-       sock->sock = socket_ctx;
-       sock->server_name = talloc_strdup(sock, server);
-       sock->event_ctx = event_context_init(sock);
-       sock->pending_send = NULL;
-       sock->recv.received = 0;
-       sock->recv.data = data_blob(NULL, 0);
-       sock->recv.pending_count = 0;
+       conn->transport.private = sock;
 
-       sock->fde = event_add_fd(sock->event_ctx, sock, socket_get_fd(sock->sock), 
-                                0, sock_io_handler, c);
+       sock->packet = packet_init(sock);
+       if (sock->packet == NULL) {
+               composite_error(c, NT_STATUS_NO_MEMORY);
+               talloc_free(sock);
+               return;
+       }
 
-       c->transport.private = sock;
+       packet_set_private(sock->packet, conn);
+       packet_set_socket(sock->packet, sock->sock);
+       packet_set_callback(sock->packet, sock_process_recv);
+       packet_set_full_request(sock->packet, sock_complete_packet);
+       packet_set_error_handler(sock->packet, sock_error_handler);
+       packet_set_event_context(sock->packet, conn->event_ctx);
+       packet_set_fde(sock->packet, sock->fde);
+       packet_set_serialise(sock->packet);
+       packet_recv_disable(sock->packet);
+       packet_set_initial_read(sock->packet, 16);
 
        /* ensure we don't get SIGPIPE */
        BlockSignals(True,SIGPIPE);
 
-       return NT_STATUS_OK;
+       composite_done(c);
 }
 
+
+struct composite_context *dcerpc_pipe_open_socket_send(TALLOC_CTX *mem_ctx,
+                                                      struct dcerpc_connection *cn,
+                                                      const char *server,
+                                                      uint32_t port, 
+                                                      const char *type,
+                                                      enum dcerpc_transport_t transport)
+{
+       NTSTATUS status;
+       struct composite_context *c;
+       struct pipe_open_socket_state *s;
+       struct composite_context *conn_req;
+
+       c = talloc_zero(mem_ctx, struct composite_context);
+       if (c == NULL) return NULL;
+
+       s = talloc_zero(c, struct pipe_open_socket_state);
+       if (s == NULL) {
+               composite_error(c, NT_STATUS_NO_MEMORY);
+               goto done;
+       }
+
+       c->state = COMPOSITE_STATE_IN_PROGRESS;
+       c->private_data = s;
+       c->event_ctx = cn->event_ctx;
+
+       s->conn      = cn;
+       s->transport = transport;
+       s->port      = port;
+       s->server    = talloc_strdup(c, server);
+       if (s->server == NULL) {
+               composite_error(c, NT_STATUS_NO_MEMORY);
+               goto done;
+       }
+
+       s->sock = talloc(cn, struct sock_private);
+       if (s->sock == NULL) {
+               composite_error(c, NT_STATUS_NO_MEMORY);
+               goto done;
+       }
+
+       status = socket_create(type, SOCKET_TYPE_STREAM, &s->socket_ctx, 0);
+       if (!NT_STATUS_IS_OK(status)) {
+               composite_error(c, status);
+               talloc_free(s->sock);
+               goto done;
+       }
+       talloc_steal(s->sock, s->socket_ctx);
+
+       conn_req = socket_connect_send(s->socket_ctx, NULL, 0, s->server, s->port, 0, c->event_ctx);
+       if (conn_req == NULL) {
+               composite_error(c, NT_STATUS_NO_MEMORY);
+               goto done;
+       }
+       
+       composite_continue(c, conn_req, continue_socket_connect, c);
+
+done:
+       return c;
+}
+
+
+NTSTATUS dcerpc_pipe_open_socket_recv(struct composite_context *c)
+{
+       NTSTATUS status = composite_wait(c);
+
+       talloc_free(c);
+       return status;
+}
+
+/* 
+   open a rpc connection using the generic socket library
+*/
+NTSTATUS dcerpc_pipe_open_socket(struct dcerpc_connection *conn,
+                                const char *server,
+                                uint32_t port, 
+                                const char *type,
+                                enum dcerpc_transport_t transport)
+{
+       struct composite_context *c;
+       
+       c = dcerpc_pipe_open_socket_send(conn, conn, server, port,
+                                        type, transport);
+       return dcerpc_pipe_open_socket_recv(c);
+}
+
+
 /* 
    open a rpc connection using tcp
 */