selftest: Expand tokenGroups test to also compare with samr.GetGroupsForUser
authorAndrew Bartlett <abartlet@samba.org>
Wed, 8 Jun 2016 02:46:07 +0000 (14:46 +1200)
committerGarming Sam <garming@samba.org>
Thu, 16 Jun 2016 02:40:12 +0000 (04:40 +0200)
Signed-off-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Garming Sam <garming@catalyst.net.nz>
source4/dsdb/tests/python/token_group.py

index a04765bb4668b9e51f39dd9597d410a9c09b1963..9143077cccddba61cd5546dabaa90fa2ec3e4562 100755 (executable)
@@ -23,7 +23,7 @@ from samba.credentials import Credentials, DONT_USE_KERBEROS
 from samba.dsdb import GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP
 import samba.tests
 from samba.tests import delete_force
-
+from samba.dcerpc import samr, security
 from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES
 
 
@@ -513,6 +513,76 @@ class DynamicTokenTest(samba.tests.TestCase):
         if len(tokenGroupsSet.difference(T)):
             self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" % tokenGroupsSet.difference(T))
 
+    def test_samr_GetGroupsForUser(self):
+        # Confirm that we get the correct results against SAMR also
+        if not url.startswith("ldap://"):
+            self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
+        host = url.split("://")[1]
+        (domain_sid, user_rid) = self.user_sid.split()
+        samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[sign]" % host, lp, creds)
+        samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
+        samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
+                                      domain_sid)
+        user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
+        rids = samr_conn.GetGroupsForUser(user_handle)
+        samr_dns = set()
+        for rid in rids.rids:
+            self.assertEqual(rid.attributes, samr.SE_GROUP_MANDATORY | samr.SE_GROUP_ENABLED_BY_DEFAULT| samr.SE_GROUP_ENABLED)
+            sid = "%s-%d" % (domain_sid, rid.rid)
+            res = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
+                                  attrs=[])
+            samr_dns.add(res[0].dn.get_casefold())
+
+        user_info = samr_conn.QueryUserInfo(user_handle, 1)
+
+        tokenGroupsSet = set()
+        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
+        for sid in res[0]['tokenGroupsGlobalAndUniversal']:
+            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
+            res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
+                                         attrs=[],
+                                         expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
+                                         % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
+            if len(res) == 1:
+                tokenGroupsSet.add(res3[0].dn.get_casefold())
+
+        if len(samr_dns.difference(tokenGroupsSet)):
+            self.fail(msg="additional samr_GetUserGroups over tokenGroups: %s" % samr_dns.difference(tokenGroupsSet))
+
+        memberOf = set()
+        # Add the primary group
+        primary_group_sid = "%s-%d" % (domain_sid, user_info.primary_gid)
+        res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
+                                     attrs=[])
+
+        memberOf.add(res2[0].dn.get_casefold())
+        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["memberOf"])
+        for dn in res[0]['memberOf']:
+            res3 = self.admin_ldb.search(base=dn, scope=ldb.SCOPE_BASE,
+                                         attrs=[],
+                                         expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
+                                         % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
+            if len(res3) == 1:
+                memberOf.add(res3[0].dn.get_casefold())
+
+        if len(memberOf.difference(samr_dns)):
+            self.fail(msg="additional memberOf over samr_GetUserGroups: %s" % memberOf.difference(samr_dns))
+
+        if len(samr_dns.difference(memberOf)):
+            self.fail(msg="additional samr_GetUserGroups over memberOf: %s" % samr_dns.difference(memberOf))
+
+        S = set()
+        S.add(self.test_user_dn.get_casefold())
+
+        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
+        self.filtered_closure(S, GTYPE_SECURITY_UNIVERSAL_GROUP)
+
+        # Now remove the user DN and primary group
+        S.remove(self.test_user_dn.get_casefold())
+
+        if len(samr_dns.difference(S)):
+            self.fail(msg="additional samr_GetUserGroups over filtered_closure: %s" % samr_dns.difference(S))
+
 if not "://" in url:
     if os.path.isfile(url):
         url = "tdb://%s" % url