py:dcerpc/raw_testcase: use generate_request_auth() in do_single_request()
[samba.git] / python / samba / tests / dcerpc / raw_testcase.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 import sys
20 import socket
21 import samba.dcerpc.dcerpc as dcerpc
22 import samba.dcerpc.base
23 import samba.dcerpc.epmapper
24 import samba.tests
25 from samba import gensec
26 from samba.credentials import Credentials
27 from samba.tests import TestCase
28 from samba.ndr import ndr_pack, ndr_unpack, ndr_unpack_out
29 from samba.compat import text_type
30
31
32 class RawDCERPCTest(TestCase):
33     """A raw DCE/RPC Test case."""
34
35     def _disconnect(self, reason):
36         if self.s is None:
37             return
38         self.s.close()
39         self.s = None
40         if self.do_hexdump:
41             sys.stderr.write("disconnect[%s]\n" % reason)
42
43     def connect(self):
44         try:
45             self.a = socket.getaddrinfo(self.host, self.tcp_port, socket.AF_UNSPEC,
46                                         socket.SOCK_STREAM, socket.SOL_TCP,
47                                         0)
48             self.s = socket.socket(self.a[0][0], self.a[0][1], self.a[0][2])
49             self.s.settimeout(10)
50             self.s.connect(self.a[0][4])
51         except socket.error as e:
52             self.s.close()
53             raise
54         except IOError as e:
55             self.s.close()
56             raise
57         except Exception as e:
58             raise
59         finally:
60             pass
61
62     def setUp(self):
63         super(RawDCERPCTest, self).setUp()
64         self.do_ndr_print = False
65         self.do_hexdump = False
66
67         self.ignore_random_pad = samba.tests.env_get_var_value('IGNORE_RANDOM_PAD',
68                                                                allow_missing=True)
69         self.host = samba.tests.env_get_var_value('SERVER')
70         self.target_hostname = samba.tests.env_get_var_value('TARGET_HOSTNAME', allow_missing=True)
71         if self.target_hostname is None:
72             self.target_hostname = self.host
73         self.tcp_port = 135
74
75         self.settings = {}
76         self.settings["lp_ctx"] = self.lp_ctx = samba.tests.env_loadparm()
77         self.settings["target_hostname"] = self.target_hostname
78
79         self.connect()
80
81     def tearDown(self):
82         self._disconnect("tearDown")
83         super(TestCase, self).tearDown()
84
85     def noop(self):
86         return
87
88     def second_connection(self, tcp_port=None):
89         c = RawDCERPCTest(methodName='noop')
90         c.do_ndr_print = self.do_ndr_print
91         c.do_hexdump = self.do_hexdump
92         c.ignore_random_pad = self.ignore_random_pad
93
94         c.host = self.host
95         c.target_hostname = self.target_hostname
96         if tcp_port is not None:
97             c.tcp_port = tcp_port
98         else:
99             c.tcp_port = self.tcp_port
100
101         c.settings = self.settings
102
103         c.connect()
104         return c
105
106     def get_user_creds(self):
107         c = Credentials()
108         c.guess()
109         username = samba.tests.env_get_var_value('USERNAME')
110         password = samba.tests.env_get_var_value('PASSWORD')
111         c.set_username(username)
112         c.set_password(password)
113         return c
114
115     def get_anon_creds(self):
116         c = Credentials()
117         c.set_anonymous()
118         return c
119
120     def get_auth_context_creds(self, creds, auth_type, auth_level,
121                                auth_context_id,
122                                g_auth_level=None):
123
124         if g_auth_level is None:
125             g_auth_level = auth_level
126
127         g = gensec.Security.start_client(self.settings)
128         g.set_credentials(creds)
129         g.want_feature(gensec.FEATURE_DCE_STYLE)
130         g.start_mech_by_authtype(auth_type, g_auth_level)
131
132         auth_context = {}
133         auth_context["auth_type"] = auth_type
134         auth_context["auth_level"] = auth_level
135         auth_context["auth_context_id"] = auth_context_id
136         auth_context["g_auth_level"] = g_auth_level
137         auth_context["gensec"] = g
138
139         return auth_context
140
141     def do_generic_bind(self, ctx, auth_context=None,
142                         pfc_flags=samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
143                         samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
144                         assoc_group_id=0, call_id=0,
145                         nak_reason=None, alter_fault=None):
146         ctx_list = [ctx]
147
148         if auth_context is not None:
149             from_server = b""
150             (finished, to_server) = auth_context["gensec"].update(from_server)
151             self.assertFalse(finished)
152
153             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
154                                            auth_level=auth_context["auth_level"],
155                                            auth_context_id=auth_context["auth_context_id"],
156                                            auth_blob=to_server)
157         else:
158             auth_info = b""
159
160         req = self.generate_bind(call_id=call_id,
161                                  pfc_flags=pfc_flags,
162                                  ctx_list=ctx_list,
163                                  assoc_group_id=assoc_group_id,
164                                  auth_info=auth_info)
165         self.send_pdu(req)
166         rep = self.recv_pdu()
167         if nak_reason is not None:
168             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_BIND_NAK, req.call_id,
169                             auth_length=0)
170             self.assertEquals(rep.u.reject_reason, nak_reason)
171             self.assertEquals(rep.u.num_versions, 1)
172             self.assertEquals(rep.u.versions[0].rpc_vers, req.rpc_vers)
173             self.assertEquals(rep.u.versions[0].rpc_vers_minor, req.rpc_vers_minor)
174             self.assertPadding(rep.u._pad, 3)
175             return
176         self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_BIND_ACK, req.call_id,
177                         pfc_flags=pfc_flags)
178         self.assertEquals(rep.u.max_xmit_frag, req.u.max_xmit_frag)
179         self.assertEquals(rep.u.max_recv_frag, req.u.max_recv_frag)
180         if assoc_group_id != 0:
181             self.assertEquals(rep.u.assoc_group_id, assoc_group_id)
182         else:
183             self.assertNotEquals(rep.u.assoc_group_id, 0)
184             assoc_group_id = rep.u.assoc_group_id
185         port_str = "%d" % self.tcp_port
186         port_len = len(port_str) + 1
187         mod_len = (2 + port_len) % 4
188         if mod_len != 0:
189             port_pad = 4 - mod_len
190         else:
191             port_pad = 0
192         self.assertEquals(rep.u.secondary_address_size, port_len)
193         self.assertEquals(rep.u.secondary_address, port_str)
194         self.assertPadding(rep.u._pad1, port_pad)
195         self.assertEquals(rep.u.num_results, 1)
196         self.assertEquals(rep.u.ctx_list[0].result,
197                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE)
198         self.assertEquals(rep.u.ctx_list[0].reason,
199                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_REASON_NOT_SPECIFIED)
200         self.assertNDRSyntaxEquals(rep.u.ctx_list[0].syntax, ctx.transfer_syntaxes[0])
201         ack = rep
202         if auth_context is None:
203             self.assertEquals(rep.auth_length, 0)
204             self.assertEquals(len(rep.u.auth_info), 0)
205             return ack
206         self.assertNotEquals(rep.auth_length, 0)
207         self.assertGreater(len(rep.u.auth_info), samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
208         self.assertEquals(rep.auth_length, len(rep.u.auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
209
210         a = self.parse_auth(rep.u.auth_info, auth_context=auth_context)
211
212         from_server = a.credentials
213         (finished, to_server) = auth_context["gensec"].update(from_server)
214         self.assertFalse(finished)
215
216         auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
217                                        auth_level=auth_context["auth_level"],
218                                        auth_context_id=auth_context["auth_context_id"],
219                                        auth_blob=to_server)
220         req = self.generate_alter(call_id=call_id,
221                                   ctx_list=ctx_list,
222                                   assoc_group_id=0xffffffff - assoc_group_id,
223                                   auth_info=auth_info)
224         self.send_pdu(req)
225         rep = self.recv_pdu()
226         if alter_fault is not None:
227             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_FAULT, req.call_id,
228                             pfc_flags=req.pfc_flags |
229                             samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_DID_NOT_EXECUTE,
230                             auth_length=0)
231             self.assertNotEquals(rep.u.alloc_hint, 0)
232             self.assertEquals(rep.u.context_id, 0)
233             self.assertEquals(rep.u.cancel_count, 0)
234             self.assertEquals(rep.u.flags, 0)
235             self.assertEquals(rep.u.status, alter_fault)
236             self.assertEquals(rep.u.reserved, 0)
237             self.assertEquals(len(rep.u.error_and_verifier), 0)
238             return None
239         self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_ALTER_RESP, req.call_id)
240         self.assertEquals(rep.u.max_xmit_frag, req.u.max_xmit_frag)
241         self.assertEquals(rep.u.max_recv_frag, req.u.max_recv_frag)
242         self.assertEquals(rep.u.assoc_group_id, assoc_group_id)
243         self.assertEquals(rep.u.secondary_address_size, 0)
244         self.assertEquals(rep.u.secondary_address, '')
245         self.assertPadding(rep.u._pad1, 2)
246         self.assertEquals(rep.u.num_results, 1)
247         self.assertEquals(rep.u.ctx_list[0].result,
248                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE)
249         self.assertEquals(rep.u.ctx_list[0].reason,
250                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_REASON_NOT_SPECIFIED)
251         self.assertNDRSyntaxEquals(rep.u.ctx_list[0].syntax, ctx.transfer_syntaxes[0])
252         self.assertNotEquals(rep.auth_length, 0)
253         self.assertGreater(len(rep.u.auth_info), samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
254         self.assertEquals(rep.auth_length, len(rep.u.auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
255
256         a = self.parse_auth(rep.u.auth_info, auth_context=auth_context)
257
258         from_server = a.credentials
259         (finished, to_server) = auth_context["gensec"].update(from_server)
260         self.assertTrue(finished)
261
262         return ack
263
264     def prepare_presentation(self, abstract, transfer, object=None,
265                              context_id=0xffff, epmap=False, auth_context=None,
266                              pfc_flags=samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
267                              samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
268                              assoc_group_id=0,
269                              return_ack=False):
270         if epmap:
271             self.epmap_reconnect(abstract, transfer=transfer, object=object)
272
273         tsf1_list = [transfer]
274         ctx = samba.dcerpc.dcerpc.ctx_list()
275         ctx.context_id = context_id
276         ctx.num_transfer_syntaxes = len(tsf1_list)
277         ctx.abstract_syntax = abstract
278         ctx.transfer_syntaxes = tsf1_list
279
280         ack = self.do_generic_bind(ctx=ctx,
281                                    auth_context=auth_context,
282                                    pfc_flags=pfc_flags,
283                                    assoc_group_id=assoc_group_id)
284         if ack is None:
285             ctx = None
286
287         if return_ack:
288             return (ctx, ack)
289         return ctx
290
291     def do_single_request(self, call_id, ctx, io,
292                           auth_context=None,
293                           object=None,
294                           bigendian=False, ndr64=False,
295                           allow_remaining=False,
296                           send_req=True,
297                           recv_rep=True,
298                           fault_pfc_flags=(
299                               samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
300                               samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
301                           fault_status=None,
302                           fault_context_id=None,
303                           timeout=None,
304                           ndr_print=None,
305                           hexdump=None):
306
307         if fault_context_id is None:
308             fault_context_id = ctx.context_id
309
310         if ndr_print is None:
311             ndr_print = self.do_ndr_print
312         if hexdump is None:
313             hexdump = self.do_hexdump
314
315         if send_req:
316             if ndr_print:
317                 sys.stderr.write("in: %s" % samba.ndr.ndr_print_in(io))
318             stub_in = samba.ndr.ndr_pack_in(io, bigendian=bigendian, ndr64=ndr64)
319             if hexdump:
320                 sys.stderr.write("stub_in: %d\n%s" % (len(stub_in), self.hexdump(stub_in)))
321
322         pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST
323         pfc_flags |= samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST
324         if object is not None:
325             pfc_flags |= samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_OBJECT_UUID
326
327         req = self.generate_request_auth(call_id=call_id,
328                                          context_id=ctx.context_id,
329                                          pfc_flags=pfc_flags,
330                                          object=object,
331                                          opnum=io.opnum(),
332                                          stub=stub_in,
333                                          auth_context=auth_context)
334         if send_req:
335             self.send_pdu(req, ndr_print=ndr_print, hexdump=hexdump)
336         if recv_rep:
337             (rep, rep_blob) = self.recv_pdu_raw(timeout=timeout,
338                                                 ndr_print=ndr_print,
339                                                 hexdump=hexdump)
340             if fault_status:
341                 self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_FAULT, req.call_id,
342                                 pfc_flags=fault_pfc_flags, auth_length=0)
343                 self.assertNotEquals(rep.u.alloc_hint, 0)
344                 self.assertEquals(rep.u.context_id, fault_context_id)
345                 self.assertEquals(rep.u.cancel_count, 0)
346                 self.assertEquals(rep.u.flags, 0)
347                 self.assertEquals(rep.u.status, fault_status)
348                 self.assertEquals(rep.u.reserved, 0)
349                 self.assertEquals(len(rep.u.error_and_verifier), 0)
350                 return
351
352             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_RESPONSE, req.call_id,
353                             auth_length=req.auth_length)
354             self.assertNotEquals(rep.u.alloc_hint, 0)
355             self.assertEquals(rep.u.context_id, req.u.context_id & 0xff)
356             self.assertEquals(rep.u.cancel_count, 0)
357             self.assertGreaterEqual(len(rep.u.stub_and_verifier), rep.u.alloc_hint)
358             stub_out = self.check_response_auth(rep, rep_blob, auth_context)
359             self.assertEqual(len(stub_out), rep.u.alloc_hint)
360
361             if hexdump:
362                 sys.stderr.write("stub_out: %d\n%s" % (len(stub_out), self.hexdump(stub_out)))
363             ndr_unpack_out(io, stub_out, bigendian=bigendian, ndr64=ndr64,
364                            allow_remaining=allow_remaining)
365             if ndr_print:
366                 sys.stderr.write("out: %s" % samba.ndr.ndr_print_out(io))
367
368     def epmap_reconnect(self, abstract, transfer=None, object=None):
369         ndr32 = samba.dcerpc.base.transfer_syntax_ndr()
370
371         if transfer is None:
372             transfer = ndr32
373
374         if object is None:
375             object = samba.dcerpc.misc.GUID()
376
377         ctx = self.prepare_presentation(samba.dcerpc.epmapper.abstract_syntax(),
378                                         transfer, context_id=0)
379
380         data1 = ndr_pack(abstract)
381         lhs1 = samba.dcerpc.epmapper.epm_lhs()
382         lhs1.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_UUID
383         lhs1.lhs_data = data1[:18]
384         rhs1 = samba.dcerpc.epmapper.epm_rhs_uuid()
385         rhs1.unknown = data1[18:]
386         floor1 = samba.dcerpc.epmapper.epm_floor()
387         floor1.lhs = lhs1
388         floor1.rhs = rhs1
389         data2 = ndr_pack(transfer)
390         lhs2 = samba.dcerpc.epmapper.epm_lhs()
391         lhs2.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_UUID
392         lhs2.lhs_data = data2[:18]
393         rhs2 = samba.dcerpc.epmapper.epm_rhs_uuid()
394         rhs2.unknown = data1[18:]
395         floor2 = samba.dcerpc.epmapper.epm_floor()
396         floor2.lhs = lhs2
397         floor2.rhs = rhs2
398         lhs3 = samba.dcerpc.epmapper.epm_lhs()
399         lhs3.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_NCACN
400         lhs3.lhs_data = b""
401         floor3 = samba.dcerpc.epmapper.epm_floor()
402         floor3.lhs = lhs3
403         floor3.rhs.minor_version = 0
404         lhs4 = samba.dcerpc.epmapper.epm_lhs()
405         lhs4.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_TCP
406         lhs4.lhs_data = b""
407         floor4 = samba.dcerpc.epmapper.epm_floor()
408         floor4.lhs = lhs4
409         floor4.rhs.port = self.tcp_port
410         lhs5 = samba.dcerpc.epmapper.epm_lhs()
411         lhs5.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_IP
412         lhs5.lhs_data = b""
413         floor5 = samba.dcerpc.epmapper.epm_floor()
414         floor5.lhs = lhs5
415         floor5.rhs.ipaddr = "0.0.0.0"
416
417         floors = [floor1, floor2, floor3, floor4, floor5]
418         req_tower = samba.dcerpc.epmapper.epm_tower()
419         req_tower.num_floors = len(floors)
420         req_tower.floors = floors
421         req_twr = samba.dcerpc.epmapper.epm_twr_t()
422         req_twr.tower = req_tower
423
424         epm_map = samba.dcerpc.epmapper.epm_Map()
425         epm_map.in_object = object
426         epm_map.in_map_tower = req_twr
427         epm_map.in_entry_handle = samba.dcerpc.misc.policy_handle()
428         epm_map.in_max_towers = 4
429
430         self.do_single_request(call_id=2, ctx=ctx, io=epm_map)
431
432         self.assertGreaterEqual(epm_map.out_num_towers, 1)
433         rep_twr = epm_map.out_towers[0].twr
434         self.assertIsNotNone(rep_twr)
435         self.assertEqual(rep_twr.tower_length, 75)
436         self.assertEqual(rep_twr.tower.num_floors, 5)
437         self.assertEqual(len(rep_twr.tower.floors), 5)
438         self.assertEqual(rep_twr.tower.floors[3].lhs.protocol,
439                          samba.dcerpc.epmapper.EPM_PROTOCOL_TCP)
440         self.assertEqual(rep_twr.tower.floors[3].lhs.protocol,
441                          samba.dcerpc.epmapper.EPM_PROTOCOL_TCP)
442
443         # reconnect to the given port
444         self._disconnect("epmap_reconnect")
445         self.tcp_port = rep_twr.tower.floors[3].rhs.port
446         self.connect()
447
448     def send_pdu(self, req, ndr_print=None, hexdump=None):
449         if ndr_print is None:
450             ndr_print = self.do_ndr_print
451         if hexdump is None:
452             hexdump = self.do_hexdump
453         try:
454             req_pdu = ndr_pack(req)
455             if ndr_print:
456                 sys.stderr.write("send_pdu: %s" % samba.ndr.ndr_print(req))
457             if hexdump:
458                 sys.stderr.write("send_pdu: %d\n%s" % (len(req_pdu), self.hexdump(req_pdu)))
459             while True:
460                 sent = self.s.send(req_pdu, 0)
461                 if sent == len(req_pdu):
462                     break
463                 req_pdu = req_pdu[sent:]
464         except socket.error as e:
465             self._disconnect("send_pdu: %s" % e)
466             raise
467         except IOError as e:
468             self._disconnect("send_pdu: %s" % e)
469             raise
470         finally:
471             pass
472
473     def recv_raw(self, hexdump=None, timeout=None):
474         rep_pdu = None
475         if hexdump is None:
476             hexdump = self.do_hexdump
477         try:
478             if timeout is not None:
479                 self.s.settimeout(timeout)
480             rep_pdu = self.s.recv(0xffff, 0)
481             self.s.settimeout(10)
482             if len(rep_pdu) == 0:
483                 self._disconnect("recv_raw: EOF")
484                 return None
485             if hexdump:
486                 sys.stderr.write("recv_raw: %d\n%s" % (len(rep_pdu), self.hexdump(rep_pdu)))
487         except socket.timeout as e:
488             self.s.settimeout(10)
489             sys.stderr.write("recv_raw: TIMEOUT\n")
490             pass
491         except socket.error as e:
492             self._disconnect("recv_raw: %s" % e)
493             raise
494         except IOError as e:
495             self._disconnect("recv_raw: %s" % e)
496             raise
497         finally:
498             pass
499         return rep_pdu
500
501     def recv_pdu_raw(self, ndr_print=None, hexdump=None, timeout=None):
502         rep_pdu = None
503         rep = None
504         if ndr_print is None:
505             ndr_print = self.do_ndr_print
506         if hexdump is None:
507             hexdump = self.do_hexdump
508         try:
509             rep_pdu = self.recv_raw(hexdump=hexdump, timeout=timeout)
510             if rep_pdu is None:
511                 return (None, None)
512             rep = ndr_unpack(samba.dcerpc.dcerpc.ncacn_packet, rep_pdu, allow_remaining=True)
513             if ndr_print:
514                 sys.stderr.write("recv_pdu: %s" % samba.ndr.ndr_print(rep))
515             self.assertEqual(rep.frag_length, len(rep_pdu))
516         finally:
517             pass
518         return (rep, rep_pdu)
519
520     def recv_pdu(self, ndr_print=None, hexdump=None, timeout=None):
521         (rep, rep_pdu) = self.recv_pdu_raw(ndr_print=ndr_print,
522                                            hexdump=hexdump,
523                                            timeout=timeout)
524         return rep
525
526     def generate_auth(self,
527                       auth_type=None,
528                       auth_level=None,
529                       auth_pad_length=0,
530                       auth_context_id=None,
531                       auth_blob=None,
532                       ndr_print=None, hexdump=None):
533         if ndr_print is None:
534             ndr_print = self.do_ndr_print
535         if hexdump is None:
536             hexdump = self.do_hexdump
537
538         if auth_type is not None:
539             a = samba.dcerpc.dcerpc.auth()
540             a.auth_type = auth_type
541             a.auth_level = auth_level
542             a.auth_pad_length = auth_pad_length
543             a.auth_context_id = auth_context_id
544             a.credentials = auth_blob
545
546             ai = ndr_pack(a)
547             if ndr_print:
548                 sys.stderr.write("generate_auth: %s" % samba.ndr.ndr_print(a))
549             if hexdump:
550                 sys.stderr.write("generate_auth: %d\n%s" % (len(ai), self.hexdump(ai)))
551         else:
552             ai = b""
553
554         return ai
555
556     def parse_auth(self, auth_info, ndr_print=None, hexdump=None,
557                    auth_context=None, stub_len=0):
558         if ndr_print is None:
559             ndr_print = self.do_ndr_print
560         if hexdump is None:
561             hexdump = self.do_hexdump
562
563         if (len(auth_info) <= samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH):
564             return None
565
566         if hexdump:
567             sys.stderr.write("parse_auth: %d\n%s" % (len(auth_info), self.hexdump(auth_info)))
568         a = ndr_unpack(samba.dcerpc.dcerpc.auth, auth_info, allow_remaining=True)
569         if ndr_print:
570             sys.stderr.write("parse_auth: %s" % samba.ndr.ndr_print(a))
571
572         if auth_context is not None:
573             self.assertEquals(a.auth_type, auth_context["auth_type"])
574             self.assertEquals(a.auth_level, auth_context["auth_level"])
575             self.assertEquals(a.auth_reserved, 0)
576             self.assertEquals(a.auth_context_id, auth_context["auth_context_id"])
577
578             self.assertLessEqual(a.auth_pad_length, dcerpc.DCERPC_AUTH_PAD_ALIGNMENT)
579             self.assertLessEqual(a.auth_pad_length, stub_len)
580
581         return a
582
583     def check_response_auth(self, rep, rep_blob, auth_context=None,
584                             auth_pad_length=None):
585
586         if auth_context is None:
587             self.assertEquals(rep.auth_length, 0)
588             return rep.u.stub_and_verifier
589
590         ofs_stub = dcerpc.DCERPC_REQUEST_LENGTH
591         ofs_sig = rep.frag_length - rep.auth_length
592         ofs_trailer = ofs_sig - dcerpc.DCERPC_AUTH_TRAILER_LENGTH
593         rep_data = rep_blob[ofs_stub:ofs_trailer]
594         rep_whole = rep_blob[0:ofs_sig]
595         rep_sig = rep_blob[ofs_sig:]
596         rep_auth_info_blob = rep_blob[ofs_trailer:]
597
598         rep_auth_info = self.parse_auth(rep_auth_info_blob,
599                                         auth_context=auth_context,
600                                         stub_len=len(rep_data))
601         if auth_pad_length is not None:
602             self.assertEquals(rep_auth_info.auth_pad_length, auth_pad_length)
603         self.assertEquals(rep_auth_info.credentials, rep_sig)
604
605         if auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PRIVACY:
606             # TODO: not yet supported here
607             self.assertTrue(False)
608         elif auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PACKET:
609             auth_context["gensec"].check_packet(rep_data, rep_whole, rep_sig)
610
611         stub_out = rep_data[0:len(rep_data)-rep_auth_info.auth_pad_length]
612
613         return stub_out
614
615     def generate_pdu(self, ptype, call_id, payload,
616                      rpc_vers=5,
617                      rpc_vers_minor=0,
618                      pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
619                                 samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
620                      drep=[samba.dcerpc.dcerpc.DCERPC_DREP_LE, 0, 0, 0],
621                      ndr_print=None, hexdump=None):
622
623         if getattr(payload, 'auth_info', None):
624             ai = payload.auth_info
625         else:
626             ai = b""
627
628         p = samba.dcerpc.dcerpc.ncacn_packet()
629         p.rpc_vers = rpc_vers
630         p.rpc_vers_minor = rpc_vers_minor
631         p.ptype = ptype
632         p.pfc_flags = pfc_flags
633         p.drep = drep
634         p.frag_length = 0
635         if len(ai) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
636             p.auth_length = len(ai) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
637         else:
638             p.auth_length = 0
639         p.call_id = call_id
640         p.u = payload
641
642         pdu = ndr_pack(p)
643         p.frag_length = len(pdu)
644
645         return p
646
647     def generate_request_auth(self, call_id,
648                               pfc_flags=(dcerpc.DCERPC_PFC_FLAG_FIRST |
649                                          dcerpc.DCERPC_PFC_FLAG_LAST),
650                               alloc_hint=None,
651                               context_id=None,
652                               opnum=None,
653                               object=None,
654                               stub=None,
655                               auth_context=None,
656                               ndr_print=None, hexdump=None):
657
658         if stub is None:
659             stub = b""
660
661         sig_size = 0
662         if auth_context is not None:
663             mod_len = len(stub) % dcerpc.DCERPC_AUTH_PAD_ALIGNMENT
664             auth_pad_length = 0
665             if mod_len > 0:
666                 auth_pad_length = dcerpc.DCERPC_AUTH_PAD_ALIGNMENT - mod_len
667             stub += b'\x00' * auth_pad_length
668
669             if auth_context["g_auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
670                 sig_size = auth_context["gensec"].sig_size(len(stub))
671             else:
672                 sig_size = 16
673
674             zero_sig = b"\x00" * sig_size
675             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
676                                            auth_level=auth_context["auth_level"],
677                                            auth_pad_length=auth_pad_length,
678                                            auth_context_id=auth_context["auth_context_id"],
679                                            auth_blob=zero_sig)
680         else:
681             auth_info = b""
682
683         req = self.generate_request(call_id=call_id,
684                                     pfc_flags=pfc_flags,
685                                     alloc_hint=alloc_hint,
686                                     context_id=context_id,
687                                     opnum=opnum,
688                                     object=object,
689                                     stub=stub,
690                                     auth_info=auth_info,
691                                     ndr_print=ndr_print,
692                                     hexdump=hexdump)
693         if auth_context is None:
694             return req
695
696         req_blob = samba.ndr.ndr_pack(req)
697         ofs_stub = dcerpc.DCERPC_REQUEST_LENGTH
698         ofs_sig = len(req_blob) - req.auth_length
699         ofs_trailer = ofs_sig - dcerpc.DCERPC_AUTH_TRAILER_LENGTH
700         req_data = req_blob[ofs_stub:ofs_trailer]
701         req_whole = req_blob[0:ofs_sig]
702
703         if auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PRIVACY:
704             # TODO: not yet supported here
705             self.assertTrue(False)
706         elif auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PACKET:
707             req_sig = auth_context["gensec"].sign_packet(req_data, req_whole)
708         else:
709             return req
710         self.assertEquals(len(req_sig), req.auth_length)
711         self.assertEquals(len(req_sig), sig_size)
712
713         stub_sig_ofs = len(req.u.stub_and_verifier) - sig_size
714         stub = req.u.stub_and_verifier[0:stub_sig_ofs] + req_sig
715         req.u.stub_and_verifier = stub
716
717         return req
718
719     def verify_pdu(self, p, ptype, call_id,
720                    rpc_vers=5,
721                    rpc_vers_minor=0,
722                    pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
723                               samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
724                    drep=[samba.dcerpc.dcerpc.DCERPC_DREP_LE, 0, 0, 0],
725                    auth_length=None):
726
727         self.assertIsNotNone(p, "No valid pdu")
728
729         if getattr(p.u, 'auth_info', None):
730             ai = p.u.auth_info
731         else:
732             ai = b""
733
734         self.assertEqual(p.rpc_vers, rpc_vers)
735         self.assertEqual(p.rpc_vers_minor, rpc_vers_minor)
736         self.assertEqual(p.ptype, ptype)
737         self.assertEqual(p.pfc_flags, pfc_flags)
738         self.assertEqual(p.drep, drep)
739         self.assertGreaterEqual(p.frag_length,
740                                 samba.dcerpc.dcerpc.DCERPC_NCACN_PAYLOAD_OFFSET)
741         if len(ai) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
742             self.assertEqual(p.auth_length,
743                              len(ai) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
744         elif auth_length is not None:
745             self.assertEqual(p.auth_length, auth_length)
746         else:
747             self.assertEqual(p.auth_length, 0)
748         self.assertEqual(p.call_id, call_id)
749
750         return
751
752     def generate_bind(self, call_id,
753                       pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
754                                  samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
755                       max_xmit_frag=5840,
756                       max_recv_frag=5840,
757                       assoc_group_id=0,
758                       ctx_list=[],
759                       auth_info=b"",
760                       ndr_print=None, hexdump=None):
761
762         b = samba.dcerpc.dcerpc.bind()
763         b.max_xmit_frag = max_xmit_frag
764         b.max_recv_frag = max_recv_frag
765         b.assoc_group_id = assoc_group_id
766         b.num_contexts = len(ctx_list)
767         b.ctx_list = ctx_list
768         b.auth_info = auth_info
769
770         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_BIND,
771                               pfc_flags=pfc_flags,
772                               call_id=call_id,
773                               payload=b,
774                               ndr_print=ndr_print, hexdump=hexdump)
775
776         return p
777
778     def generate_alter(self, call_id,
779                        pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
780                                   samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
781                        max_xmit_frag=5840,
782                        max_recv_frag=5840,
783                        assoc_group_id=0,
784                        ctx_list=[],
785                        auth_info=b"",
786                        ndr_print=None, hexdump=None):
787
788         a = samba.dcerpc.dcerpc.bind()
789         a.max_xmit_frag = max_xmit_frag
790         a.max_recv_frag = max_recv_frag
791         a.assoc_group_id = assoc_group_id
792         a.num_contexts = len(ctx_list)
793         a.ctx_list = ctx_list
794         a.auth_info = auth_info
795
796         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_ALTER,
797                               pfc_flags=pfc_flags,
798                               call_id=call_id,
799                               payload=a,
800                               ndr_print=ndr_print, hexdump=hexdump)
801
802         return p
803
804     def generate_auth3(self, call_id,
805                        pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
806                                   samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
807                        auth_info=b"",
808                        ndr_print=None, hexdump=None):
809
810         a = samba.dcerpc.dcerpc.auth3()
811         a.auth_info = auth_info
812
813         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_AUTH3,
814                               pfc_flags=pfc_flags,
815                               call_id=call_id,
816                               payload=a,
817                               ndr_print=ndr_print, hexdump=hexdump)
818
819         return p
820
821     def generate_request(self, call_id,
822                          pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
823                                     samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
824                          alloc_hint=None,
825                          context_id=None,
826                          opnum=None,
827                          object=None,
828                          stub=None,
829                          auth_info=b"",
830                          ndr_print=None, hexdump=None):
831
832         if alloc_hint is None:
833             alloc_hint = len(stub)
834
835         r = samba.dcerpc.dcerpc.request()
836         r.alloc_hint = alloc_hint
837         r.context_id = context_id
838         r.opnum = opnum
839         if object is not None:
840             r.object = object
841         r.stub_and_verifier = stub + auth_info
842
843         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_REQUEST,
844                               pfc_flags=pfc_flags,
845                               call_id=call_id,
846                               payload=r,
847                               ndr_print=ndr_print, hexdump=hexdump)
848
849         if len(auth_info) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
850             p.auth_length = len(auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
851
852         return p
853
854     def generate_co_cancel(self, call_id,
855                            pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
856                                       samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
857                            auth_info=b"",
858                            ndr_print=None, hexdump=None):
859
860         c = samba.dcerpc.dcerpc.co_cancel()
861         c.auth_info = auth_info
862
863         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_CO_CANCEL,
864                               pfc_flags=pfc_flags,
865                               call_id=call_id,
866                               payload=c,
867                               ndr_print=ndr_print, hexdump=hexdump)
868
869         return p
870
871     def generate_orphaned(self, call_id,
872                           pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
873                                      samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
874                           auth_info=b"",
875                           ndr_print=None, hexdump=None):
876
877         o = samba.dcerpc.dcerpc.orphaned()
878         o.auth_info = auth_info
879
880         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_ORPHANED,
881                               pfc_flags=pfc_flags,
882                               call_id=call_id,
883                               payload=o,
884                               ndr_print=ndr_print, hexdump=hexdump)
885
886         return p
887
888     def generate_shutdown(self, call_id,
889                           pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
890                                      samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
891                           ndr_print=None, hexdump=None):
892
893         s = samba.dcerpc.dcerpc.shutdown()
894
895         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_SHUTDOWN,
896                               pfc_flags=pfc_flags,
897                               call_id=call_id,
898                               payload=s,
899                               ndr_print=ndr_print, hexdump=hexdump)
900
901         return p
902
903     def assertIsConnected(self):
904         self.assertIsNotNone(self.s, msg="Not connected")
905         return
906
907     def assertNotConnected(self):
908         self.assertIsNone(self.s, msg="Is connected")
909         return
910
911     def assertNDRSyntaxEquals(self, s1, s2):
912         self.assertEqual(s1.uuid, s2.uuid)
913         self.assertEqual(s1.if_version, s2.if_version)
914         return
915
916     def assertPadding(self, pad, length):
917         self.assertEquals(len(pad), length)
918         #
919         # sometimes windows sends random bytes
920         #
921         # we have IGNORE_RANDOM_PAD=1 to
922         # disable the check
923         #
924         if self.ignore_random_pad:
925             return
926         zero_pad = b'\0' * length
927         self.assertEquals(pad, zero_pad)