Add tests for findnss(), add some docstrings.
authorJelmer Vernooij <jelmer@samba.org>
Sat, 9 Feb 2008 01:10:49 +0000 (02:10 +0100)
committerJelmer Vernooij <jelmer@samba.org>
Sat, 9 Feb 2008 01:10:49 +0000 (02:10 +0100)
(This used to be commit 4eec2bbc9a139e927ce21c615ebfbb3026b26384)

source4/scripting/python/samba/provision.py
source4/scripting/python/samba/tests/provision.py

index e15f205813d61536b74f6f84f91b9008a89d87d9..b094581fb43108f83c7f92cf4be9922882e5ff98 100644 (file)
@@ -81,14 +81,19 @@ def check_install(lp, session_info, credentials):
         raise "No administrator account found"
 
 
-def findnss(nssfn, *names):
-    """Find a user or group from a list of possibilities."""
+def findnss(nssfn, names):
+    """Find a user or group from a list of possibilities.
+    
+    :param nssfn: NSS Function to try (should raise KeyError if not found)
+    :param names: Names to check.
+    :return: Value return by first names list.
+    """
     for name in names:
         try:
             return nssfn(name)
         except KeyError:
             pass
-    raise Exception("Unable to find user/group for %s" % arguments[1])
+    raise KeyError("Unable to find user/group %r" % names)
 
 
 def open_ldb(session_info, credentials, lp, dbname):
@@ -146,6 +151,14 @@ def setup_modify_ldif(ldb, ldif_path, substvars=None):
 
 
 def setup_ldb(ldb, ldif_path, subst_vars):
+    """Import a LDIF a file into a LDB handle, optionally substituting variables.
+
+    :note: Either all LDIF data will be added or none (using transactions).
+
+    :param ldb: LDB file to import into.
+    :param ldif_path: Path to the LDIF file.
+    :param subst_vars: Dictionary with substitution variables.
+    """
     assert ldb is not None
     ldb.transaction_start()
     try:
@@ -716,18 +729,18 @@ def provision(lp, setup_dir, message, paths, session_info,
     if dnspass is None:
         dnspass = misc.random_password(12)
     if root is None:
-        root = findnss(pwd.getpwnam, "root")[0]
+        root = findnss(pwd.getpwnam, ["root"])[0]
     if nobody is None:
-        nobody = findnss(pwd.getpwnam, "nobody")[0]
+        nobody = findnss(pwd.getpwnam, ["nobody"])[0]
     if nogroup is None:
-        nogroup = findnss(grp.getgrnam, "nogroup", "nobody")[0]
+        nogroup = findnss(grp.getgrnam, ["nogroup", "nobody"])[0]
     if users is None:
-        users = findnss(grp.getgrnam, "users", "guest", "other", "unknown", 
-                        "usr")[0]
+        users = findnss(grp.getgrnam, ["users", "guest", "other", "unknown", 
+                        "usr"])[0]
     if wheel is None:
-        wheel = findnss(grp.getgrnam, "wheel", "root", "staff", "adm")[0]
+        wheel = findnss(grp.getgrnam, ["wheel", "root", "staff", "adm"])[0]
     if backup is None:
-        backup = findnss(grp.getgrnam, "backup", "wheel", "root", "staff")[0]
+        backup = findnss(grp.getgrnam, ["backup", "wheel", "root", "staff"])[0]
     if aci is None:
         aci = "# no aci for local ldb"
     if serverrole is None:
@@ -781,10 +794,10 @@ def provision(lp, setup_dir, message, paths, session_info,
         domain = netbiosname
     
     if rootdn is None:
-       rootdn       = domaindn
+       rootdn = domaindn
        
-    configdn     = "CN=Configuration," + rootdn
-    schemadn     = "CN=Schema," + configdn
+    configdn = "CN=Configuration," + rootdn
+    schemadn = "CN=Schema," + configdn
 
     message("set DOMAIN SID: %s" % str(domainsid))
     message("Provisioning for %s in realm %s" % (domain, realm))
index 4e9fa9c3ef85a0fbd59d221fd6fd663c0a9ccbac..eb49f7af83725bac6b843ab9d197232b7a8b7f74 100644 (file)
 #
 
 import os
-from samba.provision import setup_secretsdb, secretsdb_become_dc
+from samba.provision import setup_secretsdb, secretsdb_become_dc, findnss
 import samba.tests
 from ldb import Dn
 import param
+import unittest
 
 lp = param.LoadParm()
 lp.load("st/dc/etc/smb.conf")
@@ -66,6 +67,25 @@ class ProvisionTestCase(samba.tests.TestCaseInTempDir):
             del secrets_ldb
             os.unlink(path)
 
+
+class FindNssTests(unittest.TestCase):
+    """Test findnss() function."""
+    def test_nothing(self):
+        def x(y):
+            raise KeyError
+        self.assertRaises(KeyError, findnss, x, [])
+
+    def test_first(self):
+        self.assertEquals("bla", findnss(lambda x: "bla", ["bla"]))
+
+    def test_skip_first(self):
+        def x(y):
+            if y != "bla":
+                raise KeyError
+            return "ha"
+        self.assertEquals("ha", findnss(x, ["bloe", "bla"]))
+
+
 class Disabled:
     def test_setup_templatesdb(self):
         raise NotImplementedError(self.test_setup_templatesdb)
@@ -100,3 +120,4 @@ class Disabled:
     def test_erase_partitions(self):
         raise NotImplementedError(self.test_erase_partitions)
 
+