First part of fix for bug #7331 - Compound async SMB 2 requests don't work right.
[samba.git] / source3 / smbd / smb2_server.c
index f5e37659f88dc2066d6124fdffd49593086a92fe..64f9eaba1465617a21e27993b9db2eb8a18610ef 100644 (file)
 #include "../libcli/smb/smb_common.h"
 #include "../lib/tsocket/tsocket.h"
 
+static const char *smb2_names[] = {
+       "SMB2_NEGPROT",
+       "SMB2_SESSSETUP",
+       "SMB2_LOGOFF",
+       "SMB2_TCON",
+       "SMB2_TDIS",
+       "SMB2_CREATE",
+       "SMB2_CLOSE",
+       "SMB2_FLUSH",
+       "SMB2_READ",
+       "SMB2_WRITE",
+       "SMB2_LOCK",
+       "SMB2_IOCTL",
+       "SMB2_CANCEL",
+       "SMB2_KEEPALIVE",
+       "SMB2_FIND",
+       "SMB2_NOTIFY",
+       "SMB2_GETINFO",
+       "SMB2_SETINFO",
+       "SMB2_BREAK"
+};
+
+const char *smb2_opcode_name(uint16_t opcode)
+{
+       if (opcode >= 0x12) {
+               return "Bad SMB2 opcode";
+       }
+       return smb2_names[opcode];
+}
+
+static void print_req_vectors(struct smbd_smb2_request *req)
+{
+       int i;
+
+       for (i = 0; i < req->in.vector_count; i++) {
+               dbgtext("\treq->in.vector[%u].iov_len = %u\n",
+                       (unsigned int)i,
+                       (unsigned int)req->in.vector[i].iov_len);
+       }
+       for (i = 0; i < req->out.vector_count; i++) {
+               dbgtext("\treq->out.vector[%u].iov_len = %u\n",
+                       (unsigned int)i,
+                       (unsigned int)req->out.vector[i].iov_len);
+       }
+}
+
 bool smbd_is_smb2_header(const uint8_t *inbuf, size_t size)
 {
        if (size < (4 + SMB2_HDR_BODY)) {
@@ -421,34 +467,123 @@ void smbd_server_connection_terminate_ex(struct smbd_server_connection *sconn,
        exit_server_cleanly(reason);
 }
 
-struct smbd_smb2_request_pending_state {
-       struct smbd_server_connection *sconn;
-       uint8_t buf[4 + SMB2_HDR_BODY + 0x08];
-       struct iovec vector;
-};
+static bool dup_smb2_vec(struct iovec *dstvec,
+                       const struct iovec *srcvec,
+                       int offset)
+{
 
-static void smbd_smb2_request_pending_writev_done(struct tevent_req *subreq);
+       if (srcvec[offset].iov_len &&
+                       srcvec[offset].iov_base) {
+               dstvec[offset].iov_base = talloc_memdup(dstvec,
+                                       srcvec[offset].iov_base,
+                                       srcvec[offset].iov_len);
+               if (!dstvec[offset].iov_base) {
+                       return false;
+               }
+               dstvec[offset].iov_len = srcvec[offset].iov_len;
+       } else {
+               dstvec[offset].iov_base = NULL;
+               dstvec[offset].iov_len = 0;
+       }
+       return true;
+}
+
+static struct smbd_smb2_request *dup_smb2_req(struct smbd_smb2_request *req)
+{
+       struct smbd_smb2_request *newreq = NULL;
+       struct iovec *outvec = NULL;
+       int count = req->out.vector_count;
+       int i;
+
+       newreq = smbd_smb2_request_allocate(req->sconn);
+       if (!newreq) {
+               return NULL;
+       }
+
+       newreq->sconn = req->sconn;
+       newreq->do_signing = req->do_signing;
+       newreq->current_idx = req->current_idx;
+       newreq->async = false;
+       newreq->cancelled = false;
+
+       outvec = talloc_array(newreq, struct iovec, count);
+       if (!outvec) {
+               TALLOC_FREE(newreq);
+               return NULL;
+       }
+       newreq->out.vector = outvec;
+       newreq->out.vector_count = count;
+
+       /* Setup the outvec's identically to req. */
+       outvec[0].iov_base = newreq->out.nbt_hdr;
+       outvec[0].iov_len = 4;
+       memcpy(newreq->out.nbt_hdr, req->out.nbt_hdr, 4);
+               
+       for (i = 1; i < count; i++) {
+               if (!dup_smb2_vec(outvec,
+                               req->out.vector,
+                               i)) {
+                       TALLOC_FREE(newreq);
+                       return NULL;
+               }
+       }
+
+       smb2_setup_nbt_length(newreq->out.vector,
+               newreq->out.vector_count);
+
+       return newreq;
+}
+
+static void smbd_smb2_request_writev_done(struct tevent_req *subreq);
 
 NTSTATUS smbd_smb2_request_pending_queue(struct smbd_smb2_request *req,
                                         struct tevent_req *subreq)
 {
-       struct smbd_smb2_request_pending_state *state;
-       uint8_t *outhdr;
        int i = req->current_idx;
-       uint32_t flags;
-       uint64_t message_id;
-       uint64_t async_id;
-       uint8_t *hdr;
-       uint8_t *body;
+       struct smbd_smb2_request *nreq = NULL;
+       uint8_t *outhdr = NULL;
+       uint8_t *outbody = NULL;
+       uint32_t flags = 0;
+       uint64_t message_id = 0;
+       uint64_t async_id = 0;
+       struct iovec *outvec = NULL;
 
        if (!tevent_req_is_in_progress(subreq)) {
                return NT_STATUS_OK;
        }
 
+       if (req->async) {
+               /* We're already async. */
+               return NT_STATUS_OK;
+       }
+
+       if (req->in.vector_count > i + 3) {
+               /*
+                * We're trying to go async in a compound
+                * request chain. This is not allowed.
+                * Cancel the outstanding request.
+                */
+               tevent_req_cancel(subreq);
+               return smbd_smb2_request_error(req,
+                       NT_STATUS_INSUFFICIENT_RESOURCES);
+       }
+
        req->subreq = subreq;
        subreq = NULL;
 
-       outhdr = (uint8_t *)req->out.vector[i].iov_base;
+       if (DEBUGLEVEL >= 10) {
+               dbgtext("smbd_smb2_request_pending_queue: req->current_idx = %u\n",
+                       (unsigned int)req->current_idx );
+               print_req_vectors(req);
+       }
+
+       /* Create a new smb2 request we'll use to return. */
+       nreq = dup_smb2_req(req);
+       if (!nreq) {
+               return NT_STATUS_NO_MEMORY;
+       }
+
+       outhdr = (uint8_t *)nreq->out.vector[i].iov_base;
 
        flags = IVAL(outhdr, SMB2_HDR_FLAGS);
        message_id = BVAL(outhdr, SMB2_HDR_MESSAGE_ID);
@@ -456,78 +591,144 @@ NTSTATUS smbd_smb2_request_pending_queue(struct smbd_smb2_request *req,
        async_id = message_id; /* keep it simple for now... */
        SIVAL(outhdr, SMB2_HDR_FLAGS,   flags | SMB2_HDR_FLAG_ASYNC);
        SBVAL(outhdr, SMB2_HDR_PID,     async_id);
+       SIVAL(outhdr, SMB2_HDR_STATUS,  NT_STATUS_V(STATUS_PENDING));
 
-       /* TODO: add a paramter to delay this */
-       state = talloc(req->sconn, struct smbd_smb2_request_pending_state);
-       if (state == NULL) {
+       nreq->out.vector[i+1].iov_base = talloc_zero_array(nreq->out.vector,
+                                                       uint8_t,
+                                                       9);
+       if (!nreq->out.vector[i+1].iov_base) {
                return NT_STATUS_NO_MEMORY;
        }
-       state->sconn = req->sconn;
+       nreq->out.vector[i+1].iov_len = 9;
+       outbody = (uint8_t *)nreq->out.vector[i+1].iov_base;
 
-       state->vector.iov_base = (void *)state->buf;
-       state->vector.iov_len = sizeof(state->buf);
+       /* setup error body header */
+       SSVAL(outbody, 0x00, 0x08 + 1);
+       SSVAL(outbody, 0x02, 0);
+       SIVAL(outbody, 0x04, 0);
+       /* Match W2K8R2... */
+       SCVAL(outbody, 8, 0x21);
 
-       _smb2_setlen(state->buf, sizeof(state->buf) - 4);
-       hdr = state->buf + 4;
-       body = hdr + SMB2_HDR_BODY;
+       nreq->out.vector[i+2].iov_base = NULL;
+       nreq->out.vector[i+2].iov_len = 0;
 
-       SIVAL(hdr, SMB2_HDR_PROTOCOL_ID,        SMB2_MAGIC);
-       SSVAL(hdr, SMB2_HDR_LENGTH,             SMB2_HDR_BODY);
-       SSVAL(hdr, SMB2_HDR_EPOCH,              0);
-       SIVAL(hdr, SMB2_HDR_STATUS,             NT_STATUS_V(STATUS_PENDING));
-       SSVAL(hdr, SMB2_HDR_OPCODE,
-             SVAL(outhdr, SMB2_HDR_OPCODE));
-       SSVAL(hdr, SMB2_HDR_CREDIT,             1);
-       SIVAL(hdr, SMB2_HDR_FLAGS,
-             IVAL(outhdr, SMB2_HDR_FLAGS));
-       SIVAL(hdr, SMB2_HDR_NEXT_COMMAND,       0);
-       SBVAL(hdr, SMB2_HDR_MESSAGE_ID,
-             BVAL(outhdr, SMB2_HDR_MESSAGE_ID));
-       SBVAL(hdr, SMB2_HDR_PID,
-             BVAL(outhdr, SMB2_HDR_PID));
-       SBVAL(hdr, SMB2_HDR_SESSION_ID,
-             BVAL(outhdr, SMB2_HDR_SESSION_ID));
-       memset(hdr+SMB2_HDR_SIGNATURE, 0, 16);
+       smb2_setup_nbt_length(nreq->out.vector,
+               nreq->out.vector_count);
 
-       SSVAL(body, 0x00, 0x08 + 1);
+       if (nreq->do_signing) {
+               NTSTATUS status;
+               status = smb2_signing_sign_pdu(nreq->session->session_key,
+                                       &nreq->out.vector[i], 3);
+               if (!NT_STATUS_IS_OK(status)) {
+                       return status;
+               }
+       }
 
-       SCVAL(body, 0x02, 0);
-       SCVAL(body, 0x03, 0);
-       SIVAL(body, 0x04, 0);
+       if (DEBUGLEVEL >= 10) {
+               dbgtext("smbd_smb2_request_pending_queue: nreq->current_idx = %u\n",
+                       (unsigned int)nreq->current_idx );
+               dbgtext("smbd_smb2_request_pending_queue: returning %u vectors\n",
+                       (unsigned int)nreq->out.vector_count );
+               print_req_vectors(nreq);
+       }
 
-       subreq = tstream_writev_queue_send(state,
-                                          req->sconn->smb2.event_ctx,
-                                          req->sconn->smb2.stream,
-                                          req->sconn->smb2.send_queue,
-                                          &state->vector, 1);
-       if (subreq == NULL) {
+       nreq->subreq = tstream_writev_queue_send(nreq,
+                                       nreq->sconn->smb2.event_ctx,
+                                       nreq->sconn->smb2.stream,
+                                       nreq->sconn->smb2.send_queue,
+                                       nreq->out.vector,
+                                       nreq->out.vector_count);
+
+       if (nreq->subreq == NULL) {
                return NT_STATUS_NO_MEMORY;
        }
-       tevent_req_set_callback(subreq,
-                               smbd_smb2_request_pending_writev_done,
-                               state);
 
-       return NT_STATUS_OK;
-}
+       tevent_req_set_callback(nreq->subreq,
+                       smbd_smb2_request_writev_done,
+                       nreq);
 
-static void smbd_smb2_request_pending_writev_done(struct tevent_req *subreq)
-{
-       struct smbd_smb2_request_pending_state *state =
-               tevent_req_callback_data(subreq,
-               struct smbd_smb2_request_pending_state);
-       struct smbd_server_connection *sconn = state->sconn;
-       int ret;
-       int sys_errno;
+       /* Note we're going async with this request. */
+       req->async = true;
 
-       ret = tstream_writev_queue_recv(subreq, &sys_errno);
-       TALLOC_FREE(subreq);
-       if (ret == -1) {
-               NTSTATUS status = map_nt_error_from_unix(sys_errno);
-               smbd_server_connection_terminate(sconn, nt_errstr(status));
-               return;
+       /*
+        * Now manipulate req so that the outstanding async request
+        * is the only one left in the struct smbd_smb2_request.
+        */
+
+       if (req->current_idx == 1) {
+               /* There was only one. */
+               goto out;
        }
 
-       TALLOC_FREE(state);
+       /* Re-arrange the in.vectors. */
+       req->in.vector[1] = req->in.vector[i];
+       req->in.vector[2] = req->in.vector[i+1];
+       req->in.vector[3] = req->in.vector[i+2];
+       req->in.vector_count = 4;
+       /* Reset the new in size. */
+       smb2_setup_nbt_length(req->in.vector, 4);
+
+       /* Now recreate the out.vectors. */
+       outvec = talloc_array(req, struct iovec, 4);
+       if (!outvec) {
+               return NT_STATUS_NO_MEMORY;
+       }
+       outvec[0].iov_base = req->out.nbt_hdr;
+       outvec[0].iov_len = 4;
+       SIVAL(req->out.nbt_hdr, 0, 0);
+
+       outvec[1].iov_base = talloc_memdup(outvec,
+                               req->out.vector[i].iov_base,
+                               SMB2_HDR_BODY + 8);
+       if (!outvec[1].iov_base) {
+               return NT_STATUS_NO_MEMORY;
+       }
+       outvec[1].iov_len = SMB2_HDR_BODY;
+
+       outvec[2].iov_base = ((uint8_t *)outvec[1].iov_base) +
+                               SMB2_HDR_BODY;
+       outvec[2].iov_len = 8;
+
+       if (req->out.vector[i+2].iov_base &&
+                       req->out.vector[i+2].iov_len) {
+               outvec[3].iov_base = talloc_memdup(outvec,
+                                       req->out.vector[i+2].iov_base,
+                                       req->out.vector[i+2].iov_len);
+               if (!outvec[3].iov_base) {
+                       return NT_STATUS_NO_MEMORY;
+               }
+               outvec[3].iov_len = req->out.vector[i+2].iov_len;
+       } else {
+               outvec[3].iov_base = NULL;
+               outvec[3].iov_len = 0;
+       }
+
+       TALLOC_FREE(req->out.vector);
+
+       req->out.vector = outvec;
+
+       req->current_idx = 1;
+       req->out.vector_count = 4;
+
+  out:
+
+       smb2_setup_nbt_length(req->out.vector,
+               req->out.vector_count);
+
+       /* Ensure our final reply matches the interim one. */
+       outhdr = (uint8_t *)req->out.vector[1].iov_base;
+       SIVAL(outhdr, SMB2_HDR_FLAGS,   flags | SMB2_HDR_FLAG_ASYNC);
+       SBVAL(outhdr, SMB2_HDR_PID,     async_id);
+
+       {
+               const uint8_t *inhdr =
+                       (const uint8_t *)req->in.vector[1].iov_base;
+               DEBUG(10,("smbd_smb2_request_pending_queue: opcode[%s] mid %llu "
+                       "going async\n",
+                       smb2_opcode_name((uint16_t)IVAL(inhdr, SMB2_HDR_OPCODE)),
+                       (unsigned long long)async_id ));
+       }
+       return NT_STATUS_OK;
 }
 
 static NTSTATUS smbd_smb2_request_process_cancel(struct smbd_smb2_request *req)
@@ -539,6 +740,7 @@ static NTSTATUS smbd_smb2_request_process_cancel(struct smbd_smb2_request *req)
        uint32_t flags;
        uint64_t search_message_id;
        uint64_t search_async_id;
+       uint64_t found_id;
 
        inhdr = (const uint8_t *)req->in.vector[i].iov_base;
 
@@ -566,17 +768,26 @@ static NTSTATUS smbd_smb2_request_process_cancel(struct smbd_smb2_request *req)
 
                if (flags & SMB2_HDR_FLAG_ASYNC) {
                        if (search_async_id == async_id) {
+                               found_id = async_id;
                                break;
                        }
                } else {
                        if (search_message_id == message_id) {
+                               found_id = message_id;
                                break;
                        }
                }
        }
 
        if (cur && cur->subreq) {
+               inhdr = (const uint8_t *)cur->in.vector[i].iov_base;
+               DEBUG(10,("smbd_smb2_request_process_cancel: attempting to "
+                       "cancel opcode[%s] mid %llu\n",
+                       smb2_opcode_name((uint16_t)IVAL(inhdr, SMB2_HDR_OPCODE)),
+                        (unsigned long long)found_id ));
                tevent_req_cancel(cur->subreq);
+               TALLOC_FREE(cur->subreq);
+               TALLOC_FREE(cur);
        }
 
        return NT_STATUS_OK;
@@ -588,6 +799,7 @@ static NTSTATUS smbd_smb2_request_dispatch(struct smbd_smb2_request *req)
        int i = req->current_idx;
        uint16_t opcode;
        uint32_t flags;
+       uint64_t mid;
        NTSTATUS status;
        NTSTATUS session_status;
        uint32_t allowed_flags;
@@ -598,7 +810,10 @@ static NTSTATUS smbd_smb2_request_dispatch(struct smbd_smb2_request *req)
 
        flags = IVAL(inhdr, SMB2_HDR_FLAGS);
        opcode = IVAL(inhdr, SMB2_HDR_OPCODE);
-       DEBUG(10,("smbd_smb2_request_dispatch: opcode[%u]\n", opcode));
+       mid = BVAL(inhdr, SMB2_HDR_MESSAGE_ID);
+       DEBUG(10,("smbd_smb2_request_dispatch: opcode[%s] mid = %llu\n",
+               smb2_opcode_name(opcode),
+               (unsigned long long)mid));
 
        allowed_flags = SMB2_HDR_FLAG_CHAINED |
                        SMB2_HDR_FLAG_SIGNED |
@@ -806,8 +1021,9 @@ static NTSTATUS smbd_smb2_request_dispatch(struct smbd_smb2_request *req)
        return smbd_smb2_request_error(req, NT_STATUS_INVALID_PARAMETER);
 }
 
-static void smbd_smb2_request_dispatch_compound(struct tevent_req *subreq);
-static void smbd_smb2_request_writev_done(struct tevent_req *subreq);
+static void smbd_smb2_request_dispatch_compound(struct tevent_context *ctx,
+                                       struct tevent_immediate *im,
+                                       void *private_data);
 
 static NTSTATUS smbd_smb2_request_reply(struct smbd_smb2_request *req)
 {
@@ -830,20 +1046,30 @@ static NTSTATUS smbd_smb2_request_reply(struct smbd_smb2_request *req)
        req->current_idx += 3;
 
        if (req->current_idx < req->out.vector_count) {
-               struct timeval zero = timeval_zero();
-               subreq = tevent_wakeup_send(req,
-                                           req->sconn->smb2.event_ctx,
-                                           zero);
-               if (subreq == NULL) {
+               /*
+                * We must process the remaining compound
+                * SMB2 requests before any new incoming SMB2
+                * requests. This is because incoming SMB2
+                * requests may include a cancel for a
+                * compound request we haven't processed
+                * yet.
+                */
+               struct tevent_immediate *im = tevent_create_immediate(req);
+               if (!im) {
                        return NT_STATUS_NO_MEMORY;
                }
-               tevent_req_set_callback(subreq,
+               tevent_schedule_immediate(im,
+                                       req->sconn->smb2.event_ctx,
                                        smbd_smb2_request_dispatch_compound,
                                        req);
-
                return NT_STATUS_OK;
        }
 
+       if (DEBUGLEVEL >= 10) {
+               dbgtext("smbd_smb2_request_reply: sending...\n");
+               print_req_vectors(req);
+       }
+
        subreq = tstream_writev_queue_send(req,
                                           req->sconn->smb2.event_ctx,
                                           req->sconn->smb2.stream,
@@ -858,18 +1084,22 @@ static NTSTATUS smbd_smb2_request_reply(struct smbd_smb2_request *req)
        return NT_STATUS_OK;
 }
 
-static void smbd_smb2_request_dispatch_compound(struct tevent_req *subreq)
+static void smbd_smb2_request_dispatch_compound(struct tevent_context *ctx,
+                                       struct tevent_immediate *im,
+                                       void *private_data)
 {
-       struct smbd_smb2_request *req = tevent_req_callback_data(subreq,
+       struct smbd_smb2_request *req = talloc_get_type_abort(private_data,
                                        struct smbd_smb2_request);
        struct smbd_server_connection *sconn = req->sconn;
        NTSTATUS status;
 
-       tevent_wakeup_recv(subreq);
-       TALLOC_FREE(subreq);
+       TALLOC_FREE(im);
 
-       DEBUG(10,("smbd_smb2_request_dispatch_compound: idx[%d] of %d vectors\n",
-                 req->current_idx, req->in.vector_count));
+       if (DEBUGLEVEL >= 10) {
+               DEBUG(10,("smbd_smb2_request_dispatch_compound: idx[%d] of %d vectors\n",
+                       req->current_idx, req->in.vector_count));
+               print_req_vectors(req);
+       }
 
        status = smbd_smb2_request_dispatch(req);
        if (!NT_STATUS_IS_OK(status)) {
@@ -891,6 +1121,8 @@ static void smbd_smb2_request_writev_done(struct tevent_req *subreq)
        TALLOC_FREE(req);
        if (ret == -1) {
                NTSTATUS status = map_nt_error_from_unix(sys_errno);
+               DEBUG(2,("smbd_smb2_request_writev_done: client write error %s\n",
+                       nt_errstr(status)));
                smbd_server_connection_terminate(sconn, nt_errstr(status));
                return;
        }
@@ -1551,6 +1783,8 @@ static void smbd_smb2_request_incoming(struct tevent_req *subreq)
        status = smbd_smb2_request_read_recv(subreq, sconn, &req);
        TALLOC_FREE(subreq);
        if (!NT_STATUS_IS_OK(status)) {
+               DEBUG(2,("smbd_smb2_request_incoming: client read error %s\n",
+                       nt_errstr(status)));
                smbd_server_connection_terminate(sconn, nt_errstr(status));
                return;
        }