python:tests: add more helper functions to RawDCERPCTest
[amitay/samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import ldb
23 import samba
24 import samba.auth
25 from samba import param
26 from samba.samdb import SamDB
27 from samba import credentials
28 import samba.ndr
29 import samba.dcerpc.dcerpc
30 import samba.dcerpc.base
31 import samba.dcerpc.epmapper
32 from samba.credentials import Credentials
33 from samba import gensec
34 import socket
35 import struct
36 import subprocess
37 import sys
38 import tempfile
39 import unittest
40
41 try:
42     from unittest import SkipTest
43 except ImportError:
44     class SkipTest(Exception):
45         """Test skipped."""
46
47 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
48
49 class TestCase(unittest.TestCase):
50     """A Samba test case."""
51
52     def setUp(self):
53         super(TestCase, self).setUp()
54         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
55         if test_debug_level is not None:
56             test_debug_level = int(test_debug_level)
57             self._old_debug_level = samba.get_debug_level()
58             samba.set_debug_level(test_debug_level)
59             self.addCleanup(samba.set_debug_level, test_debug_level)
60
61     def get_loadparm(self):
62         return env_loadparm()
63
64     def get_credentials(self):
65         return cmdline_credentials
66
67     def hexdump(self, src):
68         N = 0
69         result = ''
70         while src:
71             ll = src[:8]
72             lr = src[8:16]
73             src = src[16:]
74             hl = ' '.join(["%02X" % ord(x) for x in ll])
75             hr = ' '.join(["%02X" % ord(x) for x in lr])
76             ll = ll.translate(HEXDUMP_FILTER)
77             lr = lr.translate(HEXDUMP_FILTER)
78             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
79             N += 16
80         return result
81
82     # These functions didn't exist before Python2.7:
83     if sys.version_info < (2, 7):
84         import warnings
85
86         def skipTest(self, reason):
87             raise SkipTest(reason)
88
89         def assertIn(self, member, container, msg=None):
90             self.assertTrue(member in container, msg)
91
92         def assertIs(self, a, b, msg=None):
93             self.assertTrue(a is b, msg)
94
95         def assertIsNot(self, a, b, msg=None):
96             self.assertTrue(a is not b, msg)
97
98         def assertIsNotNone(self, a, msg=None):
99             self.assertTrue(a is not None)
100
101         def assertIsInstance(self, a, b, msg=None):
102             self.assertTrue(isinstance(a, b), msg)
103
104         def assertIsNone(self, a, msg=None):
105             self.assertTrue(a is None, msg)
106
107         def assertGreater(self, a, b, msg=None):
108             self.assertTrue(a > b, msg)
109
110         def assertGreaterEqual(self, a, b, msg=None):
111             self.assertTrue(a >= b, msg)
112
113         def assertLess(self, a, b, msg=None):
114             self.assertTrue(a < b, msg)
115
116         def assertLessEqual(self, a, b, msg=None):
117             self.assertTrue(a <= b, msg)
118
119         def addCleanup(self, fn, *args, **kwargs):
120             self._cleanups = getattr(self, "_cleanups", []) + [
121                 (fn, args, kwargs)]
122
123         def _addSkip(self, result, reason):
124             addSkip = getattr(result, 'addSkip', None)
125             if addSkip is not None:
126                 addSkip(self, reason)
127             else:
128                 warnings.warn("TestResult has no addSkip method, skips not reported",
129                               RuntimeWarning, 2)
130                 result.addSuccess(self)
131
132         def run(self, result=None):
133             if result is None: result = self.defaultTestResult()
134             result.startTest(self)
135             testMethod = getattr(self, self._testMethodName)
136             try:
137                 try:
138                     self.setUp()
139                 except SkipTest, e:
140                     self._addSkip(result, str(e))
141                     return
142                 except KeyboardInterrupt:
143                     raise
144                 except:
145                     result.addError(self, self._exc_info())
146                     return
147
148                 ok = False
149                 try:
150                     testMethod()
151                     ok = True
152                 except SkipTest, e:
153                     self._addSkip(result, str(e))
154                     return
155                 except self.failureException:
156                     result.addFailure(self, self._exc_info())
157                 except KeyboardInterrupt:
158                     raise
159                 except:
160                     result.addError(self, self._exc_info())
161
162                 try:
163                     self.tearDown()
164                 except SkipTest, e:
165                     self._addSkip(result, str(e))
166                 except KeyboardInterrupt:
167                     raise
168                 except:
169                     result.addError(self, self._exc_info())
170                     ok = False
171
172                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
173                     fn(*args, **kwargs)
174                 if ok: result.addSuccess(self)
175             finally:
176                 result.stopTest(self)
177
178
179 class LdbTestCase(TestCase):
180     """Trivial test case for running tests against a LDB."""
181
182     def setUp(self):
183         super(LdbTestCase, self).setUp()
184         self.filename = os.tempnam()
185         self.ldb = samba.Ldb(self.filename)
186
187     def set_modules(self, modules=[]):
188         """Change the modules for this Ldb."""
189         m = ldb.Message()
190         m.dn = ldb.Dn(self.ldb, "@MODULES")
191         m["@LIST"] = ",".join(modules)
192         self.ldb.add(m)
193         self.ldb = samba.Ldb(self.filename)
194
195
196 class TestCaseInTempDir(TestCase):
197
198     def setUp(self):
199         super(TestCaseInTempDir, self).setUp()
200         self.tempdir = tempfile.mkdtemp()
201         self.addCleanup(self._remove_tempdir)
202
203     def _remove_tempdir(self):
204         self.assertEquals([], os.listdir(self.tempdir))
205         os.rmdir(self.tempdir)
206         self.tempdir = None
207
208
209 def env_loadparm():
210     lp = param.LoadParm()
211     try:
212         lp.load(os.environ["SMB_CONF_PATH"])
213     except KeyError:
214         raise KeyError("SMB_CONF_PATH not set")
215     return lp
216
217
218 def env_get_var_value(var_name):
219     """Returns value for variable in os.environ
220
221     Function throws AssertionError if variable is defined.
222     Unit-test based python tests require certain input params
223     to be set in environment, otherwise they can't be run
224     """
225     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
226     return os.environ[var_name]
227
228
229 cmdline_credentials = None
230
231 class RpcInterfaceTestCase(TestCase):
232     """DCE/RPC Test case."""
233
234 class RawDCERPCTest(TestCase):
235     """A raw DCE/RPC Test case."""
236
237     def _disconnect(self, reason):
238         if self.s is None:
239             return
240         self.s.close()
241         self.s = None
242         if self.do_hexdump:
243             sys.stderr.write("disconnect[%s]\n" % reason)
244
245     def connect(self):
246         try:
247             self.a = socket.getaddrinfo(self.host, self.tcp_port, socket.AF_UNSPEC,
248                                         socket.SOCK_STREAM, socket.SOL_TCP,
249                                         0)
250             self.s = socket.socket(self.a[0][0], self.a[0][1], self.a[0][2])
251             self.s.settimeout(10)
252             self.s.connect(self.a[0][4])
253         except socket.error as e:
254             self.s.close()
255             raise
256         except IOError as e:
257             self.s.close()
258             raise
259         except Exception as e:
260             raise
261         finally:
262             pass
263
264     def setUp(self):
265         super(RawDCERPCTest, self).setUp()
266         self.do_ndr_print = False
267         self.do_hexdump = False
268
269         self.host = samba.tests.env_get_var_value('SERVER')
270         self.tcp_port = 135
271
272         self.settings = {}
273         self.settings["lp_ctx"] = self.lp_ctx = samba.tests.env_loadparm()
274         self.settings["target_hostname"] = self.host
275
276         self.connect()
277
278     def get_user_creds(self):
279         c = Credentials()
280         c.guess()
281         username = samba.tests.env_get_var_value('USERNAME')
282         password = samba.tests.env_get_var_value('PASSWORD')
283         c.set_username(username)
284         c.set_password(password)
285         return c
286
287     def get_anon_creds(self):
288         c = Credentials()
289         c.set_anonymous()
290         return c
291
292     def get_auth_context_creds(self, creds, auth_type, auth_level,
293                                auth_context_id,
294                                g_auth_level=None):
295
296         if g_auth_level is None:
297             g_auth_level = auth_level
298
299         g = gensec.Security.start_client(self.settings)
300         g.set_credentials(creds)
301         g.want_feature(gensec.FEATURE_DCE_STYLE)
302         g.start_mech_by_authtype(auth_type, g_auth_level)
303
304         auth_context = {}
305         auth_context["auth_type"] = auth_type
306         auth_context["auth_level"] = auth_level
307         auth_context["auth_context_id"] = auth_context_id
308         auth_context["g_auth_level"] = g_auth_level
309         auth_context["gensec"] = g
310
311         return auth_context
312
313     def do_generic_bind(self, ctx, auth_context=None,
314                         pfc_flags=samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
315                         samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
316                         assoc_group_id=0, call_id=0,
317                         nak_reason=None, alter_fault=None):
318         ctx_list = [ctx]
319
320         if auth_context is not None:
321             from_server = ""
322             (finished, to_server) = auth_context["gensec"].update(from_server)
323             self.assertFalse(finished)
324
325             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
326                                            auth_level=auth_context["auth_level"],
327                                            auth_context_id=auth_context["auth_context_id"],
328                                            auth_blob=to_server)
329         else:
330             auth_info = ""
331
332         req = self.generate_bind(call_id=call_id,
333                                  pfc_flags=pfc_flags,
334                                  ctx_list=ctx_list,
335                                  assoc_group_id=assoc_group_id,
336                                  auth_info=auth_info)
337         self.send_pdu(req)
338         rep = self.recv_pdu()
339         if nak_reason is not None:
340             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_BIND_NAK, req.call_id,
341                             auth_length=0)
342             self.assertEquals(rep.u.reject_reason, nak_reason)
343             self.assertEquals(rep.u.num_versions, 1)
344             self.assertEquals(rep.u.versions[0].rpc_vers, req.rpc_vers)
345             self.assertEquals(rep.u.versions[0].rpc_vers_minor, req.rpc_vers_minor)
346             self.assertEquals(len(rep.u._pad), 3)
347             self.assertEquals(rep.u._pad, '\0' * 3)
348             return
349         self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_BIND_ACK, req.call_id,
350                         pfc_flags=pfc_flags)
351         self.assertEquals(rep.u.max_xmit_frag, req.u.max_xmit_frag)
352         self.assertEquals(rep.u.max_recv_frag, req.u.max_recv_frag)
353         if assoc_group_id != 0:
354             self.assertEquals(rep.u.assoc_group_id, assoc_group_id)
355         else:
356             self.assertNotEquals(rep.u.assoc_group_id, 0)
357             assoc_group_id = rep.u.assoc_group_id
358         port_str = "%d" % self.tcp_port
359         port_len = len(port_str) + 1
360         mod_len = (2 + port_len) % 4
361         if mod_len != 0:
362             port_pad = 4 - mod_len
363         else:
364             port_pad = 0
365         self.assertEquals(rep.u.secondary_address_size, port_len)
366         self.assertEquals(rep.u.secondary_address, port_str)
367         self.assertEquals(len(rep.u._pad1), port_pad)
368         # sometimes windows sends random bytes
369         # self.assertEquals(rep.u._pad1, '\0' * port_pad)
370         self.assertEquals(rep.u.num_results, 1)
371         self.assertEquals(rep.u.ctx_list[0].result,
372                 samba.dcerpc.dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE)
373         self.assertEquals(rep.u.ctx_list[0].reason,
374                 samba.dcerpc.dcerpc.DCERPC_BIND_ACK_REASON_NOT_SPECIFIED)
375         self.assertNDRSyntaxEquals(rep.u.ctx_list[0].syntax, ctx.transfer_syntaxes[0])
376         ack = rep
377         if auth_context is None:
378             self.assertEquals(rep.auth_length, 0)
379             self.assertEquals(len(rep.u.auth_info), 0)
380             return ack
381         self.assertNotEquals(rep.auth_length, 0)
382         self.assertGreater(len(rep.u.auth_info), samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
383         self.assertEquals(rep.auth_length, len(rep.u.auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
384
385         a = self.parse_auth(rep.u.auth_info)
386
387         from_server = a.credentials
388         (finished, to_server) = auth_context["gensec"].update(from_server)
389         self.assertFalse(finished)
390
391         auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
392                                        auth_level=auth_context["auth_level"],
393                                        auth_context_id=auth_context["auth_context_id"],
394                                        auth_blob=to_server)
395         req = self.generate_alter(call_id=call_id,
396                                   ctx_list=ctx_list,
397                                   assoc_group_id=0xffffffff-assoc_group_id,
398                                   auth_info=auth_info)
399         self.send_pdu(req)
400         rep = self.recv_pdu()
401         if alter_fault is not None:
402             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_FAULT, req.call_id,
403                             pfc_flags=req.pfc_flags |
404                             samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_DID_NOT_EXECUTE,
405                             auth_length=0)
406             self.assertNotEquals(rep.u.alloc_hint, 0)
407             self.assertEquals(rep.u.context_id, 0)
408             self.assertEquals(rep.u.cancel_count, 0)
409             self.assertEquals(rep.u.flags, 0)
410             self.assertEquals(rep.u.status, alter_fault)
411             self.assertEquals(rep.u.reserved, 0)
412             self.assertEquals(len(rep.u.error_and_verifier), 0)
413             return None
414         self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_ALTER_RESP, req.call_id)
415         self.assertEquals(rep.u.max_xmit_frag, req.u.max_xmit_frag)
416         self.assertEquals(rep.u.max_recv_frag, req.u.max_recv_frag)
417         self.assertEquals(rep.u.assoc_group_id, assoc_group_id)
418         self.assertEquals(rep.u.secondary_address_size, 0)
419         self.assertEquals(rep.u.secondary_address, '')
420         self.assertEquals(len(rep.u._pad1), 2)
421         # sometimes windows sends random bytes
422         # self.assertEquals(rep.u._pad1, '\0' * 2)
423         self.assertEquals(rep.u.num_results, 1)
424         self.assertEquals(rep.u.ctx_list[0].result,
425                 samba.dcerpc.dcerpc.DCERPC_BIND_ACK_RESULT_ACCEPTANCE)
426         self.assertEquals(rep.u.ctx_list[0].reason,
427                 samba.dcerpc.dcerpc.DCERPC_BIND_ACK_REASON_NOT_SPECIFIED)
428         self.assertNDRSyntaxEquals(rep.u.ctx_list[0].syntax, ctx.transfer_syntaxes[0])
429         self.assertNotEquals(rep.auth_length, 0)
430         self.assertGreater(len(rep.u.auth_info), samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
431         self.assertEquals(rep.auth_length, len(rep.u.auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
432
433         a = self.parse_auth(rep.u.auth_info)
434
435         from_server = a.credentials
436         (finished, to_server) = auth_context["gensec"].update(from_server)
437         self.assertTrue(finished)
438
439         return ack
440
441     def prepare_presentation(self, abstract, transfer, object=None,
442                              context_id=0xffff, epmap=False, auth_context=None,
443                              pfc_flags=samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
444                              samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
445                              assoc_group_id=0,
446                              return_ack=False):
447         if epmap:
448             self.epmap_reconnect(abstract, transfer=transfer, object=object)
449
450         tsf1_list = [transfer]
451         ctx = samba.dcerpc.dcerpc.ctx_list()
452         ctx.context_id = context_id
453         ctx.num_transfer_syntaxes = len(tsf1_list)
454         ctx.abstract_syntax = abstract
455         ctx.transfer_syntaxes = tsf1_list
456
457         ack = self.do_generic_bind(ctx=ctx,
458                                    auth_context=auth_context,
459                                    pfc_flags=pfc_flags,
460                                    assoc_group_id=assoc_group_id)
461         if ack is None:
462             ctx = None
463
464         if return_ack:
465             return (ctx, ack)
466         return ctx
467
468     def do_single_request(self, call_id, ctx, io,
469                           auth_context=None,
470                           object=None,
471                           bigendian=False, ndr64=False,
472                           allow_remaining=False,
473                           send_req=True,
474                           recv_rep=True,
475                           fault_pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
476                           samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
477                           fault_status=None,
478                           fault_context_id=None,
479                           timeout=None,
480                           ndr_print=None,
481                           hexdump=None):
482
483         if fault_context_id is None:
484             fault_context_id = ctx.context_id
485
486         if ndr_print is None:
487             ndr_print = self.do_ndr_print
488         if hexdump is None:
489             hexdump = self.do_hexdump
490
491         if send_req:
492             if ndr_print:
493                 sys.stderr.write("in: %s" % samba.ndr.ndr_print_in(io))
494             stub_in = samba.ndr.ndr_pack_in(io, bigendian=bigendian, ndr64=ndr64)
495             if hexdump:
496                 sys.stderr.write("stub_in: %d\n%s" % (len(stub_in), self.hexdump(stub_in)))
497         else:
498             # only used for sig_size calculation
499             stub_in = '\xff' * samba.dcerpc.dcerpc.DCERPC_AUTH_PAD_ALIGNMENT
500
501         sig_size = 0
502         if auth_context is not None:
503             mod_len = len(stub_in) % samba.dcerpc.dcerpc.DCERPC_AUTH_PAD_ALIGNMENT
504             auth_pad_length = 0
505             if mod_len > 0:
506                 auth_pad_length = samba.dcerpc.dcerpc.DCERPC_AUTH_PAD_ALIGNMENT - mod_len
507             stub_in += '\x00' * auth_pad_length
508
509             if auth_context["g_auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
510                 sig_size = auth_context["gensec"].sig_size(len(stub_in))
511             else:
512                 sig_size = 16
513
514             zero_sig = "\x00"*sig_size
515             auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
516                                            auth_level=auth_context["auth_level"],
517                                            auth_pad_length=auth_pad_length,
518                                            auth_context_id=auth_context["auth_context_id"],
519                                            auth_blob=zero_sig)
520         else:
521             auth_info=""
522
523         pfc_flags =  samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST
524         pfc_flags |= samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST
525         if object is not None:
526             pfc_flags |= samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_OBJECT_UUID
527
528         req = self.generate_request(call_id=call_id,
529                                     context_id=ctx.context_id,
530                                     pfc_flags=pfc_flags,
531                                     object=object,
532                                     opnum=io.opnum(),
533                                     stub=stub_in,
534                                     auth_info=auth_info)
535
536         if send_req:
537             if sig_size != 0 and auth_context["auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
538                 req_blob = samba.ndr.ndr_pack(req)
539                 ofs_stub = samba.dcerpc.dcerpc.DCERPC_REQUEST_LENGTH
540                 ofs_sig = len(req_blob) - req.auth_length
541                 ofs_trailer = ofs_sig - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
542                 req_data = req_blob[ofs_stub:ofs_trailer]
543                 req_whole = req_blob[0:ofs_sig]
544                 sig = auth_context["gensec"].sign_packet(req_data, req_whole)
545                 auth_info = self.generate_auth(auth_type=auth_context["auth_type"],
546                                                auth_level=auth_context["auth_level"],
547                                                auth_pad_length=auth_pad_length,
548                                                auth_context_id=auth_context["auth_context_id"],
549                                                auth_blob=sig)
550                 req = self.generate_request(call_id=call_id,
551                                             context_id=ctx.context_id,
552                                             pfc_flags=pfc_flags,
553                                             object=object,
554                                             opnum=io.opnum(),
555                                             stub=stub_in,
556                                             auth_info=auth_info)
557             self.send_pdu(req, ndr_print=ndr_print, hexdump=hexdump)
558         if recv_rep:
559             (rep, rep_blob) = self.recv_pdu_raw(timeout=timeout,
560                                                 ndr_print=ndr_print,
561                                                 hexdump=hexdump)
562             if fault_status:
563                 self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_FAULT, req.call_id,
564                                 pfc_flags=fault_pfc_flags, auth_length=0)
565                 self.assertNotEquals(rep.u.alloc_hint, 0)
566                 self.assertEquals(rep.u.context_id, fault_context_id)
567                 self.assertEquals(rep.u.cancel_count, 0)
568                 self.assertEquals(rep.u.flags, 0)
569                 self.assertEquals(rep.u.status, fault_status)
570                 self.assertEquals(rep.u.reserved, 0)
571                 self.assertEquals(len(rep.u.error_and_verifier), 0)
572                 return
573
574             self.verify_pdu(rep, samba.dcerpc.dcerpc.DCERPC_PKT_RESPONSE, req.call_id,
575                             auth_length=sig_size)
576             self.assertNotEquals(rep.u.alloc_hint, 0)
577             self.assertEquals(rep.u.context_id, req.u.context_id & 0xff)
578             self.assertEquals(rep.u.cancel_count, 0)
579             self.assertGreaterEqual(len(rep.u.stub_and_verifier), rep.u.alloc_hint)
580             if sig_size != 0:
581
582                 ofs_stub = samba.dcerpc.dcerpc.DCERPC_REQUEST_LENGTH
583                 ofs_sig = rep.frag_length - rep.auth_length
584                 ofs_trailer = ofs_sig - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
585                 rep_data = rep_blob[ofs_stub:ofs_trailer]
586                 rep_whole = rep_blob[0:ofs_sig]
587                 rep_sig = rep_blob[ofs_sig:]
588                 rep_auth_info_blob = rep_blob[ofs_trailer:]
589
590                 rep_auth_info = self.parse_auth(rep_auth_info_blob)
591                 self.assertEquals(rep_auth_info.auth_type, auth_context["auth_type"])
592                 self.assertEquals(rep_auth_info.auth_level, auth_context["auth_level"])
593                 self.assertLessEqual(rep_auth_info.auth_pad_length, len(rep_data))
594                 self.assertEquals(rep_auth_info.auth_reserved, 0)
595                 self.assertEquals(rep_auth_info.auth_context_id, auth_context["auth_context_id"])
596                 self.assertEquals(rep_auth_info.credentials, rep_sig)
597
598                 if auth_context["auth_level"] >= samba.dcerpc.dcerpc.DCERPC_AUTH_LEVEL_PACKET:
599                     auth_context["gensec"].check_packet(rep_data, rep_whole, rep_sig)
600
601                 stub_out = rep_data[0:-rep_auth_info.auth_pad_length]
602             else:
603                 stub_out = rep.u.stub_and_verifier
604
605             if hexdump:
606                 sys.stderr.write("stub_out: %d\n%s" % (len(stub_out), self.hexdump(stub_out)))
607             samba.ndr.ndr_unpack_out(io, stub_out, bigendian=bigendian, ndr64=ndr64,
608                                      allow_remaining=allow_remaining)
609             if ndr_print:
610                 sys.stderr.write("out: %s" % samba.ndr.ndr_print_out(io))
611
612     def epmap_reconnect(self, abstract, transfer=None, object=None):
613         ndr32 = samba.dcerpc.base.transfer_syntax_ndr()
614
615         if transfer is None:
616             transfer = ndr32
617
618         if object is None:
619             object = samba.dcerpc.misc.GUID()
620
621         ctx = self.prepare_presentation(samba.dcerpc.epmapper.abstract_syntax(),
622                                         transfer, context_id=0)
623
624         data1 = samba.ndr.ndr_pack(abstract)
625         lhs1 = samba.dcerpc.epmapper.epm_lhs()
626         lhs1.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_UUID
627         lhs1.lhs_data = data1[:18]
628         rhs1 = samba.dcerpc.epmapper.epm_rhs_uuid()
629         rhs1.unknown = data1[18:]
630         floor1 = samba.dcerpc.epmapper.epm_floor()
631         floor1.lhs = lhs1
632         floor1.rhs = rhs1
633         data2 = samba.ndr.ndr_pack(transfer)
634         lhs2 = samba.dcerpc.epmapper.epm_lhs()
635         lhs2.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_UUID
636         lhs2.lhs_data = data2[:18]
637         rhs2 = samba.dcerpc.epmapper.epm_rhs_uuid()
638         rhs2.unknown = data1[18:]
639         floor2 = samba.dcerpc.epmapper.epm_floor()
640         floor2.lhs = lhs2
641         floor2.rhs = rhs2
642         lhs3 = samba.dcerpc.epmapper.epm_lhs()
643         lhs3.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_NCACN
644         lhs3.lhs_data = ""
645         floor3 = samba.dcerpc.epmapper.epm_floor()
646         floor3.lhs = lhs3
647         floor3.rhs.minor_version = 0
648         lhs4 = samba.dcerpc.epmapper.epm_lhs()
649         lhs4.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_TCP
650         lhs4.lhs_data = ""
651         floor4 = samba.dcerpc.epmapper.epm_floor()
652         floor4.lhs = lhs4
653         floor4.rhs.port = self.tcp_port
654         lhs5 = samba.dcerpc.epmapper.epm_lhs()
655         lhs5.protocol = samba.dcerpc.epmapper.EPM_PROTOCOL_IP
656         lhs5.lhs_data = ""
657         floor5 = samba.dcerpc.epmapper.epm_floor()
658         floor5.lhs = lhs5
659         floor5.rhs.ipaddr = "0.0.0.0"
660
661         floors = [floor1,floor2,floor3,floor4,floor5]
662         req_tower = samba.dcerpc.epmapper.epm_tower()
663         req_tower.num_floors = len(floors)
664         req_tower.floors = floors
665         req_twr = samba.dcerpc.epmapper.epm_twr_t()
666         req_twr.tower = req_tower
667
668         epm_map = samba.dcerpc.epmapper.epm_Map()
669         epm_map.in_object = object
670         epm_map.in_map_tower = req_twr
671         epm_map.in_entry_handle = samba.dcerpc.misc.policy_handle()
672         epm_map.in_max_towers = 4
673
674         self.do_single_request(call_id=2, ctx=ctx, io=epm_map)
675
676         self.assertGreaterEqual(epm_map.out_num_towers, 1)
677         rep_twr = epm_map.out_towers[0].twr
678         self.assertIsNotNone(rep_twr)
679         self.assertEqual(rep_twr.tower_length, 75)
680         self.assertEqual(rep_twr.tower.num_floors, 5)
681         self.assertEqual(len(rep_twr.tower.floors), 5)
682         self.assertEqual(rep_twr.tower.floors[3].lhs.protocol,
683                           samba.dcerpc.epmapper.EPM_PROTOCOL_TCP)
684         self.assertEqual(rep_twr.tower.floors[3].lhs.protocol,
685                           samba.dcerpc.epmapper.EPM_PROTOCOL_TCP)
686
687         # reconnect to the given port
688         self._disconnect("epmap_reconnect")
689         self.tcp_port = rep_twr.tower.floors[3].rhs.port
690         self.connect()
691
692     def send_pdu(self, req, ndr_print=None, hexdump=None):
693         if ndr_print is None:
694             ndr_print = self.do_ndr_print
695         if hexdump is None:
696             hexdump = self.do_hexdump
697         try:
698             req_pdu = samba.ndr.ndr_pack(req)
699             if ndr_print:
700                 sys.stderr.write("send_pdu: %s" % samba.ndr.ndr_print(req))
701             if hexdump:
702                 sys.stderr.write("send_pdu: %d\n%s" % (len(req_pdu), self.hexdump(req_pdu)))
703             while True:
704                 sent = self.s.send(req_pdu, 0)
705                 if sent == len(req_pdu):
706                     break
707                 req_pdu = req_pdu[sent:]
708         except socket.error as e:
709             self._disconnect("send_pdu: %s" % e)
710             raise
711         except IOError as e:
712             self._disconnect("send_pdu: %s" % e)
713             raise
714         finally:
715             pass
716
717     def recv_raw(self, hexdump=None, timeout=None):
718         rep_pdu = None
719         if hexdump is None:
720             hexdump = self.do_hexdump
721         try:
722             if timeout is not None:
723                 self.s.settimeout(timeout)
724             rep_pdu = self.s.recv(0xffff, 0)
725             self.s.settimeout(10)
726             if len(rep_pdu) == 0:
727                 self._disconnect("recv_raw: EOF")
728                 return None
729             if hexdump:
730                 sys.stderr.write("recv_raw: %d\n%s" % (len(rep_pdu), self.hexdump(rep_pdu)))
731         except socket.timeout as e:
732             self.s.settimeout(10)
733             sys.stderr.write("recv_raw: TIMEOUT\n")
734             pass
735         except socket.error as e:
736             self._disconnect("recv_raw: %s" % e)
737             raise
738         except IOError as e:
739             self._disconnect("recv_raw: %s" % e)
740             raise
741         finally:
742             pass
743         return rep_pdu
744
745     def recv_pdu_raw(self, ndr_print=None, hexdump=None, timeout=None):
746         rep_pdu = None
747         rep = None
748         if ndr_print is None:
749             ndr_print = self.do_ndr_print
750         if hexdump is None:
751             hexdump = self.do_hexdump
752         try:
753             rep_pdu = self.recv_raw(hexdump=hexdump, timeout=timeout)
754             if rep_pdu is None:
755                 return (None,None)
756             rep = samba.ndr.ndr_unpack(samba.dcerpc.dcerpc.ncacn_packet, rep_pdu, allow_remaining=True)
757             if ndr_print:
758                 sys.stderr.write("recv_pdu: %s" % samba.ndr.ndr_print(rep))
759             self.assertEqual(rep.frag_length, len(rep_pdu))
760         finally:
761             pass
762         return (rep, rep_pdu)
763
764     def recv_pdu(self, ndr_print=None, hexdump=None, timeout=None):
765         (rep, rep_pdu) = self.recv_pdu_raw(ndr_print=ndr_print,
766                                            hexdump=hexdump,
767                                            timeout=timeout)
768         return rep
769
770     def generate_auth(self,
771                       auth_type=None,
772                       auth_level=None,
773                       auth_pad_length=0,
774                       auth_context_id=None,
775                       auth_blob=None,
776                       ndr_print=None, hexdump=None):
777         if ndr_print is None:
778             ndr_print = self.do_ndr_print
779         if hexdump is None:
780             hexdump = self.do_hexdump
781
782         if auth_type is not None:
783             a = samba.dcerpc.dcerpc.auth()
784             a.auth_type = auth_type
785             a.auth_level = auth_level
786             a.auth_pad_length = auth_pad_length
787             a.auth_context_id= auth_context_id
788             a.credentials = auth_blob
789
790             ai = samba.ndr.ndr_pack(a)
791             if ndr_print:
792                 sys.stderr.write("generate_auth: %s" % samba.ndr.ndr_print(a))
793             if hexdump:
794                 sys.stderr.write("generate_auth: %d\n%s" % (len(ai), self.hexdump(ai)))
795         else:
796             ai = ""
797
798         return ai
799
800     def parse_auth(self, auth_info, ndr_print=None, hexdump=None):
801         if ndr_print is None:
802             ndr_print = self.do_ndr_print
803         if hexdump is None:
804             hexdump = self.do_hexdump
805
806         if (len(auth_info) <= samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH):
807             return None
808
809         if hexdump:
810             sys.stderr.write("parse_auth: %d\n%s" % (len(auth_info), self.hexdump(auth_info)))
811         a = samba.ndr.ndr_unpack(samba.dcerpc.dcerpc.auth, auth_info, allow_remaining=True)
812         if ndr_print:
813             sys.stderr.write("parse_auth: %s" % samba.ndr.ndr_print(a))
814
815         return a
816
817     def generate_pdu(self, ptype, call_id, payload,
818                      rpc_vers=5,
819                      rpc_vers_minor=0,
820                      pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
821                                  samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
822                      drep = [samba.dcerpc.dcerpc.DCERPC_DREP_LE, 0, 0, 0],
823                      ndr_print=None, hexdump=None):
824
825         if getattr(payload, 'auth_info', None):
826             ai = payload.auth_info
827         else:
828             ai = ""
829
830         p = samba.dcerpc.dcerpc.ncacn_packet()
831         p.rpc_vers = rpc_vers
832         p.rpc_vers_minor = rpc_vers_minor
833         p.ptype = ptype
834         p.pfc_flags = pfc_flags
835         p.drep = drep
836         p.frag_length = 0
837         if len(ai) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
838             p.auth_length = len(ai) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
839         else:
840             p.auth_length = 0
841         p.call_id = call_id
842         p.u = payload
843
844         pdu = samba.ndr.ndr_pack(p)
845         p.frag_length = len(pdu)
846
847         return p
848
849     def verify_pdu(self, p, ptype, call_id,
850                    rpc_vers=5,
851                    rpc_vers_minor=0,
852                    pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
853                                samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
854                    drep = [samba.dcerpc.dcerpc.DCERPC_DREP_LE, 0, 0, 0],
855                    auth_length=None):
856
857         self.assertIsNotNone(p, "No valid pdu")
858
859         if getattr(p.u, 'auth_info', None):
860             ai = p.u.auth_info
861         else:
862             ai = ""
863
864         self.assertEqual(p.rpc_vers, rpc_vers)
865         self.assertEqual(p.rpc_vers_minor, rpc_vers_minor)
866         self.assertEqual(p.ptype, ptype)
867         self.assertEqual(p.pfc_flags, pfc_flags)
868         self.assertEqual(p.drep, drep)
869         self.assertGreaterEqual(p.frag_length,
870                 samba.dcerpc.dcerpc.DCERPC_NCACN_PAYLOAD_OFFSET)
871         if len(ai) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
872             self.assertEqual(p.auth_length,
873                     len(ai) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH)
874         elif auth_length is not None:
875             self.assertEqual(p.auth_length, auth_length)
876         else:
877             self.assertEqual(p.auth_length, 0)
878         self.assertEqual(p.call_id, call_id)
879
880         return
881
882     def generate_bind(self, call_id,
883                       pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
884                                   samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
885                       max_xmit_frag=5840,
886                       max_recv_frag=5840,
887                       assoc_group_id=0,
888                       ctx_list=[],
889                       auth_info="",
890                       ndr_print=None, hexdump=None):
891
892         b = samba.dcerpc.dcerpc.bind()
893         b.max_xmit_frag = max_xmit_frag
894         b.max_recv_frag = max_recv_frag
895         b.assoc_group_id = assoc_group_id
896         b.num_contexts = len(ctx_list)
897         b.ctx_list = ctx_list
898         b.auth_info = auth_info
899
900         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_BIND,
901                               pfc_flags=pfc_flags,
902                               call_id=call_id,
903                               payload=b,
904                               ndr_print=ndr_print, hexdump=hexdump)
905
906         return p
907
908     def generate_alter(self, call_id,
909                        pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
910                                    samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
911                        max_xmit_frag=5840,
912                        max_recv_frag=5840,
913                        assoc_group_id=0,
914                        ctx_list=[],
915                        auth_info="",
916                        ndr_print=None, hexdump=None):
917
918         a = samba.dcerpc.dcerpc.bind()
919         a.max_xmit_frag = max_xmit_frag
920         a.max_recv_frag = max_recv_frag
921         a.assoc_group_id = assoc_group_id
922         a.num_contexts = len(ctx_list)
923         a.ctx_list = ctx_list
924         a.auth_info = auth_info
925
926         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_ALTER,
927                               pfc_flags=pfc_flags,
928                               call_id=call_id,
929                               payload=a,
930                               ndr_print=ndr_print, hexdump=hexdump)
931
932         return p
933
934     def generate_auth3(self, call_id,
935                        pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
936                                    samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
937                        auth_info="",
938                        ndr_print=None, hexdump=None):
939
940         a = samba.dcerpc.dcerpc.auth3()
941         a.auth_info = auth_info
942
943         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_AUTH3,
944                               pfc_flags=pfc_flags,
945                               call_id=call_id,
946                               payload=a,
947                               ndr_print=ndr_print, hexdump=hexdump)
948
949         return p
950
951     def generate_request(self, call_id,
952                          pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
953                                      samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
954                          alloc_hint=None,
955                          context_id=None,
956                          opnum=None,
957                          object=None,
958                          stub=None,
959                          auth_info="",
960                          ndr_print=None, hexdump=None):
961
962         if alloc_hint is None:
963             alloc_hint = len(stub)
964
965         r = samba.dcerpc.dcerpc.request()
966         r.alloc_hint = alloc_hint
967         r.context_id = context_id
968         r.opnum = opnum
969         if object is not None:
970             r.object = object
971         r.stub_and_verifier = stub + auth_info
972
973         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_REQUEST,
974                               pfc_flags=pfc_flags,
975                               call_id=call_id,
976                               payload=r,
977                               ndr_print=ndr_print, hexdump=hexdump)
978
979         if len(auth_info) > samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH:
980             p.auth_length = len(auth_info) - samba.dcerpc.dcerpc.DCERPC_AUTH_TRAILER_LENGTH
981
982         return p
983
984     def generate_co_cancel(self, call_id,
985                            pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
986                                        samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
987                            auth_info="",
988                            ndr_print=None, hexdump=None):
989
990         c = samba.dcerpc.dcerpc.co_cancel()
991         c.auth_info = auth_info
992
993         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_CO_CANCEL,
994                               pfc_flags=pfc_flags,
995                               call_id=call_id,
996                               payload=c,
997                               ndr_print=ndr_print, hexdump=hexdump)
998
999         return p
1000
1001     def generate_orphaned(self, call_id,
1002                           pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
1003                                       samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
1004                           auth_info="",
1005                           ndr_print=None, hexdump=None):
1006
1007         o = samba.dcerpc.dcerpc.orphaned()
1008         o.auth_info = auth_info
1009
1010         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_ORPHANED,
1011                               pfc_flags=pfc_flags,
1012                               call_id=call_id,
1013                               payload=o,
1014                               ndr_print=ndr_print, hexdump=hexdump)
1015
1016         return p
1017
1018     def generate_shutdown(self, call_id,
1019                           pfc_flags = samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_FIRST |
1020                                       samba.dcerpc.dcerpc.DCERPC_PFC_FLAG_LAST,
1021                           ndr_print=None, hexdump=None):
1022
1023         s = samba.dcerpc.dcerpc.shutdown()
1024
1025         p = self.generate_pdu(ptype=samba.dcerpc.dcerpc.DCERPC_PKT_SHUTDOWN,
1026                               pfc_flags=pfc_flags,
1027                               call_id=call_id,
1028                               payload=s,
1029                               ndr_print=ndr_print, hexdump=hexdump)
1030
1031         return p
1032
1033     def assertIsConnected(self):
1034         self.assertIsNotNone(self.s, msg="Not connected")
1035         return
1036
1037     def assertNotConnected(self):
1038         self.assertIsNone(self.s, msg="Is connected")
1039         return
1040
1041     def assertNDRSyntaxEquals(self, s1, s2):
1042         self.assertEqual(s1.uuid, s2.uuid)
1043         self.assertEqual(s1.if_version, s2.if_version)
1044         return
1045
1046 class ValidNetbiosNameTests(TestCase):
1047
1048     def test_valid(self):
1049         self.assertTrue(samba.valid_netbios_name("FOO"))
1050
1051     def test_too_long(self):
1052         self.assertFalse(samba.valid_netbios_name("FOO"*10))
1053
1054     def test_invalid_characters(self):
1055         self.assertFalse(samba.valid_netbios_name("*BLA"))
1056
1057
1058 class BlackboxProcessError(Exception):
1059     """This is raised when check_output() process returns a non-zero exit status
1060
1061     Exception instance should contain the exact exit code (S.returncode),
1062     command line (S.cmd), process output (S.stdout) and process error stream
1063     (S.stderr)
1064     """
1065
1066     def __init__(self, returncode, cmd, stdout, stderr):
1067         self.returncode = returncode
1068         self.cmd = cmd
1069         self.stdout = stdout
1070         self.stderr = stderr
1071
1072     def __str__(self):
1073         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
1074                                                                              self.stdout, self.stderr)
1075
1076 class BlackboxTestCase(TestCaseInTempDir):
1077     """Base test case for blackbox tests."""
1078
1079     def _make_cmdline(self, line):
1080         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
1081         parts = line.split(" ")
1082         if os.path.exists(os.path.join(bindir, parts[0])):
1083             parts[0] = os.path.join(bindir, parts[0])
1084         line = " ".join(parts)
1085         return line
1086
1087     def check_run(self, line):
1088         line = self._make_cmdline(line)
1089         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
1090         retcode = p.wait()
1091         if retcode:
1092             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
1093
1094     def check_output(self, line):
1095         line = self._make_cmdline(line)
1096         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
1097         retcode = p.wait()
1098         if retcode:
1099             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
1100         return p.stdout.read()
1101
1102
1103 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
1104                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
1105     """Create SamDB instance and connects to samdb_url database.
1106
1107     :param samdb_url: Url for database to connect to.
1108     :param lp: Optional loadparm object
1109     :param session_info: Optional session information
1110     :param credentials: Optional credentials, defaults to anonymous.
1111     :param flags: Optional LDB flags
1112     :param ldap_only: If set, only remote LDAP connection will be created.
1113     :param global_schema: Whether to use global schema.
1114
1115     Added value for tests is that we have a shorthand function
1116     to make proper URL for ldb.connect() while using default
1117     parameters for connection based on test environment
1118     """
1119     if not "://" in samdb_url:
1120         if not ldap_only and os.path.isfile(samdb_url):
1121             samdb_url = "tdb://%s" % samdb_url
1122         else:
1123             samdb_url = "ldap://%s" % samdb_url
1124     # use 'paged_search' module when connecting remotely
1125     if samdb_url.startswith("ldap://"):
1126         ldb_options = ["modules:paged_searches"]
1127     elif ldap_only:
1128         raise AssertionError("Trying to connect to %s while remote "
1129                              "connection is required" % samdb_url)
1130
1131     # set defaults for test environment
1132     if lp is None:
1133         lp = env_loadparm()
1134     if session_info is None:
1135         session_info = samba.auth.system_session(lp)
1136     if credentials is None:
1137         credentials = cmdline_credentials
1138
1139     return SamDB(url=samdb_url,
1140                  lp=lp,
1141                  session_info=session_info,
1142                  credentials=credentials,
1143                  flags=flags,
1144                  options=ldb_options,
1145                  global_schema=global_schema)
1146
1147
1148 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
1149                      flags=0, ldb_options=None, ldap_only=False):
1150     """Connects to samdb_url database
1151
1152     :param samdb_url: Url for database to connect to.
1153     :param lp: Optional loadparm object
1154     :param session_info: Optional session information
1155     :param credentials: Optional credentials, defaults to anonymous.
1156     :param flags: Optional LDB flags
1157     :param ldap_only: If set, only remote LDAP connection will be created.
1158     :return: (sam_db_connection, rootDse_record) tuple
1159     """
1160     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
1161                            flags, ldb_options, ldap_only)
1162     # fetch RootDse
1163     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
1164                         attrs=["*"])
1165     return (sam_db, res[0])
1166
1167
1168 def connect_samdb_env(env_url, env_username, env_password, lp=None):
1169     """Connect to SamDB by getting URL and Credentials from environment
1170
1171     :param env_url: Environment variable name to get lsb url from
1172     :param env_username: Username environment variable
1173     :param env_password: Password environment variable
1174     :return: sam_db_connection
1175     """
1176     samdb_url = env_get_var_value(env_url)
1177     creds = credentials.Credentials()
1178     if lp is None:
1179         # guess Credentials parameters here. Otherwise workstation
1180         # and domain fields are NULL and gencache code segfalts
1181         lp = param.LoadParm()
1182         creds.guess(lp)
1183     creds.set_username(env_get_var_value(env_username))
1184     creds.set_password(env_get_var_value(env_password))
1185     return connect_samdb(samdb_url, credentials=creds, lp=lp)
1186
1187
1188 def delete_force(samdb, dn):
1189     try:
1190         samdb.delete(dn)
1191     except ldb.LdbError, (num, errstr):
1192         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr