tests/krb5: Overhaul check_device_info()
authorJoseph Sutton <josephsutton@catalyst.net.nz>
Fri, 3 Mar 2023 00:41:19 +0000 (13:41 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Wed, 8 Mar 2023 04:39:32 +0000 (04:39 +0000)
With expected_device_groups, tests can now specify particular group
arrangements they expect to see.

Signed-off-by: Joseph Sutton <josephsutton@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/tests/krb5/raw_testcase.py

index 7911a2ca41ef0842e9bcf9dde4af593e92e1854d..007c2a0eb6136ea0cfaf1b903e54d4d9a49b924f 100644 (file)
@@ -2518,6 +2518,7 @@ class RawKerberosTest(TestCaseInTempDir):
                          expected_device_claims=None,
                          unexpected_device_claims=None,
                          expect_resource_groups_flag=None,
+                         expected_device_groups=None,
                          to_rodc=False):
         if expected_error_mode == 0:
             expected_error_mode = ()
@@ -2589,6 +2590,7 @@ class RawKerberosTest(TestCaseInTempDir):
             'expected_device_claims': expected_device_claims,
             'unexpected_device_claims': unexpected_device_claims,
             'expect_resource_groups_flag': expect_resource_groups_flag,
+            'expected_device_groups': expected_device_groups,
             'to_rodc': to_rodc
         }
         if callback_dict is None:
@@ -2609,6 +2611,7 @@ class RawKerberosTest(TestCaseInTempDir):
                           expected_sid=None,
                           expected_requester_sid=None,
                           expected_domain_sid=None,
+                          expected_device_domain_sid=None,
                           expected_supported_etypes=None,
                           expected_flags=None,
                           unexpected_flags=None,
@@ -2658,6 +2661,7 @@ class RawKerberosTest(TestCaseInTempDir):
                           expected_device_claims=None,
                           unexpected_device_claims=None,
                           expect_resource_groups_flag=None,
+                          expected_device_groups=None,
                           to_rodc=False):
         if expected_error_mode == 0:
             expected_error_mode = ()
@@ -2682,6 +2686,7 @@ class RawKerberosTest(TestCaseInTempDir):
             'expected_sid': expected_sid,
             'expected_requester_sid': expected_requester_sid,
             'expected_domain_sid': expected_domain_sid,
+            'expected_device_domain_sid': expected_device_domain_sid,
             'expected_supported_etypes': expected_supported_etypes,
             'expected_flags': expected_flags,
             'unexpected_flags': unexpected_flags,
@@ -2731,6 +2736,7 @@ class RawKerberosTest(TestCaseInTempDir):
             'expected_device_claims': expected_device_claims,
             'unexpected_device_claims': unexpected_device_claims,
             'expect_resource_groups_flag': expect_resource_groups_flag,
+            'expected_device_groups': expected_device_groups,
             'to_rodc': to_rodc
         }
         if callback_dict is None:
@@ -3315,39 +3321,93 @@ class RawKerberosTest(TestCaseInTempDir):
                 break
         else:
             self.fail('missing logon info for armor PAC')
-
         self.assertEqual(armor_info.base.rid, device_info.rid)
 
-        self.assertEqual(armor_info.base.primary_gid,
-                         device_info.primary_gid)
-        self.assertEqual(security.DOMAIN_RID_DOMAIN_MEMBERS,
-                         device_info.primary_gid)
+        device_domain_sid = kdc_exchange_dict['expected_device_domain_sid']
+        expected_device_groups = kdc_exchange_dict['expected_device_groups']
+        if kdc_exchange_dict['expect_device_info']:
+            self.assertIsNotNone(device_domain_sid)
+            self.assertIsNotNone(expected_device_groups)
 
-        self.assertEqual(armor_info.base.domain_sid,
-                         device_info.domain_sid)
+        if device_domain_sid is not None:
+            self.assertEqual(device_domain_sid, str(device_info.domain_sid))
+        else:
+            device_domain_sid = str(device_info.domain_sid)
 
-        def get_groups(groups):
-            return [(x.rid, x.attributes) for x in groups.rids]
+        # Check the device info SIDs.
 
-        self.assertEqual(get_groups(armor_info.base.groups),
-                         get_groups(device_info.groups))
+        # A representation of the device info groups.
+        primary_sid = f'{device_domain_sid}-{device_info.primary_gid}'
+        got_sids = {
+            (primary_sid, self.SidType.PRIMARY_GID, None),
+        }
 
-        self.assertEqual(1, device_info.sid_count)
-        self.assertEqual(
-            security.SID_AUTHENTICATION_AUTHORITY_ASSERTED_IDENTITY,
-            str(device_info.sids[0].sid))
+        # Collect the groups.
+        if device_info.groups.rids is not None:
+            self.assertTrue(device_info.groups.rids, 'got empty RIDs')
+
+            for group in device_info.groups.rids:
+                got_sid = f'{device_domain_sid}-{group.rid}'
+
+                device_sid = (got_sid, self.SidType.BASE_SID, group.attributes)
+                self.assertNotIn(device_sid, got_sids, 'got duplicated SID')
+                got_sids.add(device_sid)
+
+        # Collect the SIDs.
+        if device_info.sids is not None:
+            self.assertTrue(device_info.sids, 'got empty SIDs')
+
+            for sid_attr in device_info.sids:
+                got_sid = str(sid_attr.sid)
 
-        claims_valid_sid, claims_valid_rid = (
-            security.SID_CLAIMS_VALID.rsplit('-', 1))
+                in_a_domain = sid_attr.sid.num_auths == 5 and (
+                    str(sid_attr.sid).startswith('S-1-5-21-'))
+                self.assertFalse(in_a_domain,
+                                 f'got unexpected SID for domain: {got_sid} '
+                                 f'(should be in device_info.domain_groups)')
 
-        self.assertEqual(1, device_info.domain_group_count)
-        domain_group = device_info.domain_groups[0]
-        self.assertEqual(claims_valid_sid,
-                         str(domain_group.domain_sid))
+                device_sid = (got_sid,
+                              self.SidType.EXTRA_SID,
+                              sid_attr.attributes)
+                self.assertNotIn(device_sid, got_sids, 'got duplicated SID')
+                got_sids.add(device_sid)
 
-        self.assertEqual(1, domain_group.groups.count)
-        self.assertEqual(int(claims_valid_rid),
-                         domain_group.groups.rids[0].rid)
+        # Collect the domain groups.
+        if device_info.domain_groups is not None:
+            self.assertTrue(device_info.domain_groups, 'got empty domain groups')
+
+            for domain_group in device_info.domain_groups:
+                self.assertTrue(domain_group, 'got empty domain group')
+
+                got_domain_sids = set()
+
+                resource_group_sid = domain_group.domain_sid
+
+                in_a_domain = resource_group_sid.num_auths == 4 and (
+                    str(resource_group_sid).startswith('S-1-5-21-'))
+                self.assertTrue(
+                    in_a_domain,
+                    f'got unexpected domain SID for non-domain: {resource_group_sid} '
+                    f'(should be in device_info.sids)')
+
+                for resource_group in domain_group.groups.rids:
+                    got_sid = f'{resource_group_sid}-{resource_group.rid}'
+
+                    device_sid = (got_sid,
+                                  self.SidType.RESOURCE_SID,
+                                  resource_group.attributes)
+                    self.assertNotIn(device_sid, got_domain_sids, 'got duplicated SID')
+                    got_domain_sids.add(device_sid)
+
+                got_domain_sids = frozenset(got_domain_sids)
+                self.assertNotIn(got_domain_sids, got_sids)
+                got_sids.add(got_domain_sids)
+
+        # Compare the aggregated device SIDs against the set of expected device
+        # SIDs.
+        if expected_device_groups is not None:
+            self.assertEqual(expected_device_groups, got_sids,
+                             'expected != got')
 
     def check_pac_buffers(self, pac_data, kdc_exchange_dict):
         pac = ndr_unpack(krb5pac.PAC_DATA, pac_data)
@@ -3404,6 +3464,8 @@ class RawKerberosTest(TestCaseInTempDir):
         expected_device_claims = kdc_exchange_dict['expected_device_claims']
         unexpected_device_claims = kdc_exchange_dict['unexpected_device_claims']
 
+        expected_device_groups = kdc_exchange_dict['expected_device_groups']
+
         if (self.kdc_claims_support and self.kdc_compound_id_support
                 and expect_device_claims and compound_id):
             expected_types.append(krb5pac.PAC_TYPE_DEVICE_CLAIMS_INFO)
@@ -3430,6 +3492,9 @@ class RawKerberosTest(TestCaseInTempDir):
             self.assertFalse(expect_device_info,
                              'expected device info with no armor TGT or '
                              'for non-TGS request')
+            self.assertFalse(expected_device_groups,
+                             'expected device groups, but device info not '
+                             'expected in PAC')
 
             if expect_device_info is None and compound_id:
                 unchecked.add(krb5pac.PAC_TYPE_DEVICE_INFO)