librpc/rpc: add dcerpc_sec_vt_header2_[from_ncacn_packet|equal]()
[sfrench/samba-autobuild/.git] / librpc / rpc / dcerpc_util.c
index a405ca8966af4752603d5fd8ab1a8c919ef427c1..425e748116a91aada48857e770ba8c1b70bd1c10 100644 (file)
@@ -209,20 +209,35 @@ static int dcerpc_read_ncacn_packet_next_vector(struct tstream_context *stream,
        off_t ofs = 0;
 
        if (state->buffer.length == 0) {
-               /* first get enough to read the fragment length */
+               /*
+                * first get enough to read the fragment length
+                *
+                * We read the full fixed ncacn_packet header
+                * in order to make wireshark happy with
+                * pcap files from socket_wrapper.
+                */
                ofs = 0;
-               state->buffer.length = DCERPC_FRAG_LEN_OFFSET + 2;
+               state->buffer.length = DCERPC_NCACN_PAYLOAD_OFFSET;
                state->buffer.data = talloc_array(state, uint8_t,
                                                  state->buffer.length);
                if (!state->buffer.data) {
                        return -1;
                }
-       } else if (state->buffer.length == (DCERPC_FRAG_LEN_OFFSET + 2)) {
+       } else if (state->buffer.length == DCERPC_NCACN_PAYLOAD_OFFSET) {
                /* now read the fragment length and allocate the full buffer */
                size_t frag_len = dcerpc_get_frag_length(&state->buffer);
 
                ofs = state->buffer.length;
 
+               if (frag_len < ofs) {
+                       /*
+                        * something is wrong, let the caller deal with it
+                        */
+                       *_vector = NULL;
+                       *_count = 0;
+                       return 0;
+               }
+
                state->buffer.data = talloc_realloc(state,
                                                    state->buffer.data,
                                                    uint8_t, frag_len);
@@ -266,7 +281,7 @@ static void dcerpc_read_ncacn_packet_done(struct tevent_req *subreq)
        ret = tstream_readv_pdu_recv(subreq, &sys_errno);
        TALLOC_FREE(subreq);
        if (ret == -1) {
-               status = map_nt_error_from_unix(sys_errno);
+               status = map_nt_error_from_unix_common(sys_errno);
                tevent_req_nterror(req, status);
                return;
        }
@@ -292,6 +307,11 @@ static void dcerpc_read_ncacn_packet_done(struct tevent_req *subreq)
                return;
        }
 
+       if (state->pkt->frag_length != state->buffer.length) {
+               tevent_req_nterror(req, NT_STATUS_RPC_PROTOCOL_ERROR);
+               return;
+       }
+
        tevent_req_done(req);
 }
 
@@ -318,3 +338,114 @@ NTSTATUS dcerpc_read_ncacn_packet_recv(struct tevent_req *req,
        tevent_req_received(req);
        return NT_STATUS_OK;
 }
+
+const char *dcerpc_default_transport_endpoint(TALLOC_CTX *mem_ctx,
+                                             enum dcerpc_transport_t transport,
+                                             const struct ndr_interface_table *table)
+{
+       NTSTATUS status;
+       const char *p = NULL;
+       const char *endpoint = NULL;
+       int i;
+       struct dcerpc_binding *default_binding = NULL;
+       TALLOC_CTX *frame = talloc_stackframe();
+
+       /* Find one of the default pipes for this interface */
+
+       for (i = 0; i < table->endpoints->count; i++) {
+
+               status = dcerpc_parse_binding(frame, table->endpoints->names[i],
+                                             &default_binding);
+               if (NT_STATUS_IS_OK(status)) {
+                       if (transport == NCA_UNKNOWN &&
+                           default_binding->endpoint != NULL) {
+                               p = default_binding->endpoint;
+                               break;
+                       }
+                       if (default_binding->transport == transport &&
+                           default_binding->endpoint != NULL) {
+                               p = default_binding->endpoint;
+                               break;
+                       }
+               }
+       }
+
+       if (i == table->endpoints->count || p == NULL) {
+               goto done;
+       }
+
+       /*
+        * extract the pipe name without \\pipe from for example
+        * ncacn_np:[\\pipe\\epmapper]
+        */
+       if (default_binding->transport == NCACN_NP) {
+               if (strncasecmp(p, "\\pipe\\", 6) == 0) {
+                       p += 6;
+               }
+               if (strncmp(p, "\\", 1) == 0) {
+                       p += 1;
+               }
+       }
+
+       endpoint = talloc_strdup(mem_ctx, p);
+
+ done:
+       talloc_free(frame);
+       return endpoint;
+}
+
+struct dcerpc_sec_vt_header2 dcerpc_sec_vt_header2_from_ncacn_packet(const struct ncacn_packet *pkt)
+{
+       struct dcerpc_sec_vt_header2 ret;
+
+       ZERO_STRUCT(ret);
+       ret.ptype = pkt->ptype;
+       memcpy(&ret.drep, pkt->drep, sizeof(ret.drep));
+       ret.call_id = pkt->call_id;
+
+       switch (pkt->ptype) {
+       case DCERPC_PKT_REQUEST:
+               ret.context_id = pkt->u.request.context_id;
+               ret.opnum      = pkt->u.request.opnum;
+               break;
+
+       case DCERPC_PKT_RESPONSE:
+               ret.context_id = pkt->u.response.context_id;
+               break;
+
+       case DCERPC_PKT_FAULT:
+               ret.context_id = pkt->u.fault.context_id;
+               break;
+
+       default:
+               break;
+       }
+
+       return ret;
+}
+
+bool dcerpc_sec_vt_header2_equal(const struct dcerpc_sec_vt_header2 *v1,
+                                const struct dcerpc_sec_vt_header2 *v2)
+{
+       if (v1->ptype != v2->ptype) {
+               return false;
+       }
+
+       if (memcmp(v1->drep, v2->drep, sizeof(v1->drep)) != 0) {
+               return false;
+       }
+
+       if (v1->call_id != v2->call_id) {
+               return false;
+       }
+
+       if (v1->context_id != v2->context_id) {
+               return false;
+       }
+
+       if (v1->opnum != v2->opnum) {
+               return false;
+       }
+
+       return true;
+}