libsocket: Add "mem_ctx" to socket_create()
[samba.git] / source4 / lib / socket / connect_multi.c
index 2b926c8bd99b62e0cec50b21fbb170bb78acf36a..b29fffb33b4009a05898b17417d1856d46cb5af4 100644 (file)
@@ -42,6 +42,8 @@ struct connect_multi_state {
        uint16_t result_port;
 
        int num_connects_sent, num_connects_recv;
+
+       struct socket_connect_multi_ex *ex;
 };
 
 /*
@@ -59,17 +61,19 @@ static void connect_multi_timer(struct tevent_context *ev,
                                    struct timeval tv, void *p);
 static void connect_multi_next_socket(struct composite_context *result);
 static void continue_one(struct composite_context *creq);
+static void continue_one_ex(struct tevent_req *subreq);
 
 /*
   setup an async socket_connect, with multiple ports
 */
-_PUBLIC_ struct composite_context *socket_connect_multi_send(
+_PUBLIC_ struct composite_context *socket_connect_multi_ex_send(
                                                    TALLOC_CTX *mem_ctx,
                                                    const char *server_name,
                                                    int num_server_ports,
                                                    uint16_t *server_ports,
                                                    struct resolve_context *resolve_ctx,
-                                                   struct tevent_context *event_ctx)
+                                                   struct tevent_context *event_ctx,
+                                                   struct socket_connect_multi_ex *ex)
 {
        struct composite_context *result;
        struct connect_multi_state *multi;
@@ -95,6 +99,8 @@ _PUBLIC_ struct composite_context *socket_connect_multi_send(
                multi->ports[i] = server_ports[i];
        }
 
+       multi->ex = ex;
+
        /*  
            we don't want to do the name resolution separately
                    for each port, so start it now, then only start on
@@ -146,8 +152,9 @@ static void connect_multi_next_socket(struct composite_context *result)
        if (composite_nomem(state, result)) return;
 
        state->result = result;
-       result->status = socket_create(multi->server_address[multi->current_address]->family,
-                                       SOCKET_TYPE_STREAM, &state->sock, 0);
+       result->status = socket_create(
+               state, multi->server_address[multi->current_address]->family,
+               SOCKET_TYPE_STREAM, &state->sock, 0);
        if (!composite_is_ok(result)) return;
 
        state->addr = socket_address_copy(state, multi->server_address[multi->current_address]);
@@ -155,8 +162,6 @@ static void connect_multi_next_socket(struct composite_context *result)
 
        socket_address_set_port(state->addr, multi->ports[multi->current_port]);
 
-       talloc_steal(state, state->sock);
-
        creq = socket_connect_send(state->sock, NULL, 
                                   state->addr, 0,
                                   result->event_ctx);
@@ -225,10 +230,61 @@ static void continue_one(struct composite_context *creq)
        struct connect_multi_state *multi = talloc_get_type(result->private_data, 
                                                            struct connect_multi_state);
        NTSTATUS status;
-       multi->num_connects_recv++;
 
        status = socket_connect_recv(creq);
 
+       if (multi->ex) {
+               struct tevent_req *subreq;
+
+               subreq = multi->ex->establish_send(state,
+                                                  result->event_ctx,
+                                                  state->sock,
+                                                  state->addr,
+                                                  multi->ex->private_data);
+               if (composite_nomem(subreq, result)) return;
+               tevent_req_set_callback(subreq, continue_one_ex, state);
+               return;
+       }
+
+       multi->num_connects_recv++;
+
+       if (NT_STATUS_IS_OK(status)) {
+               multi->sock = talloc_steal(multi, state->sock);
+               multi->result_port = state->addr->port;
+       }
+
+       talloc_free(state);
+
+       if (NT_STATUS_IS_OK(status) ||
+           multi->num_connects_recv == (multi->num_address * multi->num_ports)) {
+               result->status = status;
+               composite_done(result);
+               return;
+       }
+
+       /* try the next port */
+       connect_multi_next_socket(result);
+}
+
+/*
+  one of our multi->ex->establish_send() calls hash finished. If it got a
+  connection or there are none left then we are done
+*/
+static void continue_one_ex(struct tevent_req *subreq)
+{
+       struct connect_one_state *state =
+               tevent_req_callback_data(subreq,
+               struct connect_one_state);
+       struct composite_context *result = state->result;
+       struct connect_multi_state *multi =
+               talloc_get_type_abort(result->private_data,
+               struct connect_multi_state);
+       NTSTATUS status;
+       multi->num_connects_recv++;
+
+       status = multi->ex->establish_recv(subreq);
+       TALLOC_FREE(subreq);
+
        if (NT_STATUS_IS_OK(status)) {
                multi->sock = talloc_steal(multi, state->sock);
                multi->result_port = state->addr->port;
@@ -250,7 +306,7 @@ static void continue_one(struct composite_context *creq)
 /*
   async recv routine for socket_connect_multi()
  */
-_PUBLIC_ NTSTATUS socket_connect_multi_recv(struct composite_context *ctx,
+_PUBLIC_ NTSTATUS socket_connect_multi_ex_recv(struct composite_context *ctx,
                                   TALLOC_CTX *mem_ctx,
                                   struct socket_context **sock,
                                   uint16_t *port)
@@ -267,6 +323,55 @@ _PUBLIC_ NTSTATUS socket_connect_multi_recv(struct composite_context *ctx,
        return status;
 }
 
+NTSTATUS socket_connect_multi_ex(TALLOC_CTX *mem_ctx,
+                                const char *server_address,
+                                int num_server_ports, uint16_t *server_ports,
+                                struct resolve_context *resolve_ctx,
+                                struct tevent_context *event_ctx,
+                                struct socket_connect_multi_ex *ex,
+                                struct socket_context **result,
+                                uint16_t *result_port)
+{
+       struct composite_context *ctx =
+               socket_connect_multi_ex_send(mem_ctx, server_address,
+                                            num_server_ports, server_ports,
+                                            resolve_ctx,
+                                            event_ctx,
+                                            ex);
+       return socket_connect_multi_ex_recv(ctx, mem_ctx, result, result_port);
+}
+
+/*
+  setup an async socket_connect, with multiple ports
+*/
+_PUBLIC_ struct composite_context *socket_connect_multi_send(
+                                                   TALLOC_CTX *mem_ctx,
+                                                   const char *server_name,
+                                                   int num_server_ports,
+                                                   uint16_t *server_ports,
+                                                   struct resolve_context *resolve_ctx,
+                                                   struct tevent_context *event_ctx)
+{
+       return socket_connect_multi_ex_send(mem_ctx,
+                                           server_name,
+                                           num_server_ports,
+                                           server_ports,
+                                           resolve_ctx,
+                                           event_ctx,
+                                           NULL); /* ex */
+}
+
+/*
+  async recv routine for socket_connect_multi()
+ */
+_PUBLIC_ NTSTATUS socket_connect_multi_recv(struct composite_context *ctx,
+                                  TALLOC_CTX *mem_ctx,
+                                  struct socket_context **sock,
+                                  uint16_t *port)
+{
+       return socket_connect_multi_ex_recv(ctx, mem_ctx, sock, port);
+}
+
 NTSTATUS socket_connect_multi(TALLOC_CTX *mem_ctx,
                              const char *server_address,
                              int num_server_ports, uint16_t *server_ports,
@@ -275,10 +380,13 @@ NTSTATUS socket_connect_multi(TALLOC_CTX *mem_ctx,
                              struct socket_context **result,
                              uint16_t *result_port)
 {
-       struct composite_context *ctx =
-               socket_connect_multi_send(mem_ctx, server_address,
-                                         num_server_ports, server_ports,
-                                         resolve_ctx,
-                                         event_ctx);
-       return socket_connect_multi_recv(ctx, mem_ctx, result, result_port);
+       return socket_connect_multi_ex(mem_ctx,
+                                      server_address,
+                                      num_server_ports,
+                                      server_ports,
+                                      resolve_ctx,
+                                      event_ctx,
+                                      NULL, /* ex */
+                                      result,
+                                      result_port);
 }