Add tests for findnss(), add some docstrings.
[ira/wip.git] / source / scripting / python / samba / tests / provision.py
index 1456b6751ce555171686024e96b4d40ca6e06529..eb49f7af83725bac6b843ab9d197232b7a8b7f74 100644 (file)
 #
 
 import os
-from samba.provision import setup_secretsdb, secretsdb_become_dc
+from samba.provision import setup_secretsdb, secretsdb_become_dc, findnss
 import samba.tests
 from ldb import Dn
+import param
+import unittest
+
+lp = param.LoadParm()
+lp.load("st/dc/etc/smb.conf")
 
 setup_dir = "setup"
 def setup_path(file):
@@ -30,7 +35,7 @@ def setup_path(file):
 class ProvisionTestCase(samba.tests.TestCaseInTempDir):
     def test_setup_secretsdb(self):
         path = os.path.join(self.tempdir, "secrets.ldb")
-        ldb = setup_secretsdb(path, setup_path, None, None, None)
+        ldb = setup_secretsdb(path, setup_path, None, None, lp=lp)
         try:
             self.assertEquals("LSA Secrets",
                  ldb.searchone(basedn="CN=LSA Secrets", attribute="CN"))
@@ -40,7 +45,7 @@ class ProvisionTestCase(samba.tests.TestCaseInTempDir):
             
     def test_become_dc(self):
         path = os.path.join(self.tempdir, "secrets.ldb")
-        secrets_ldb = setup_secretsdb(path, setup_path, None, None, None)
+        secrets_ldb = setup_secretsdb(path, setup_path, None, None, lp=lp)
         try:
             secretsdb_become_dc(secrets_ldb, setup_path, domain="EXAMPLE", 
                    realm="example", netbiosname="myhost", 
@@ -62,6 +67,25 @@ class ProvisionTestCase(samba.tests.TestCaseInTempDir):
             del secrets_ldb
             os.unlink(path)
 
+
+class FindNssTests(unittest.TestCase):
+    """Test findnss() function."""
+    def test_nothing(self):
+        def x(y):
+            raise KeyError
+        self.assertRaises(KeyError, findnss, x, [])
+
+    def test_first(self):
+        self.assertEquals("bla", findnss(lambda x: "bla", ["bla"]))
+
+    def test_skip_first(self):
+        def x(y):
+            if y != "bla":
+                raise KeyError
+            return "ha"
+        self.assertEquals("ha", findnss(x, ["bloe", "bla"]))
+
+
 class Disabled:
     def test_setup_templatesdb(self):
         raise NotImplementedError(self.test_setup_templatesdb)
@@ -96,3 +120,4 @@ class Disabled:
     def test_erase_partitions(self):
         raise NotImplementedError(self.test_erase_partitions)
 
+