Split up async_req into a generic and a NTSTATUS specific part
[tprouty/samba.git] / source3 / libsmb / async_smb.c
index d371e057e370f7d2ab837183114763669ab98e85..e579d1c9f0bcc32b7cdfda27ae2fbe6390f721ea 100644 (file)
@@ -127,7 +127,7 @@ static char *cli_request_print(TALLOC_CTX *mem_ctx, struct async_req *req)
 static int cli_request_destructor(struct cli_request *req)
 {
        if (req->enc_state != NULL) {
 static int cli_request_destructor(struct cli_request *req)
 {
        if (req->enc_state != NULL) {
-               common_free_enc_buffer(req->enc_state, req->outbuf);
+               common_free_enc_buffer(req->enc_state, (char *)req->outbuf);
        }
        DLIST_REMOVE(req->cli->outstanding_requests, req);
        if (req->cli->outstanding_requests == NULL) {
        }
        DLIST_REMOVE(req->cli->outstanding_requests, req);
        if (req->cli->outstanding_requests == NULL) {
@@ -151,32 +151,6 @@ bool cli_in_chain(struct cli_state *cli)
        return (cli->chain_accumulator->num_async != 0);
 }
 
        return (cli->chain_accumulator->num_async != 0);
 }
 
-/**
- * Is the SMB command able to hold an AND_X successor
- * @param[in] cmd      The SMB command in question
- * @retval Can we add a chained request after "cmd"?
- */
-
-static bool is_andx_req(uint8_t cmd)
-{
-       switch (cmd) {
-       case SMBtconX:
-       case SMBlockingX:
-       case SMBopenX:
-       case SMBreadX:
-       case SMBwriteX:
-       case SMBsesssetupX:
-       case SMBulogoffX:
-       case SMBntcreateX:
-               return true;
-               break;
-       default:
-               break;
-       }
-
-       return false;
-}
-
 /**
  * @brief Find the smb_cmd offset of the last command pushed
  * @param[in] buf      The buffer we're building up
 /**
  * @brief Find the smb_cmd offset of the last command pushed
  * @param[in] buf      The buffer we're building up
@@ -187,7 +161,7 @@ static bool is_andx_req(uint8_t cmd)
  * to the chain. Find the offset to the place where we have to put our cmd.
  */
 
  * to the chain. Find the offset to the place where we have to put our cmd.
  */
 
-static bool find_andx_cmd_ofs(char *buf, size_t *pofs)
+static bool find_andx_cmd_ofs(uint8_t *buf, size_t *pofs)
 {
        uint8_t cmd;
        size_t ofs;
 {
        uint8_t cmd;
        size_t ofs;
@@ -217,6 +191,140 @@ static bool find_andx_cmd_ofs(char *buf, size_t *pofs)
        return true;
 }
 
        return true;
 }
 
+/**
+ * @brief Do the smb chaining at a buffer level
+ * @param[in] poutbuf          Pointer to the talloc'ed buffer to be modified
+ * @param[in] smb_command      The command that we want to issue
+ * @param[in] wct              How many words?
+ * @param[in] vwv              The words, already in network order
+ * @param[in] bytes_alignment  How shall we align "bytes"?
+ * @param[in] num_bytes                How many bytes?
+ * @param[in] bytes            The data the request ships
+ *
+ * smb_splice_chain() adds the vwv and bytes to the request already present in
+ * *poutbuf.
+ */
+
+bool smb_splice_chain(uint8_t **poutbuf, uint8_t smb_command,
+                     uint8_t wct, const uint16_t *vwv,
+                     size_t bytes_alignment,
+                     uint32_t num_bytes, const uint8_t *bytes)
+{
+       uint8_t *outbuf;
+       size_t old_size, new_size;
+       size_t ofs;
+       size_t chain_padding = 0;
+       size_t bytes_padding = 0;
+       bool first_request;
+
+       old_size = talloc_get_size(*poutbuf);
+
+       /*
+        * old_size == smb_wct means we're pushing the first request in for
+        * libsmb/
+        */
+
+       first_request = (old_size == smb_wct);
+
+       if (!first_request && ((old_size % 4) != 0)) {
+               /*
+                * Align the wct field of subsequent requests to a 4-byte
+                * boundary
+                */
+               chain_padding = 4 - (old_size % 4);
+       }
+
+       /*
+        * After the old request comes the new wct field (1 byte), the vwv's
+        * and the num_bytes field. After at we might need to align the bytes
+        * given to us to "bytes_alignment", increasing the num_bytes value.
+        */
+
+       new_size = old_size + chain_padding + 1 + wct * sizeof(uint16_t) + 2;
+
+       if ((bytes_alignment != 0) && ((new_size % bytes_alignment) != 0)) {
+               bytes_padding = bytes_alignment - (new_size % bytes_alignment);
+       }
+
+       new_size += bytes_padding + num_bytes;
+
+       if ((smb_command != SMBwriteX) && (new_size > 0xffff)) {
+               DEBUG(1, ("splice_chain: %u bytes won't fit\n",
+                         (unsigned)new_size));
+               return false;
+       }
+
+       outbuf = TALLOC_REALLOC_ARRAY(NULL, *poutbuf, uint8_t, new_size);
+       if (outbuf == NULL) {
+               DEBUG(0, ("talloc failed\n"));
+               return false;
+       }
+       *poutbuf = outbuf;
+
+       if (first_request) {
+               SCVAL(outbuf, smb_com, smb_command);
+       } else {
+               size_t andx_cmd_ofs;
+
+               if (!find_andx_cmd_ofs(outbuf, &andx_cmd_ofs)) {
+                       DEBUG(1, ("invalid command chain\n"));
+                       *poutbuf = TALLOC_REALLOC_ARRAY(
+                               NULL, *poutbuf, uint8_t, old_size);
+                       return false;
+               }
+
+               if (chain_padding != 0) {
+                       memset(outbuf + old_size, 0, chain_padding);
+                       old_size += chain_padding;
+               }
+
+               SCVAL(outbuf, andx_cmd_ofs, smb_command);
+               SSVAL(outbuf, andx_cmd_ofs + 2, old_size - 4);
+       }
+
+       ofs = old_size;
+
+       /*
+        * Push the chained request:
+        *
+        * wct field
+        */
+
+       SCVAL(outbuf, ofs, wct);
+       ofs += 1;
+
+       /*
+        * vwv array
+        */
+
+       memcpy(outbuf + ofs, vwv, sizeof(uint16_t) * wct);
+       ofs += sizeof(uint16_t) * wct;
+
+       /*
+        * bcc (byte count)
+        */
+
+       SSVAL(outbuf, ofs, num_bytes + bytes_padding);
+       ofs += sizeof(uint16_t);
+
+       /*
+        * padding
+        */
+
+       if (bytes_padding != 0) {
+               memset(outbuf + ofs, 0, bytes_padding);
+               ofs += bytes_padding;
+       }
+
+       /*
+        * The bytes field
+        */
+
+       memcpy(outbuf + ofs, bytes, num_bytes);
+
+       return true;
+}
+
 /**
  * @brief Destroy an async_req that is the visible part of a cli_request
  * @param[in] req      The request to kill
 /**
  * @brief Destroy an async_req that is the visible part of a cli_request
  * @param[in] req      The request to kill
@@ -267,6 +375,7 @@ static int cli_async_req_destructor(struct async_req *req)
  * @param[in] additional_flags open_and_x wants to add oplock header flags
  * @param[in] wct              How many words?
  * @param[in] vwv              The words, already in network order
  * @param[in] additional_flags open_and_x wants to add oplock header flags
  * @param[in] wct              How many words?
  * @param[in] vwv              The words, already in network order
+ * @param[in] bytes_alignment  How shall we align "bytes"?
  * @param[in] num_bytes                How many bytes?
  * @param[in] bytes            The data the request ships
  *
  * @param[in] num_bytes                How many bytes?
  * @param[in] bytes            The data the request ships
  *
@@ -282,14 +391,12 @@ static struct async_req *cli_request_chain(TALLOC_CTX *mem_ctx,
                                           uint8_t smb_command,
                                           uint8_t additional_flags,
                                           uint8_t wct, const uint16_t *vwv,
                                           uint8_t smb_command,
                                           uint8_t additional_flags,
                                           uint8_t wct, const uint16_t *vwv,
-                                          uint16_t num_bytes,
+                                          size_t bytes_alignment,
+                                          uint32_t num_bytes,
                                           const uint8_t *bytes)
 {
        struct async_req **tmp_reqs;
                                           const uint8_t *bytes)
 {
        struct async_req **tmp_reqs;
-       char *tmp_buf;
        struct cli_request *req;
        struct cli_request *req;
-       size_t old_size, new_size;
-       size_t ofs;
 
        req = cli->chain_accumulator;
 
 
        req = cli->chain_accumulator;
 
@@ -302,7 +409,7 @@ static struct async_req *cli_request_chain(TALLOC_CTX *mem_ctx,
        req->async = tmp_reqs;
        req->num_async += 1;
 
        req->async = tmp_reqs;
        req->num_async += 1;
 
-       req->async[req->num_async-1] = async_req_new(mem_ctx, ev);
+       req->async[req->num_async-1] = async_req_new(mem_ctx);
        if (req->async[req->num_async-1] == NULL) {
                DEBUG(0, ("async_req_new failed\n"));
                req->num_async -= 1;
        if (req->async[req->num_async-1] == NULL) {
                DEBUG(0, ("async_req_new failed\n"));
                req->num_async -= 1;
@@ -313,51 +420,10 @@ static struct async_req *cli_request_chain(TALLOC_CTX *mem_ctx,
        talloc_set_destructor(req->async[req->num_async-1],
                              cli_async_req_destructor);
 
        talloc_set_destructor(req->async[req->num_async-1],
                              cli_async_req_destructor);
 
-       old_size = talloc_get_size(req->outbuf);
-
-       /*
-        * We need space for the wct field, the words, the byte count field
-        * and the bytes themselves.
-        */
-       new_size = old_size + 1 + wct * sizeof(uint16_t) + 2 + num_bytes;
-
-       if (new_size > 0xffff) {
-               DEBUG(1, ("cli_request_chain: %u bytes won't fit\n",
-                         (unsigned)new_size));
-               goto fail;
-       }
-
-       tmp_buf = TALLOC_REALLOC_ARRAY(NULL, req->outbuf, char, new_size);
-       if (tmp_buf == NULL) {
-               DEBUG(0, ("talloc failed\n"));
+       if (!smb_splice_chain(&req->outbuf, smb_command, wct, vwv,
+                             bytes_alignment, num_bytes, bytes)) {
                goto fail;
        }
                goto fail;
        }
-       req->outbuf = tmp_buf;
-
-       if (old_size == smb_wct) {
-               SCVAL(req->outbuf, smb_com, smb_command);
-       } else {
-               size_t andx_cmd_ofs;
-               if (!find_andx_cmd_ofs(req->outbuf, &andx_cmd_ofs)) {
-                       DEBUG(1, ("invalid command chain\n"));
-                       goto fail;
-               }
-               SCVAL(req->outbuf, andx_cmd_ofs, smb_command);
-               SSVAL(req->outbuf, andx_cmd_ofs + 2, old_size - 4);
-       }
-
-       ofs = old_size;
-
-       SCVAL(req->outbuf, ofs, wct);
-       ofs += 1;
-
-       memcpy(req->outbuf + ofs, vwv, sizeof(uint16_t) * wct);
-       ofs += sizeof(uint16_t) * wct;
-
-       SSVAL(req->outbuf, ofs, num_bytes);
-       ofs += sizeof(uint16_t);
-
-       memcpy(req->outbuf + ofs, bytes, num_bytes);
 
        return req->async[req->num_async-1];
 
 
        return req->async[req->num_async-1];
 
@@ -421,11 +487,12 @@ bool cli_chain_cork(struct cli_state *cli, struct event_context *ev,
        if (size_hint == 0) {
                size_hint = 100;
        }
        if (size_hint == 0) {
                size_hint = 100;
        }
-       req->outbuf = talloc_array(req, char, smb_wct + size_hint);
+       req->outbuf = talloc_array(req, uint8_t, smb_wct + size_hint);
        if (req->outbuf == NULL) {
                goto fail;
        }
        if (req->outbuf == NULL) {
                goto fail;
        }
-       req->outbuf = TALLOC_REALLOC_ARRAY(NULL, req->outbuf, char, smb_wct);
+       req->outbuf = TALLOC_REALLOC_ARRAY(NULL, req->outbuf, uint8_t,
+                                          smb_wct);
 
        req->num_async = 0;
        req->async = NULL;
 
        req->num_async = 0;
        req->async = NULL;
@@ -434,7 +501,7 @@ bool cli_chain_cork(struct cli_state *cli, struct event_context *ev,
        req->recv_helper.fn = NULL;
 
        SSVAL(req->outbuf, smb_tid, cli->cnum);
        req->recv_helper.fn = NULL;
 
        SSVAL(req->outbuf, smb_tid, cli->cnum);
-       cli_setup_packet_buf(cli, req->outbuf);
+       cli_setup_packet_buf(cli, (char *)req->outbuf);
 
        req->mid = cli_new_mid(cli);
 
 
        req->mid = cli_new_mid(cli);
 
@@ -459,6 +526,7 @@ bool cli_chain_cork(struct cli_state *cli, struct event_context *ev,
 void cli_chain_uncork(struct cli_state *cli)
 {
        struct cli_request *req = cli->chain_accumulator;
 void cli_chain_uncork(struct cli_state *cli)
 {
        struct cli_request *req = cli->chain_accumulator;
+       size_t smblen;
 
        SMB_ASSERT(req != NULL);
 
 
        SMB_ASSERT(req != NULL);
 
@@ -468,22 +536,35 @@ void cli_chain_uncork(struct cli_state *cli)
        cli->chain_accumulator = NULL;
 
        SSVAL(req->outbuf, smb_mid, req->mid);
        cli->chain_accumulator = NULL;
 
        SSVAL(req->outbuf, smb_mid, req->mid);
-       smb_setlen(req->outbuf, talloc_get_size(req->outbuf) - 4);
 
 
-       cli_calculate_sign_mac(cli, req->outbuf);
+       smblen = talloc_get_size(req->outbuf) - 4;
+
+       smb_setlen((char *)req->outbuf, smblen);
+
+       if (smblen > 0x1ffff) {
+               /*
+                * This is a POSIX 14 word large write. Overwrite just the
+                * size field, the '0xFFSMB' has been set by smb_setlen which
+                * _smb_setlen_large does not do.
+                */
+               _smb_setlen_large(((char *)req->outbuf), smblen);
+       }
+
+       cli_calculate_sign_mac(cli, (char *)req->outbuf);
 
        if (cli_encryption_on(cli)) {
                NTSTATUS status;
                char *enc_buf;
 
 
        if (cli_encryption_on(cli)) {
                NTSTATUS status;
                char *enc_buf;
 
-               status = cli_encrypt_message(cli, req->outbuf, &enc_buf);
+               status = cli_encrypt_message(cli, (char *)req->outbuf,
+                                            &enc_buf);
                if (!NT_STATUS_IS_OK(status)) {
                        DEBUG(0, ("Error in encrypting client message. "
                                  "Error %s\n", nt_errstr(status)));
                        TALLOC_FREE(req);
                        return;
                }
                if (!NT_STATUS_IS_OK(status)) {
                        DEBUG(0, ("Error in encrypting client message. "
                                  "Error %s\n", nt_errstr(status)));
                        TALLOC_FREE(req);
                        return;
                }
-               req->outbuf = enc_buf;
+               req->outbuf = (uint8_t *)enc_buf;
                req->enc_state = cli->trans_enc_state;
        }
 
                req->enc_state = cli->trans_enc_state;
        }
 
@@ -501,6 +582,7 @@ void cli_chain_uncork(struct cli_state *cli)
  * @param[in] additional_flags open_and_x wants to add oplock header flags
  * @param[in] wct              How many words?
  * @param[in] vwv              The words, already in network order
  * @param[in] additional_flags open_and_x wants to add oplock header flags
  * @param[in] wct              How many words?
  * @param[in] vwv              The words, already in network order
+ * @param[in] bytes_alignment  How shall we align "bytes"?
  * @param[in] num_bytes                How many bytes?
  * @param[in] bytes            The data the request ships
  *
  * @param[in] num_bytes                How many bytes?
  * @param[in] bytes            The data the request ships
  *
@@ -513,7 +595,8 @@ struct async_req *cli_request_send(TALLOC_CTX *mem_ctx,
                                   uint8_t smb_command,
                                   uint8_t additional_flags,
                                   uint8_t wct, const uint16_t *vwv,
                                   uint8_t smb_command,
                                   uint8_t additional_flags,
                                   uint8_t wct, const uint16_t *vwv,
-                                  uint16_t num_bytes, const uint8_t *bytes)
+                                  size_t bytes_alignment,
+                                  uint32_t num_bytes, const uint8_t *bytes)
 {
        struct async_req *result;
        bool uncork = false;
 {
        struct async_req *result;
        bool uncork = false;
@@ -528,7 +611,7 @@ struct async_req *cli_request_send(TALLOC_CTX *mem_ctx,
        }
 
        result = cli_request_chain(mem_ctx, ev, cli, smb_command,
        }
 
        result = cli_request_chain(mem_ctx, ev, cli, smb_command,
-                                  additional_flags, wct, vwv,
+                                  additional_flags, wct, vwv, bytes_alignment,
                                   num_bytes, bytes);
 
        if (result == NULL) {
                                   num_bytes, bytes);
 
        if (result == NULL) {
@@ -542,6 +625,37 @@ struct async_req *cli_request_send(TALLOC_CTX *mem_ctx,
        return result;
 }
 
        return result;
 }
 
+/**
+ * Calculate the current ofs to wct for requests like write&x
+ * @param[in] req      The smb request we're currently building
+ * @retval how many bytes offset have we accumulated?
+ */
+
+uint16_t cli_wct_ofs(const struct cli_state *cli)
+{
+       size_t buf_size;
+
+       if (cli->chain_accumulator == NULL) {
+               return smb_wct - 4;
+       }
+
+       buf_size = talloc_get_size(cli->chain_accumulator->outbuf);
+
+       if (buf_size == smb_wct) {
+               return smb_wct - 4;
+       }
+
+       /*
+        * Add alignment for subsequent requests
+        */
+
+       if ((buf_size % 4) != 0) {
+               buf_size += (4 - (buf_size % 4));
+       }
+
+       return buf_size - 4;
+}
+
 /**
  * Figure out if there is an andx command behind the current one
  * @param[in] buf      The smb buffer to look at
 /**
  * Figure out if there is an andx command behind the current one
  * @param[in] buf      The smb buffer to look at
@@ -865,7 +979,7 @@ static void handle_incoming_pdu(struct cli_state *cli)
                   nt_errstr(status)));
 
        for (req = cli->outstanding_requests; req; req = req->next) {
                   nt_errstr(status)));
 
        for (req = cli->outstanding_requests; req; req = req->next) {
-               async_req_error(req->async[0], status);
+               async_req_nterror(req->async[0], status);
        }
        return;
 }
        }
        return;
 }
@@ -882,11 +996,44 @@ static void cli_state_handler(struct event_context *event_ctx,
                              struct fd_event *event, uint16 flags, void *p)
 {
        struct cli_state *cli = (struct cli_state *)p;
                              struct fd_event *event, uint16 flags, void *p)
 {
        struct cli_state *cli = (struct cli_state *)p;
-       struct cli_request *req;
+       struct cli_request *req, *next;
        NTSTATUS status;
 
        DEBUG(11, ("cli_state_handler called with flags %d\n", flags));
 
        NTSTATUS status;
 
        DEBUG(11, ("cli_state_handler called with flags %d\n", flags));
 
+       if (flags & EVENT_FD_WRITE) {
+               size_t to_send;
+               ssize_t sent;
+
+               for (req = cli->outstanding_requests; req; req = req->next) {
+                       to_send = smb_len(req->outbuf)+4;
+                       if (to_send > req->sent) {
+                               break;
+                       }
+               }
+
+               if (req == NULL) {
+                       if (cli->fd_event != NULL) {
+                               event_fd_set_not_writeable(cli->fd_event);
+                       }
+                       return;
+               }
+
+               sent = sys_send(cli->fd, req->outbuf + req->sent,
+                           to_send - req->sent, 0);
+
+               if (sent < 0) {
+                       status = map_nt_error_from_unix(errno);
+                       goto sock_error;
+               }
+
+               req->sent += sent;
+
+               if (req->sent == to_send) {
+                       return;
+               }
+       }
+
        if (flags & EVENT_FD_READ) {
                int res, available;
                size_t old_size, new_size;
        if (flags & EVENT_FD_READ) {
                int res, available;
                size_t old_size, new_size;
@@ -952,45 +1099,18 @@ static void cli_state_handler(struct event_context *event_ctx,
                }
        }
 
                }
        }
 
-       if (flags & EVENT_FD_WRITE) {
-               size_t to_send;
-               ssize_t sent;
-
-               for (req = cli->outstanding_requests; req; req = req->next) {
-                       to_send = smb_len(req->outbuf)+4;
-                       if (to_send > req->sent) {
-                               break;
-                       }
-               }
-
-               if (req == NULL) {
-                       if (cli->fd_event != NULL) {
-                               event_fd_set_not_writeable(cli->fd_event);
-                       }
-                       return;
-               }
-
-               sent = sys_send(cli->fd, req->outbuf + req->sent,
-                           to_send - req->sent, 0);
+       return;
 
 
-               if (sent < 0) {
-                       status = map_nt_error_from_unix(errno);
-                       goto sock_error;
-               }
+ sock_error:
 
 
-               req->sent += sent;
+       for (req = cli->outstanding_requests; req; req = next) {
+               int i, num_async;
 
 
-               if (req->sent == to_send) {
-                       return;
-               }
-       }
-       return;
+               next = req->next;
+               num_async = req->num_async;
 
 
- sock_error:
-       for (req = cli->outstanding_requests; req; req = req->next) {
-               int i;
-               for (i=0; i<req->num_async; i++) {
-                       async_req_error(req->async[i], status);
+               for (i=0; i<num_async; i++) {
+                       async_req_nterror(req->async[i], status);
                }
        }
        TALLOC_FREE(cli->fd_event);
                }
        }
        TALLOC_FREE(cli->fd_event);