32d626ec90b17ef022f10b4384f5a4fe560e842a
[samba.git] / python / samba / tests / dcerpc / raw_testcase.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 import sys
20 import socket
21 import samba.dcerpc.dcerpc as dcerpc
22 import samba.dcerpc.base
23 import samba.dcerpc.epmapper
24 import samba.tests
25 from samba import gensec
26 from samba.credentials import Credentials
27 from samba.tests import TestCase
28 from samba.ndr import ndr_pack, ndr_unpack, ndr_unpack_out
29 from samba.compat import text_type
30
31
32 class RawDCERPCTest(TestCase):
33     """A raw DCE/RPC Test case."""
34
35     def _disconnect(self, reason):
36         if self.s is None:
37             return
38         self.s.close()
39         self.s = None
40         if self.do_hexdump:
41             sys.stderr.write("disconnect[%s]\n" % reason)
42
43     def connect(self):
44         try:
45             self.a = socket.getaddrinfo(self.host, self.tcp_port, socket.AF_UNSPEC,
46                                         socket.SOCK_STREAM, socket.SOL_TCP,
47                                         0)
48             self.s = socket.socket(self.a[0][0], self.a[0][1], self.a[0][2])
49             self.s.settimeout(10)
50             self.s.connect(self.a[0][4])
51         except socket.error as e:
52             self.s.close()
53             raise
54         except IOError as e:
55             self.s.close()
56             raise
57         except Exception as e:
58             raise
59         finally:
60             pass
61
62     def setUp(self):
63         super(RawDCERPCTest, self).setUp()
64         self.do_ndr_print = False
65         self.do_hexdump = False
66
67         self.ignore_random_pad = samba.tests.env_get_var_value('IGNORE_RANDOM_PAD',
68                                                                allow_missing=True)
69         self.host = samba.tests.env_get_var_value('SERVER')
70         self.target_hostname = samba.tests.env_get_var_value('TARGET_HOSTNAME', allow_missing=True)
71         if self.target_hostname is None:
72             self.target_hostname = self.host
73         self.tcp_port = 135
74
75         self.settings = {}
76         self.settings["lp_ctx"] = self.lp_ctx = samba.tests.env_loadparm()
77         self.settings["target_hostname"] = self.target_hostname
78
79         self.connect()
80
81     def tearDown(self):
82         self._disconnect("tearDown")
83         super(TestCase, self).tearDown()
84
85     def noop(self):
86         return
87
88     def second_connection(self, tcp_port=None):
89         c = RawDCERPCTest(methodName='noop')
90         c.do_ndr_print = self.do_ndr_print
91         c.do_hexdump = self.do_hexdump
92         c.ignore_random_pad = self.ignore_random_pad
93
94         c.host = self.host
95         c.target_hostname = self.target_hostname
96         if tcp_port is not None:
97             c.tcp_port = tcp_port
98         else:
99             c.tcp_port = self.tcp_port
100
101         c.settings = self.settings
102
103         c.connect()
104         return c
105
106     def get_user_creds(self):
107         c = Credentials()
108         c.guess()
109         username = samba.tests.env_get_var_value('USERNAME')
110         password = samba.tests.env_get_var_value('PASSWORD')
111         c.set_username(username)
112         c.set_password(password)
113         return c
114
115     def get_anon_creds(self):
116         c = Credentials()
117         c.set_anonymous()
118         return c
119
120     def get_auth_context_creds(self, creds, auth_type, auth_level,
121                                auth_context_id,
122                                g_auth_level=None):
123
124         if g_auth_level is None:
125             g_auth_level = auth_level
126
127         g = gensec.Security.start_client(self.settings)
128         g.set_credentials(creds)
129         g.want_feature(gensec.FEATURE_DCE_STYLE)
130         g.start_mech_by_authtype(auth_type, g_auth_level)
131
132         auth_context = {}
133         auth_context["auth_type"] = auth_type
134         auth_context["auth_level"] = auth_level
135         auth_context["auth_context_id"] = auth_context_id
136         auth_context["g_auth_level"] = g_auth_level
137         auth_context["gensec"] = g
138
139         return auth_context
140
141     def do_generic_bind(self, ctx, auth_context=None,
142                         pfc_flags=samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
143                         samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
144                         assoc_group_id=0, call_id=0,
145                         nak_reason=None, alter_fault=None):
146         ctx_list = [ctx]
147
148         if auth_context is not None:
149             from_server = b""
150             (finished, to_server) = auth_context["gensec"].update(from_server)
151             self.assertFalse(finished)
152
153             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
154                                            auth_level=auth_context["auth_level"],
155                                            auth_context_id=auth_context["auth_context_id"],
156                                            auth_blob=to_server)
157         else:
158             auth_info = b""
159
160         req = self.generate_bind(call_id=call_id,
161                                  pfc_flags=pfc_flags,
162                                  ctx_list=ctx_list,
163                                  assoc_group_id=assoc_group_id,
164                                  auth_info=auth_info)
165         self.send_pdu(req)
166         rep = self.recv_pdu()
167         if nak_reason is not None:
168             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_BIND_NAK, req.call_id,
169                             auth_length=0)
170             self.assertEquals(rep.u.reject_reason, nak_reason)
171             self.assertEquals(rep.u.num_versions, 1)
172             self.assertEquals(rep.u.versions[0].rpc_vers, req.rpc_vers)
173             self.assertEquals(rep.u.versions[0].rpc_vers_minor, req.rpc_vers_minor)
174             self.assertPadding(rep.u._pad, 3)
175             return
176         self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_BIND_ACK, req.call_id,
177                         pfc_flags=pfc_flags)
178         self.assertEquals(rep.u.max_xmit_frag, req.u.max_xmit_frag)
179         self.assertEquals(rep.u.max_recv_frag, req.u.max_recv_frag)
180         if assoc_group_id != 0:
181             self.assertEquals(rep.u.assoc_group_id, assoc_group_id)
182         else:
183             self.assertNotEquals(rep.u.assoc_group_id, 0)
184             assoc_group_id = rep.u.assoc_group_id
185         port_str = "%d" % self.tcp_port
186         port_len = len(port_str) + 1
187         mod_len = (2 + port_len) % 4
188         if mod_len != 0:
189             port_pad = 4 - mod_len
190         else:
191             port_pad = 0
192         self.assertEquals(rep.u.secondary_address_size, port_len)
193         self.assertEquals(rep.u.secondary_address, port_str)
194         self.assertPadding(rep.u._pad1, port_pad)
195         self.assertEquals(rep.u.num_results, 1)
196         self.assertEquals(rep.u.ctx_list[0].result,
197                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE)
198         self.assertEquals(rep.u.ctx_list[0].reason,
199                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_REASON_NOT_SPECIFIED)
200         self.assertNDRSyntaxEquals(rep.u.ctx_list[0].syntax, ctx.transfer_syntaxes[0])
201         ack = rep
202         if auth_context is None:
203             self.assertEquals(rep.auth_length, 0)
204             self.assertEquals(len(rep.u.auth_info), 0)
205             return ack
206         self.assertNotEquals(rep.auth_length, 0)
207         self.assertGreater(len(rep.u.auth_info), samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
208         self.assertEquals(rep.auth_length, len(rep.u.auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
209
210         a = self.parse_auth(rep.u.auth_info, auth_context=auth_context)
211
212         from_server = a.credentials
213         (finished, to_server) = auth_context["gensec"].update(from_server)
214         self.assertFalse(finished)
215
216         auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
217                                        auth_level=auth_context["auth_level"],
218                                        auth_context_id=auth_context["auth_context_id"],
219                                        auth_blob=to_server)
220         req = self.generate_alter(call_id=call_id,
221                                   ctx_list=ctx_list,
222                                   assoc_group_id=0xffffffff - assoc_group_id,
223                                   auth_info=auth_info)
224         self.send_pdu(req)
225         rep = self.recv_pdu()
226         if alter_fault is not None:
227             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_FAULT, req.call_id,
228                             pfc_flags=req.pfc_flags |
229                             samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_DID_NOT_EXECUTE,
230                             auth_length=0)
231             self.assertNotEquals(rep.u.alloc_hint, 0)
232             self.assertEquals(rep.u.context_id, 0)
233             self.assertEquals(rep.u.cancel_count, 0)
234             self.assertEquals(rep.u.flags, 0)
235             self.assertEquals(rep.u.status, alter_fault)
236             self.assertEquals(rep.u.reserved, 0)
237             self.assertEquals(len(rep.u.error_and_verifier), 0)
238             return None
239         self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_ALTER_RESP, req.call_id)
240         self.assertEquals(rep.u.max_xmit_frag, req.u.max_xmit_frag)
241         self.assertEquals(rep.u.max_recv_frag, req.u.max_recv_frag)
242         self.assertEquals(rep.u.assoc_group_id, assoc_group_id)
243         self.assertEquals(rep.u.secondary_address_size, 0)
244         self.assertEquals(rep.u.secondary_address, '')
245         self.assertPadding(rep.u._pad1, 2)
246         self.assertEquals(rep.u.num_results, 1)
247         self.assertEquals(rep.u.ctx_list[0].result,
248                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE)
249         self.assertEquals(rep.u.ctx_list[0].reason,
250                           samba.dcerpc.dcerpc.DCERPC_BIND_ACK_REASON_NOT_SPECIFIED)
251         self.assertNDRSyntaxEquals(rep.u.ctx_list[0].syntax, ctx.transfer_syntaxes[0])
252         self.assertNotEquals(rep.auth_length, 0)
253         self.assertGreater(len(rep.u.auth_info), samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
254         self.assertEquals(rep.auth_length, len(rep.u.auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
255
256         a = self.parse_auth(rep.u.auth_info, auth_context=auth_context)
257
258         from_server = a.credentials
259         (finished, to_server) = auth_context["gensec"].update(from_server)
260         self.assertTrue(finished)
261
262         return ack
263
264     def prepare_presentation(self, abstract, transfer, object=None,
265                              context_id=0xffff, epmap=False, auth_context=None,
266                              pfc_flags=samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
267                              samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
268                              assoc_group_id=0,
269                              return_ack=False):
270         if epmap:
271             self.epmap_reconnect(abstract, transfer=transfer, object=object)
272
273         tsf1_list = [transfer]
274         ctx = samba.dcerpc.dcerpc.ctx_list()
275         ctx.context_id = context_id
276         ctx.num_transfer_syntaxes = len(tsf1_list)
277         ctx.abstract_syntax = abstract
278         ctx.transfer_syntaxes = tsf1_list
279
280         ack = self.do_generic_bind(ctx=ctx,
281                                    auth_context=auth_context,
282                                    pfc_flags=pfc_flags,
283                                    assoc_group_id=assoc_group_id)
284         if ack is None:
285             ctx = None
286
287         if return_ack:
288             return (ctx, ack)
289         return ctx
290
291     def do_single_request(self, call_id, ctx, io,
292                           auth_context=None,
293                           object=None,
294                           bigendian=False, ndr64=False,
295                           allow_remaining=False,
296                           send_req=True,
297                           recv_rep=True,
298                           fault_pfc_flags=(
299                               samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
300                               samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
301                           fault_status=None,
302                           fault_context_id=None,
303                           timeout=None,
304                           ndr_print=None,
305                           hexdump=None):
306
307         if fault_context_id is None:
308             fault_context_id = ctx.context_id
309
310         if ndr_print is None:
311             ndr_print = self.do_ndr_print
312         if hexdump is None:
313             hexdump = self.do_hexdump
314
315         if send_req:
316             if ndr_print:
317                 sys.stderr.write("in: %s" % samba.ndr.ndr_print_in(io))
318             stub_in = samba.ndr.ndr_pack_in(io, bigendian=bigendian, ndr64=ndr64)
319             if hexdump:
320                 sys.stderr.write("stub_in: %d\n%s" % (len(stub_in), self.hexdump(stub_in)))
321         else:
322             # only used for sig_size calculation
323             stub_in = b'\xff' * samba.dcerpc.dcerpc.DCERPC_AUTH_PAD_ALIGNMENT
324
325         sig_size = 0
326         if auth_context is not None:
327             mod_len = len(stub_in) % samba.dcerpc.dcerpc.DCERPC_AUTH_PAD_ALIGNMENT
328             auth_pad_length = 0
329             if mod_len > 0:
330                 auth_pad_length = samba.dcerpc.dcerpc.DCERPC_AUTH_PAD_ALIGNMENT - mod_len
331             stub_in += b'\x00' * auth_pad_length
332
333             if auth_context["g_auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
334                 sig_size = auth_context["gensec"].sig_size(len(stub_in))
335             else:
336                 sig_size = 16
337
338             zero_sig = b"\x00" * sig_size
339             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
340                                            auth_level=auth_context["auth_level"],
341                                            auth_pad_length=auth_pad_length,
342                                            auth_context_id=auth_context["auth_context_id"],
343                                            auth_blob=zero_sig)
344         else:
345             auth_info = b""
346
347         pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST
348         pfc_flags |= samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST
349         if object is not None:
350             pfc_flags |= samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_OBJECT_UUID
351
352         req = self.generate_request(call_id=call_id,
353                                     context_id=ctx.context_id,
354                                     pfc_flags=pfc_flags,
355                                     object=object,
356                                     opnum=io.opnum(),
357                                     stub=stub_in,
358                                     auth_info=auth_info)
359
360         if send_req:
361             if sig_size != 0 and auth_context["auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
362                 req_blob = samba.ndr.ndr_pack(req)
363                 ofs_stub = samba.dcerpc.dcerpc.DCERPC_REQUEST_LENGTH
364                 ofs_sig = len(req_blob) - req.auth_length
365                 ofs_trailer = ofs_sig - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
366                 req_data = req_blob[ofs_stub:ofs_trailer]
367                 req_whole = req_blob[0:ofs_sig]
368                 sig = auth_context["gensec"].sign_packet(req_data, req_whole)
369                 auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
370                                                auth_level=auth_context["auth_level"],
371                                                auth_pad_length=auth_pad_length,
372                                                auth_context_id=auth_context["auth_context_id"],
373                                                auth_blob=sig)
374                 req = self.generate_request(call_id=call_id,
375                                             context_id=ctx.context_id,
376                                             pfc_flags=pfc_flags,
377                                             object=object,
378                                             opnum=io.opnum(),
379                                             stub=stub_in,
380                                             auth_info=auth_info)
381             self.send_pdu(req, ndr_print=ndr_print, hexdump=hexdump)
382         if recv_rep:
383             (rep, rep_blob) = self.recv_pdu_raw(timeout=timeout,
384                                                 ndr_print=ndr_print,
385                                                 hexdump=hexdump)
386             if fault_status:
387                 self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_FAULT, req.call_id,
388                                 pfc_flags=fault_pfc_flags, auth_length=0)
389                 self.assertNotEquals(rep.u.alloc_hint, 0)
390                 self.assertEquals(rep.u.context_id, fault_context_id)
391                 self.assertEquals(rep.u.cancel_count, 0)
392                 self.assertEquals(rep.u.flags, 0)
393                 self.assertEquals(rep.u.status, fault_status)
394                 self.assertEquals(rep.u.reserved, 0)
395                 self.assertEquals(len(rep.u.error_and_verifier), 0)
396                 return
397
398             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_RESPONSE, req.call_id,
399                             auth_length=sig_size)
400             self.assertNotEquals(rep.u.alloc_hint, 0)
401             self.assertEquals(rep.u.context_id, req.u.context_id & 0xff)
402             self.assertEquals(rep.u.cancel_count, 0)
403             self.assertGreaterEqual(len(rep.u.stub_and_verifier), rep.u.alloc_hint)
404             stub_out = self.check_response_auth(rep, rep_blob, auth_context)
405             self.assertEqual(len(stub_out), rep.u.alloc_hint)
406
407             if hexdump:
408                 sys.stderr.write("stub_out: %d\n%s" % (len(stub_out), self.hexdump(stub_out)))
409             ndr_unpack_out(io, stub_out, bigendian=bigendian, ndr64=ndr64,
410                            allow_remaining=allow_remaining)
411             if ndr_print:
412                 sys.stderr.write("out: %s" % samba.ndr.ndr_print_out(io))
413
414     def epmap_reconnect(self, abstract, transfer=None, object=None):
415         ndr32 = samba.dcerpc.base.transfer_syntax_ndr()
416
417         if transfer is None:
418             transfer = ndr32
419
420         if object is None:
421             object = samba.dcerpc.misc.GUID()
422
423         ctx = self.prepare_presentation(samba.dcerpc.epmapper.abstract_syntax(),
424                                         transfer, context_id=0)
425
426         data1 = ndr_pack(abstract)
427         lhs1 = samba.dcerpc.epmapper.epm_lhs()
428         lhs1.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_UUID
429         lhs1.lhs_data = data1[:18]
430         rhs1 = samba.dcerpc.epmapper.epm_rhs_uuid()
431         rhs1.unknown = data1[18:]
432         floor1 = samba.dcerpc.epmapper.epm_floor()
433         floor1.lhs = lhs1
434         floor1.rhs = rhs1
435         data2 = ndr_pack(transfer)
436         lhs2 = samba.dcerpc.epmapper.epm_lhs()
437         lhs2.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_UUID
438         lhs2.lhs_data = data2[:18]
439         rhs2 = samba.dcerpc.epmapper.epm_rhs_uuid()
440         rhs2.unknown = data1[18:]
441         floor2 = samba.dcerpc.epmapper.epm_floor()
442         floor2.lhs = lhs2
443         floor2.rhs = rhs2
444         lhs3 = samba.dcerpc.epmapper.epm_lhs()
445         lhs3.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_NCACN
446         lhs3.lhs_data = b""
447         floor3 = samba.dcerpc.epmapper.epm_floor()
448         floor3.lhs = lhs3
449         floor3.rhs.minor_version = 0
450         lhs4 = samba.dcerpc.epmapper.epm_lhs()
451         lhs4.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_TCP
452         lhs4.lhs_data = b""
453         floor4 = samba.dcerpc.epmapper.epm_floor()
454         floor4.lhs = lhs4
455         floor4.rhs.port = self.tcp_port
456         lhs5 = samba.dcerpc.epmapper.epm_lhs()
457         lhs5.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_IP
458         lhs5.lhs_data = b""
459         floor5 = samba.dcerpc.epmapper.epm_floor()
460         floor5.lhs = lhs5
461         floor5.rhs.ipaddr = "0.0.0.0"
462
463         floors = [floor1, floor2, floor3, floor4, floor5]
464         req_tower = samba.dcerpc.epmapper.epm_tower()
465         req_tower.num_floors = len(floors)
466         req_tower.floors = floors
467         req_twr = samba.dcerpc.epmapper.epm_twr_t()
468         req_twr.tower = req_tower
469
470         epm_map = samba.dcerpc.epmapper.epm_Map()
471         epm_map.in_object = object
472         epm_map.in_map_tower = req_twr
473         epm_map.in_entry_handle = samba.dcerpc.misc.policy_handle()
474         epm_map.in_max_towers = 4
475
476         self.do_single_request(call_id=2, ctx=ctx, io=epm_map)
477
478         self.assertGreaterEqual(epm_map.out_num_towers, 1)
479         rep_twr = epm_map.out_towers[0].twr
480         self.assertIsNotNone(rep_twr)
481         self.assertEqual(rep_twr.tower_length, 75)
482         self.assertEqual(rep_twr.tower.num_floors, 5)
483         self.assertEqual(len(rep_twr.tower.floors), 5)
484         self.assertEqual(rep_twr.tower.floors[3].lhs.protocol,
485                          samba.dcerpc.epmapper.EPM_PROTOCOL_TCP)
486         self.assertEqual(rep_twr.tower.floors[3].lhs.protocol,
487                          samba.dcerpc.epmapper.EPM_PROTOCOL_TCP)
488
489         # reconnect to the given port
490         self._disconnect("epmap_reconnect")
491         self.tcp_port = rep_twr.tower.floors[3].rhs.port
492         self.connect()
493
494     def send_pdu(self, req, ndr_print=None, hexdump=None):
495         if ndr_print is None:
496             ndr_print = self.do_ndr_print
497         if hexdump is None:
498             hexdump = self.do_hexdump
499         try:
500             req_pdu = ndr_pack(req)
501             if ndr_print:
502                 sys.stderr.write("send_pdu: %s" % samba.ndr.ndr_print(req))
503             if hexdump:
504                 sys.stderr.write("send_pdu: %d\n%s" % (len(req_pdu), self.hexdump(req_pdu)))
505             while True:
506                 sent = self.s.send(req_pdu, 0)
507                 if sent == len(req_pdu):
508                     break
509                 req_pdu = req_pdu[sent:]
510         except socket.error as e:
511             self._disconnect("send_pdu: %s" % e)
512             raise
513         except IOError as e:
514             self._disconnect("send_pdu: %s" % e)
515             raise
516         finally:
517             pass
518
519     def recv_raw(self, hexdump=None, timeout=None):
520         rep_pdu = None
521         if hexdump is None:
522             hexdump = self.do_hexdump
523         try:
524             if timeout is not None:
525                 self.s.settimeout(timeout)
526             rep_pdu = self.s.recv(0xffff, 0)
527             self.s.settimeout(10)
528             if len(rep_pdu) == 0:
529                 self._disconnect("recv_raw: EOF")
530                 return None
531             if hexdump:
532                 sys.stderr.write("recv_raw: %d\n%s" % (len(rep_pdu), self.hexdump(rep_pdu)))
533         except socket.timeout as e:
534             self.s.settimeout(10)
535             sys.stderr.write("recv_raw: TIMEOUT\n")
536             pass
537         except socket.error as e:
538             self._disconnect("recv_raw: %s" % e)
539             raise
540         except IOError as e:
541             self._disconnect("recv_raw: %s" % e)
542             raise
543         finally:
544             pass
545         return rep_pdu
546
547     def recv_pdu_raw(self, ndr_print=None, hexdump=None, timeout=None):
548         rep_pdu = None
549         rep = None
550         if ndr_print is None:
551             ndr_print = self.do_ndr_print
552         if hexdump is None:
553             hexdump = self.do_hexdump
554         try:
555             rep_pdu = self.recv_raw(hexdump=hexdump, timeout=timeout)
556             if rep_pdu is None:
557                 return (None, None)
558             rep = ndr_unpack(samba.dcerpc.dcerpc.ncacn_packet, rep_pdu, allow_remaining=True)
559             if ndr_print:
560                 sys.stderr.write("recv_pdu: %s" % samba.ndr.ndr_print(rep))
561             self.assertEqual(rep.frag_length, len(rep_pdu))
562         finally:
563             pass
564         return (rep, rep_pdu)
565
566     def recv_pdu(self, ndr_print=None, hexdump=None, timeout=None):
567         (rep, rep_pdu) = self.recv_pdu_raw(ndr_print=ndr_print,
568                                            hexdump=hexdump,
569                                            timeout=timeout)
570         return rep
571
572     def generate_auth(self,
573                       auth_type=None,
574                       auth_level=None,
575                       auth_pad_length=0,
576                       auth_context_id=None,
577                       auth_blob=None,
578                       ndr_print=None, hexdump=None):
579         if ndr_print is None:
580             ndr_print = self.do_ndr_print
581         if hexdump is None:
582             hexdump = self.do_hexdump
583
584         if auth_type is not None:
585             a = samba.dcerpc.dcerpc.auth()
586             a.auth_type = auth_type
587             a.auth_level = auth_level
588             a.auth_pad_length = auth_pad_length
589             a.auth_context_id = auth_context_id
590             a.credentials = auth_blob
591
592             ai = ndr_pack(a)
593             if ndr_print:
594                 sys.stderr.write("generate_auth: %s" % samba.ndr.ndr_print(a))
595             if hexdump:
596                 sys.stderr.write("generate_auth: %d\n%s" % (len(ai), self.hexdump(ai)))
597         else:
598             ai = b""
599
600         return ai
601
602     def parse_auth(self, auth_info, ndr_print=None, hexdump=None,
603                    auth_context=None, stub_len=0):
604         if ndr_print is None:
605             ndr_print = self.do_ndr_print
606         if hexdump is None:
607             hexdump = self.do_hexdump
608
609         if (len(auth_info) <= samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH):
610             return None
611
612         if hexdump:
613             sys.stderr.write("parse_auth: %d\n%s" % (len(auth_info), self.hexdump(auth_info)))
614         a = ndr_unpack(samba.dcerpc.dcerpc.auth, auth_info, allow_remaining=True)
615         if ndr_print:
616             sys.stderr.write("parse_auth: %s" % samba.ndr.ndr_print(a))
617
618         if auth_context is not None:
619             self.assertEquals(a.auth_type, auth_context["auth_type"])
620             self.assertEquals(a.auth_level, auth_context["auth_level"])
621             self.assertEquals(a.auth_reserved, 0)
622             self.assertEquals(a.auth_context_id, auth_context["auth_context_id"])
623
624             self.assertLessEqual(a.auth_pad_length, dcerpc.DCERPC_AUTH_PAD_ALIGNMENT)
625             self.assertLessEqual(a.auth_pad_length, stub_len)
626
627         return a
628
629     def check_response_auth(self, rep, rep_blob, auth_context=None,
630                             auth_pad_length=None):
631
632         if auth_context is None:
633             self.assertEquals(rep.auth_length, 0)
634             return rep.u.stub_and_verifier
635
636         ofs_stub = dcerpc.DCERPC_REQUEST_LENGTH
637         ofs_sig = rep.frag_length - rep.auth_length
638         ofs_trailer = ofs_sig - dcerpc.DCERPC_AUTH_TRAILER_LENGTH
639         rep_data = rep_blob[ofs_stub:ofs_trailer]
640         rep_whole = rep_blob[0:ofs_sig]
641         rep_sig = rep_blob[ofs_sig:]
642         rep_auth_info_blob = rep_blob[ofs_trailer:]
643
644         rep_auth_info = self.parse_auth(rep_auth_info_blob,
645                                         auth_context=auth_context,
646                                         stub_len=len(rep_data))
647         if auth_pad_length is not None:
648             self.assertEquals(rep_auth_info.auth_pad_length, auth_pad_length)
649         self.assertEquals(rep_auth_info.credentials, rep_sig)
650
651         if auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PRIVACY:
652             # TODO: not yet supported here
653             self.assertTrue(False)
654         elif auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PACKET:
655             auth_context["gensec"].check_packet(rep_data, rep_whole, rep_sig)
656
657         stub_out = rep_data[0:len(rep_data)-rep_auth_info.auth_pad_length]
658
659         return stub_out
660
661     def generate_pdu(self, ptype, call_id, payload,
662                      rpc_vers=5,
663                      rpc_vers_minor=0,
664                      pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
665                                 samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
666                      drep=[samba.dcerpc.dcerpc.DCERPC_DREP_LE, 0, 0, 0],
667                      ndr_print=None, hexdump=None):
668
669         if getattr(payload, 'auth_info', None):
670             ai = payload.auth_info
671         else:
672             ai = b""
673
674         p = samba.dcerpc.dcerpc.ncacn_packet()
675         p.rpc_vers = rpc_vers
676         p.rpc_vers_minor = rpc_vers_minor
677         p.ptype = ptype
678         p.pfc_flags = pfc_flags
679         p.drep = drep
680         p.frag_length = 0
681         if len(ai) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
682             p.auth_length = len(ai) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
683         else:
684             p.auth_length = 0
685         p.call_id = call_id
686         p.u = payload
687
688         pdu = ndr_pack(p)
689         p.frag_length = len(pdu)
690
691         return p
692
693     def generate_request_auth(self, call_id,
694                               pfc_flags=(dcerpc.DCERPC_PFC_FLAG_FIRST |
695                                          dcerpc.DCERPC_PFC_FLAG_LAST),
696                               alloc_hint=None,
697                               context_id=None,
698                               opnum=None,
699                               object=None,
700                               stub=None,
701                               auth_context=None,
702                               ndr_print=None, hexdump=None):
703
704         if stub is None:
705             stub = b""
706
707         sig_size = 0
708         if auth_context is not None:
709             mod_len = len(stub) % dcerpc.DCERPC_AUTH_PAD_ALIGNMENT
710             auth_pad_length = 0
711             if mod_len > 0:
712                 auth_pad_length = dcerpc.DCERPC_AUTH_PAD_ALIGNMENT - mod_len
713             stub += b'\x00' * auth_pad_length
714
715             if auth_context["g_auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
716                 sig_size = auth_context["gensec"].sig_size(len(stub))
717             else:
718                 sig_size = 16
719
720             zero_sig = b"\x00" * sig_size
721             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
722                                            auth_level=auth_context["auth_level"],
723                                            auth_pad_length=auth_pad_length,
724                                            auth_context_id=auth_context["auth_context_id"],
725                                            auth_blob=zero_sig)
726         else:
727             auth_info = b""
728
729         req = self.generate_request(call_id=call_id,
730                                     pfc_flags=pfc_flags,
731                                     alloc_hint=alloc_hint,
732                                     context_id=context_id,
733                                     opnum=opnum,
734                                     object=object,
735                                     stub=stub,
736                                     auth_info=auth_info,
737                                     ndr_print=ndr_print,
738                                     hexdump=hexdump)
739         if auth_context is None:
740             return req
741
742         req_blob = samba.ndr.ndr_pack(req)
743         ofs_stub = dcerpc.DCERPC_REQUEST_LENGTH
744         ofs_sig = len(req_blob) - req.auth_length
745         ofs_trailer = ofs_sig - dcerpc.DCERPC_AUTH_TRAILER_LENGTH
746         req_data = req_blob[ofs_stub:ofs_trailer]
747         req_whole = req_blob[0:ofs_sig]
748
749         if auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PRIVACY:
750             # TODO: not yet supported here
751             self.assertTrue(False)
752         elif auth_context["auth_level"] >= dcerpc.DCERPC_AUTH_LEVEL_PACKET:
753             req_sig = auth_context["gensec"].sign_packet(req_data, req_whole)
754         else:
755             return req
756         self.assertEquals(len(req_sig), req.auth_length)
757         self.assertEquals(len(req_sig), sig_size)
758
759         stub_sig_ofs = len(req.u.stub_and_verifier) - sig_size
760         stub = req.u.stub_and_verifier[0:stub_sig_ofs] + req_sig
761         req.u.stub_and_verifier = stub
762
763         return req
764
765     def verify_pdu(self, p, ptype, call_id,
766                    rpc_vers=5,
767                    rpc_vers_minor=0,
768                    pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
769                               samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
770                    drep=[samba.dcerpc.dcerpc.DCERPC_DREP_LE, 0, 0, 0],
771                    auth_length=None):
772
773         self.assertIsNotNone(p, "No valid pdu")
774
775         if getattr(p.u, 'auth_info', None):
776             ai = p.u.auth_info
777         else:
778             ai = b""
779
780         self.assertEqual(p.rpc_vers, rpc_vers)
781         self.assertEqual(p.rpc_vers_minor, rpc_vers_minor)
782         self.assertEqual(p.ptype, ptype)
783         self.assertEqual(p.pfc_flags, pfc_flags)
784         self.assertEqual(p.drep, drep)
785         self.assertGreaterEqual(p.frag_length,
786                                 samba.dcerpc.dcerpc.DCERPC_NCACN_PAYLOAD_OFFSET)
787         if len(ai) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
788             self.assertEqual(p.auth_length,
789                              len(ai) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
790         elif auth_length is not None:
791             self.assertEqual(p.auth_length, auth_length)
792         else:
793             self.assertEqual(p.auth_length, 0)
794         self.assertEqual(p.call_id, call_id)
795
796         return
797
798     def generate_bind(self, call_id,
799                       pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
800                                  samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
801                       max_xmit_frag=5840,
802                       max_recv_frag=5840,
803                       assoc_group_id=0,
804                       ctx_list=[],
805                       auth_info=b"",
806                       ndr_print=None, hexdump=None):
807
808         b = samba.dcerpc.dcerpc.bind()
809         b.max_xmit_frag = max_xmit_frag
810         b.max_recv_frag = max_recv_frag
811         b.assoc_group_id = assoc_group_id
812         b.num_contexts = len(ctx_list)
813         b.ctx_list = ctx_list
814         b.auth_info = auth_info
815
816         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_BIND,
817                               pfc_flags=pfc_flags,
818                               call_id=call_id,
819                               payload=b,
820                               ndr_print=ndr_print, hexdump=hexdump)
821
822         return p
823
824     def generate_alter(self, call_id,
825                        pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
826                                   samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
827                        max_xmit_frag=5840,
828                        max_recv_frag=5840,
829                        assoc_group_id=0,
830                        ctx_list=[],
831                        auth_info=b"",
832                        ndr_print=None, hexdump=None):
833
834         a = samba.dcerpc.dcerpc.bind()
835         a.max_xmit_frag = max_xmit_frag
836         a.max_recv_frag = max_recv_frag
837         a.assoc_group_id = assoc_group_id
838         a.num_contexts = len(ctx_list)
839         a.ctx_list = ctx_list
840         a.auth_info = auth_info
841
842         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_ALTER,
843                               pfc_flags=pfc_flags,
844                               call_id=call_id,
845                               payload=a,
846                               ndr_print=ndr_print, hexdump=hexdump)
847
848         return p
849
850     def generate_auth3(self, call_id,
851                        pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
852                                   samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
853                        auth_info=b"",
854                        ndr_print=None, hexdump=None):
855
856         a = samba.dcerpc.dcerpc.auth3()
857         a.auth_info = auth_info
858
859         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_AUTH3,
860                               pfc_flags=pfc_flags,
861                               call_id=call_id,
862                               payload=a,
863                               ndr_print=ndr_print, hexdump=hexdump)
864
865         return p
866
867     def generate_request(self, call_id,
868                          pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
869                                     samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
870                          alloc_hint=None,
871                          context_id=None,
872                          opnum=None,
873                          object=None,
874                          stub=None,
875                          auth_info=b"",
876                          ndr_print=None, hexdump=None):
877
878         if alloc_hint is None:
879             alloc_hint = len(stub)
880
881         r = samba.dcerpc.dcerpc.request()
882         r.alloc_hint = alloc_hint
883         r.context_id = context_id
884         r.opnum = opnum
885         if object is not None:
886             r.object = object
887         r.stub_and_verifier = stub + auth_info
888
889         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_REQUEST,
890                               pfc_flags=pfc_flags,
891                               call_id=call_id,
892                               payload=r,
893                               ndr_print=ndr_print, hexdump=hexdump)
894
895         if len(auth_info) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
896             p.auth_length = len(auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
897
898         return p
899
900     def generate_co_cancel(self, call_id,
901                            pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
902                                       samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
903                            auth_info=b"",
904                            ndr_print=None, hexdump=None):
905
906         c = samba.dcerpc.dcerpc.co_cancel()
907         c.auth_info = auth_info
908
909         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_CO_CANCEL,
910                               pfc_flags=pfc_flags,
911                               call_id=call_id,
912                               payload=c,
913                               ndr_print=ndr_print, hexdump=hexdump)
914
915         return p
916
917     def generate_orphaned(self, call_id,
918                           pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
919                                      samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
920                           auth_info=b"",
921                           ndr_print=None, hexdump=None):
922
923         o = samba.dcerpc.dcerpc.orphaned()
924         o.auth_info = auth_info
925
926         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_ORPHANED,
927                               pfc_flags=pfc_flags,
928                               call_id=call_id,
929                               payload=o,
930                               ndr_print=ndr_print, hexdump=hexdump)
931
932         return p
933
934     def generate_shutdown(self, call_id,
935                           pfc_flags=(samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
936                                      samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST),
937                           ndr_print=None, hexdump=None):
938
939         s = samba.dcerpc.dcerpc.shutdown()
940
941         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_SHUTDOWN,
942                               pfc_flags=pfc_flags,
943                               call_id=call_id,
944                               payload=s,
945                               ndr_print=ndr_print, hexdump=hexdump)
946
947         return p
948
949     def assertIsConnected(self):
950         self.assertIsNotNone(self.s, msg="Not connected")
951         return
952
953     def assertNotConnected(self):
954         self.assertIsNone(self.s, msg="Is connected")
955         return
956
957     def assertNDRSyntaxEquals(self, s1, s2):
958         self.assertEqual(s1.uuid, s2.uuid)
959         self.assertEqual(s1.if_version, s2.if_version)
960         return
961
962     def assertPadding(self, pad, length):
963         self.assertEquals(len(pad), length)
964         #
965         # sometimes windows sends random bytes
966         #
967         # we have IGNORE_RANDOM_PAD=1 to
968         # disable the check
969         #
970         if self.ignore_random_pad:
971             return
972         zero_pad = b'\0' * length
973         self.assertEquals(pad, zero_pad)