python/samba/tests/dcerpc/raw_protocol.py
authorStefan Metzmacher <metze@samba.org>
Wed, 2 Apr 2014 20:01:36 +0000 (22:01 +0200)
committerMatthieu Patou <mat@matws.net>
Fri, 3 Oct 2014 19:16:32 +0000 (12:16 -0700)
python/samba/tests/dcerpc/raw_protocol.py

index 851a053372a3d013281ca27ff08e0293be649977..ded785800bfbdb3645ac9738be84fa024f72505a 100755 (executable)
@@ -49,7 +49,43 @@ class DCERPCTest(TestCase):
         super(DCERPCTest, self).setUp()
         self.settings = {}
         self.settings["lp_ctx"] = self.lp_ctx = samba.tests.env_loadparm()
-        self.settings["target_hostname"] = host=os.getenv('SERVER_IP')
+        self.settings["target_hostname"] = host =os.getenv('SERVER_IP')
+        try:
+            self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            self.s.connect((host, 135))
+        finally:
+            pass
+
+    def send_pdu(self, req, dump=False):
+        try:
+            req_pdu = ndr.ndr_pack(req)
+            if dump:
+                print "s: %s" % ndr.ndr_print(req)
+                #print self.hexdump(req_pdu)
+            self.s.send(req_pdu, 0)
+        finally:
+            pass
+
+    def recv_pdu(self, dump=False):
+        try:
+            rep_pdu = self.s.recv(0xffff, 0)
+            rep = ndr.ndr_unpack(dcerpc.ncacn_packet, rep_pdu, allow_remaining=True)
+            if dump:
+                #print self.hexdump(rep_pdu)
+                print "r: %s" % ndr.ndr_print(rep)
+        finally:
+            pass
+        return rep
+
+    def hexdump(self, src, length=8):
+        N=0; result=''
+        while src:
+           s,src = src[:length],src[length:]
+           hexa = ' '.join(["%02X"%ord(x) for x in s])
+           s = s.translate(FILTER)
+           result += "%04X   %-*s   %s\n" % (N, length*3, hexa, s)
+           N+=length
+        return result
 
     def generate_auth(self,
                       auth_type=None,
@@ -76,6 +112,17 @@ class DCERPCTest(TestCase):
 
         return ai
 
+    def parse_auth(self, auth_info, dump=False):
+        if (len(auth_info) <= 8):
+            return None
+
+        a = ndr.ndr_unpack(dcerpc.auth, auth_info, allow_remaining=True)
+        if dump:
+            #print self.hexdump(rep_pdu)
+            print ndr.ndr_print(a)
+
+        return a
+
     def generate_pdu(self, ptype, call_id, payload,
                      rpc_vers=5,
                      rpc_vers_minor=0,
@@ -104,12 +151,8 @@ class DCERPCTest(TestCase):
 
         pdu = ndr.ndr_pack(p)
         p.frag_length = len(pdu)
-        pdu = ndr.ndr_pack(p)
-        if dump:
-            print ndr.ndr_print(p)
-            print self.hexdump(pdu)
 
-        return pdu
+        return p
 
     def generate_bind(self, call_id,
                       pfc_flags = dcerpc.DCERPC_PFC_FLAG_FIRST | dcerpc.DCERPC_PFC_FLAG_LAST,
@@ -128,12 +171,13 @@ class DCERPCTest(TestCase):
         b.ctx_list = ctx_list
         b.auth_info = auth_info
 
-        pdu = self.generate_pdu(ptype=dcerpc.DCERPC_PKT_BIND,
-                                pfc_flags=pfc_flags,
-                                call_id=call_id,
-                                payload=b,
-                                dump=dump)
-        return pdu
+        p = self.generate_pdu(ptype=dcerpc.DCERPC_PKT_BIND,
+                              pfc_flags=pfc_flags,
+                              call_id=call_id,
+                              payload=b,
+                              dump=dump)
+
+        return p
 
     def generate_alter(self, call_id,
                        pfc_flags = dcerpc.DCERPC_PFC_FLAG_FIRST | dcerpc.DCERPC_PFC_FLAG_LAST,
@@ -152,12 +196,13 @@ class DCERPCTest(TestCase):
         a.ctx_list = ctx_list
         a.auth_info = auth_info
 
-        pdu = self.generate_pdu(ptype=dcerpc.DCERPC_PKT_ALTER,
-                                pfc_flags=pfc_flags,
-                                call_id=call_id,
-                                payload=a,
-                                dump=dump)
-        return pdu
+        p = self.generate_pdu(ptype=dcerpc.DCERPC_PKT_ALTER,
+                              pfc_flags=pfc_flags,
+                              call_id=call_id,
+                              payload=a,
+                              dump=dump)
+
+        return p
 
     def make_bind_pdu(self, dump=False):
         ndrINV = misc.ndr_syntax_id()
@@ -208,45 +253,111 @@ class DCERPCTest(TestCase):
                                        auth_blob=to_server,
                                        dump=dump)
 
-        pdu = self.generate_bind(call_id=0,
+        req = self.generate_bind(call_id=0,
                                  ctx_list=ctx_list,
                                  auth_info=auth_info,
                                  dump=dump)
 
-        return pdu
+        return req
 
-    def dcerpc_transaction(self, req_pdu, host=os.getenv('SERVER_IP'), dump=False):
-        "send a RPC pdu and read the reply"
-        s = None
-        try:
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            s.connect((host, 135))
-            s.send(req_pdu, 0)
-            rep_pdu = s.recv(0xffff, 0)
-            rep = ndr.ndr_unpack(dcerpc.ncacn_packet, rep_pdu, allow_remaining=True)
-            if dump:
-                print self.hexdump(rep_pdu)
-                print ndr.ndr_print(rep)
-        finally:
-                if s is not None:
-                    s.close()
-
-    def hexdump(self, src, length=8):
-        N=0; result=''
-        while src:
-           s,src = src[:length],src[length:]
-           hexa = ' '.join(["%02X"%ord(x) for x in s])
-           s = s.translate(FILTER)
-           result += "%04X   %-*s   %s\n" % (N, length*3, hexa, s)
-           N+=length
-        return result
 
 class TestDCERPC_BIND(DCERPCTest):
 
-    def test_one(self):
-        p = self.make_bind_pdu(dump=True)
-        r = self.dcerpc_transaction(p, dump=True)
+    def _test_one(self):
+        req = self.make_bind_pdu(dump=False)
+        self.send_pdu(req, dump=True)
+        rep = self.recv_pdu(dump=True)
+        self.send_pdu(req, dump=True)
+        rep = self.recv_pdu(dump=True)
+
+    def test_two(self):
+        ndr32 = base.transfer_syntax_ndr()
+        ndr64 = base.transfer_syntax_ndr64()
+        features = 0
+        features |= dcerpc.DCERPC_BIND_TIME_SECURITY_CONTEXT_MULTIPLEXING
+        features |= dcerpc.DCERPC_BIND_TIME_KEEP_CONNECTION_ON_ORPHAN
+        bt_features = base.bind_time_features_syntax(features)
+        tsf0_list = [ndr32]
+        tsf1_list = [ndr64]
+        tsf2_list = [bt_features]
+
+        ctx0 = dcerpc.ctx_list()
+        ctx0.context_id = 0
+        ctx0.num_transfer_syntaxes = len(tsf0_list)
+        ctx0.abstract_syntax = samba.dcerpc.epmapper.abstract_syntax()
+        ctx0.transfer_syntaxes = tsf0_list
+        ctx1 = dcerpc.ctx_list()
+        ctx1.context_id = 1
+        ctx1.num_transfer_syntaxes = len(tsf1_list)
+        ctx1.abstract_syntax = samba.dcerpc.epmapper.abstract_syntax()
+        ctx1.transfer_syntaxes = tsf1_list
+        ctx2 = dcerpc.ctx_list()
+        ctx2.context_id = 2
+        ctx2.num_transfer_syntaxes = len(tsf2_list)
+        ctx2.abstract_syntax = samba.dcerpc.epmapper.abstract_syntax()
+        ctx2.transfer_syntaxes = tsf2_list
+        ctx_list = [ctx0,ctx1,ctx2]
+
+        c = Credentials()
+        c.set_anonymous()
+        g = gensec.Security.start_client(self.settings)
+        g.set_credentials(c)
+        g.want_feature(gensec.FEATURE_DCE_STYLE)
+        #g.set_max_update_size(5)
+        auth_type = dcerpc.DCERPC_AUTH_TYPE_SPNEGO
+        auth_level = dcerpc.DCERPC_AUTH_LEVEL_CONNECT
+        auth_context_id = 0
+        g.start_mech_by_authtype(auth_type, auth_level)
+        from_server = ""
+        (finished, to_server) = g.update(from_server)
+        self.assertFalse(finished)
+
+        auth_info = self.generate_auth(auth_type=auth_type,
+                                       auth_level=auth_level,
+                                       auth_context_id=auth_context_id,
+                                       auth_blob=to_server)
+
+        req = self.generate_bind(call_id=0,
+                                 ctx_list=ctx_list,
+                                 auth_info=auth_info)
 
+        self.send_pdu(req, dump=True)
+        rep = self.recv_pdu(dump=True)
+        self.assertEquals(rep.ptype, dcerpc.DCERPC_PKT_BIND_ACK)
+
+        self.assertTrue(rep.u.num_results >= 1)
+        self.assertTrue(rep.u.num_results <= len(ctx_list))
+        for i in range(0, rep.u.num_results):
+            r = rep.u.ctx_list[i]
+            if r.result == dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE:
+                ctx_list = [ctx_list[i]]
+                break
+
+        a = self.parse_auth(rep.u.auth_info, dump=True)
+
+        #for ctx in rep.u.
+        from_server = a.credentials
+        (finished, to_server) = g.update(from_server)
+        self.assertFalse(finished)
+
+        auth_context_id = 1
+        auth_info = self.generate_auth(auth_type=auth_type,
+                                       auth_level=auth_level,
+                                       auth_context_id=auth_context_id,
+                                       auth_blob=to_server)
+
+        req = self.generate_alter(call_id=0,
+                                  ctx_list=ctx_list,
+                                  auth_info=auth_info)
+
+        self.send_pdu(req, dump=True)
+        rep = self.recv_pdu(dump=True)
+        self.assertEquals(rep.ptype, dcerpc.DCERPC_PKT_ALTER_RESP)
+        a = self.parse_auth(rep.u.auth_info, dump=True)
+
+        from_server = a.credentials
+        (finished, to_server) = g.update(from_server)
+        self.assertTrue(finished)
 
 if __name__ == "__main__":
     import unittest