d9efde8273a8302cf4c226bc105f5b4a08645060
[samba.git] / python / samba / tests / krb5 / kdc_base_test.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Stefan Metzmacher 2020
3 # Copyright (C) 2020-2021 Catalyst.Net Ltd
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 import sys
20 import os
21 from datetime import datetime, timezone
22 import tempfile
23 import binascii
24 import collections
25 import secrets
26 from enum import Enum
27
28 from collections import namedtuple
29 import ldb
30 from ldb import SCOPE_BASE
31 from samba import generate_random_password
32 from samba.auth import system_session
33 from samba.credentials import (
34     Credentials,
35     SPECIFIED,
36     DONT_USE_KERBEROS,
37     MUST_USE_KERBEROS,
38 )
39 from samba.dcerpc import drsblobs, drsuapi, misc, krb5pac, krb5ccache, security
40 from samba.drs_utils import drs_Replicate, drsuapi_connect
41 from samba.dsdb import (
42     DSDB_SYNTAX_BINARY_DN,
43     DS_DOMAIN_FUNCTION_2000,
44     DS_DOMAIN_FUNCTION_2008,
45     DS_GUID_COMPUTERS_CONTAINER,
46     DS_GUID_DOMAIN_CONTROLLERS_CONTAINER,
47     DS_GUID_USERS_CONTAINER,
48     UF_WORKSTATION_TRUST_ACCOUNT,
49     UF_NO_AUTH_DATA_REQUIRED,
50     UF_NORMAL_ACCOUNT,
51     UF_NOT_DELEGATED,
52     UF_PARTIAL_SECRETS_ACCOUNT,
53     UF_SERVER_TRUST_ACCOUNT,
54     UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION
55 )
56 from samba.join import DCJoinContext
57 from samba.ndr import ndr_pack, ndr_unpack
58 from samba import net
59 from samba.samdb import SamDB, dsdb_Dn
60
61 from samba.tests import delete_force
62 import samba.tests.krb5.kcrypto as kcrypto
63 from samba.tests.krb5.raw_testcase import (
64     KerberosCredentials,
65     KerberosTicketCreds,
66     RawKerberosTest
67 )
68 import samba.tests.krb5.rfc4120_pyasn1 as krb5_asn1
69 from samba.tests.krb5.rfc4120_constants import (
70     AD_IF_RELEVANT,
71     AD_WIN2K_PAC,
72     AES256_CTS_HMAC_SHA1_96,
73     ARCFOUR_HMAC_MD5,
74     KDC_ERR_PREAUTH_REQUIRED,
75     KRB_AS_REP,
76     KRB_TGS_REP,
77     KRB_ERROR,
78     KU_AS_REP_ENC_PART,
79     KU_ENC_CHALLENGE_CLIENT,
80     KU_PA_ENC_TIMESTAMP,
81     KU_TICKET,
82     NT_PRINCIPAL,
83     NT_SRV_INST,
84     PADATA_ENCRYPTED_CHALLENGE,
85     PADATA_ENC_TIMESTAMP,
86     PADATA_ETYPE_INFO2,
87 )
88
89 sys.path.insert(0, "bin/python")
90 os.environ["PYTHONUNBUFFERED"] = "1"
91
92 global_asn1_print = False
93 global_hexdump = False
94
95
96 class KDCBaseTest(RawKerberosTest):
97     """ Base class for KDC tests.
98     """
99
100     class AccountType(Enum):
101         USER = object()
102         COMPUTER = object()
103         SERVER = object()
104         RODC = object()
105
106     @classmethod
107     def setUpClass(cls):
108         super().setUpClass()
109         cls._lp = None
110
111         cls._ldb = None
112         cls._rodc_ldb = None
113
114         cls._functional_level = None
115
116         # An identifier to ensure created accounts have unique names. Windows
117         # caches accounts based on usernames, so account names being different
118         # across test runs avoids previous test runs affecting the results.
119         cls.account_base = f'{secrets.token_hex(4)}_'
120         cls.account_id = 0
121
122         # A list containing DNs of accounts created as part of testing.
123         cls.accounts = []
124
125         cls.account_cache = {}
126         cls.tkt_cache = {}
127
128         cls._rodc_ctx = None
129
130         cls.ldb_cleanups = []
131
132     @classmethod
133     def tearDownClass(cls):
134         # Clean up any accounts created by create_account. This is
135         # done in tearDownClass() rather than tearDown(), so that
136         # accounts need only be created once for permutation tests.
137         if cls._ldb is not None:
138             for cleanup in reversed(cls.ldb_cleanups):
139                 try:
140                     cls._ldb.modify(cleanup)
141                 except ldb.LdbError:
142                     pass
143
144             for dn in reversed(cls.accounts):
145                 delete_force(cls._ldb, dn)
146
147         if cls._rodc_ctx is not None:
148             cls._rodc_ctx.cleanup_old_join(force=True)
149
150         super().tearDownClass()
151
152     def setUp(self):
153         super().setUp()
154         self.do_asn1_print = global_asn1_print
155         self.do_hexdump = global_hexdump
156
157     def get_lp(self):
158         if self._lp is None:
159             type(self)._lp = self.get_loadparm()
160
161         return self._lp
162
163     def get_samdb(self):
164         if self._ldb is None:
165             creds = self.get_admin_creds()
166             lp = self.get_lp()
167
168             session = system_session()
169             type(self)._ldb = SamDB(url="ldap://%s" % self.dc_host,
170                                     session_info=session,
171                                     credentials=creds,
172                                     lp=lp)
173
174         return self._ldb
175
176     def get_rodc_samdb(self):
177         if self._rodc_ldb is None:
178             creds = self.get_admin_creds()
179             lp = self.get_lp()
180
181             session = system_session()
182             type(self)._rodc_ldb = SamDB(url="ldap://%s" % self.host,
183                                          session_info=session,
184                                          credentials=creds,
185                                          lp=lp,
186                                          am_rodc=True)
187
188         return self._rodc_ldb
189
190     def get_server_dn(self, samdb):
191         server = samdb.get_serverName()
192
193         res = samdb.search(base=server,
194                            scope=ldb.SCOPE_BASE,
195                            attrs=['serverReference'])
196         dn = ldb.Dn(samdb, res[0]['serverReference'][0].decode('utf8'))
197
198         return dn
199
200     def get_mock_rodc_ctx(self):
201         if self._rodc_ctx is None:
202             admin_creds = self.get_admin_creds()
203             lp = self.get_lp()
204
205             rodc_name = self.get_new_username()
206             site_name = 'Default-First-Site-Name'
207
208             rodc_ctx = DCJoinContext(server=self.dc_host,
209                                      creds=admin_creds,
210                                      lp=lp,
211                                      site=site_name,
212                                      netbios_name=rodc_name,
213                                      targetdir=None,
214                                      domain=None)
215             self.create_rodc(rodc_ctx)
216
217             type(self)._rodc_ctx = rodc_ctx
218
219         return self._rodc_ctx
220
221     def get_domain_functional_level(self, ldb):
222         if self._functional_level is None:
223             res = ldb.search(base='',
224                              scope=SCOPE_BASE,
225                              attrs=['domainFunctionality'])
226             try:
227                 functional_level = int(res[0]['domainFunctionality'][0])
228             except KeyError:
229                 functional_level = DS_DOMAIN_FUNCTION_2000
230
231             type(self)._functional_level = functional_level
232
233         return self._functional_level
234
235     def get_default_enctypes(self):
236         samdb = self.get_samdb()
237         functional_level = self.get_domain_functional_level(samdb)
238
239         # RC4 should always be supported
240         default_enctypes = {kcrypto.Enctype.RC4}
241         if functional_level >= DS_DOMAIN_FUNCTION_2008:
242             # AES is only supported at functional level 2008 or higher
243             default_enctypes.add(kcrypto.Enctype.AES256)
244             default_enctypes.add(kcrypto.Enctype.AES128)
245
246         return default_enctypes
247
248     def create_group(self, samdb, name, ou=None):
249         if ou is None:
250             ou = samdb.get_wellknown_dn(samdb.get_default_basedn(),
251                                         DS_GUID_USERS_CONTAINER)
252
253         dn = f'CN={name},{ou}'
254
255         # Remove the group if it exists; this will happen if a previous test
256         # run failed.
257         delete_force(samdb, dn)
258
259         # Save the group name so it can be deleted in tearDownClass.
260         self.accounts.append(dn)
261
262         details = {
263             'dn': dn,
264             'objectClass': 'group'
265         }
266         samdb.add(details)
267
268         return dn
269
270     def create_account(self, samdb, name, account_type=AccountType.USER,
271                        spn=None, upn=None, additional_details=None,
272                        ou=None, account_control=0, add_dollar=True):
273         '''Create an account for testing.
274            The dn of the created account is added to self.accounts,
275            which is used by tearDownClass to clean up the created accounts.
276         '''
277         if ou is None:
278             if account_type is account_type.COMPUTER:
279                 guid = DS_GUID_COMPUTERS_CONTAINER
280             elif account_type is account_type.SERVER:
281                 guid = DS_GUID_DOMAIN_CONTROLLERS_CONTAINER
282             else:
283                 guid = DS_GUID_USERS_CONTAINER
284
285             ou = samdb.get_wellknown_dn(samdb.get_default_basedn(), guid)
286
287         dn = "CN=%s,%s" % (name, ou)
288
289         # remove the account if it exists, this will happen if a previous test
290         # run failed
291         delete_force(samdb, dn)
292         account_name = name
293         if account_type is self.AccountType.USER:
294             object_class = "user"
295             account_control |= UF_NORMAL_ACCOUNT
296         else:
297             object_class = "computer"
298             if add_dollar:
299                 account_name += '$'
300             if account_type is self.AccountType.COMPUTER:
301                 account_control |= UF_WORKSTATION_TRUST_ACCOUNT
302             elif account_type is self.AccountType.SERVER:
303                 account_control |= UF_SERVER_TRUST_ACCOUNT
304             else:
305                 self.fail()
306
307         password = generate_random_password(32, 32)
308         utf16pw = ('"%s"' % password).encode('utf-16-le')
309
310         details = {
311             "dn": dn,
312             "objectclass": object_class,
313             "sAMAccountName": account_name,
314             "userAccountControl": str(account_control),
315             "unicodePwd": utf16pw}
316         if upn is not None:
317             upn = upn.format(account=account_name)
318         if spn is not None:
319             if isinstance(spn, str):
320                 spn = spn.format(account=account_name)
321             else:
322                 spn = tuple(s.format(account=account_name) for s in spn)
323             details["servicePrincipalName"] = spn
324         if upn is not None:
325             details["userPrincipalName"] = upn
326         if additional_details is not None:
327             details.update(additional_details)
328         # Save the account name so it can be deleted in tearDownClass
329         self.accounts.append(dn)
330         samdb.add(details)
331
332         creds = KerberosCredentials()
333         creds.guess(self.get_lp())
334         creds.set_realm(samdb.domain_dns_name().upper())
335         creds.set_domain(samdb.domain_netbios_name().upper())
336         creds.set_password(password)
337         creds.set_username(account_name)
338         if account_type is self.AccountType.USER:
339             creds.set_workstation('')
340         else:
341             creds.set_workstation(name)
342         creds.set_dn(ldb.Dn(samdb, dn))
343         creds.set_upn(upn)
344         creds.set_spn(spn)
345
346         self.creds_set_enctypes(creds)
347
348         res = samdb.search(base=dn,
349                            scope=ldb.SCOPE_BASE,
350                            attrs=['msDS-KeyVersionNumber'])
351         kvno = res[0].get('msDS-KeyVersionNumber', idx=0)
352         if kvno is not None:
353             self.assertEqual(int(kvno), 1)
354         creds.set_kvno(1)
355
356         return (creds, dn)
357
358     def get_security_descriptor(self, dn):
359         samdb = self.get_samdb()
360
361         sid = self.get_objectSid(samdb, dn)
362
363         owner_sid = security.dom_sid(security.SID_BUILTIN_ADMINISTRATORS)
364
365         ace = security.ace()
366         ace.access_mask = security.SEC_ADS_CONTROL_ACCESS
367
368         ace.trustee = security.dom_sid(sid)
369
370         dacl = security.acl()
371         dacl.revision = security.SECURITY_ACL_REVISION_ADS
372         dacl.aces = [ace]
373         dacl.num_aces = 1
374
375         security_desc = security.descriptor()
376         security_desc.type |= security.SEC_DESC_DACL_PRESENT
377         security_desc.owner_sid = owner_sid
378         security_desc.dacl = dacl
379
380         return ndr_pack(security_desc)
381
382     def create_rodc(self, ctx):
383         ctx.nc_list = [ctx.base_dn, ctx.config_dn, ctx.schema_dn]
384         ctx.full_nc_list = [ctx.base_dn, ctx.config_dn, ctx.schema_dn]
385         ctx.krbtgt_dn = f'CN=krbtgt_{ctx.myname},CN=Users,{ctx.base_dn}'
386
387         ctx.never_reveal_sid = [f'<SID={ctx.domsid}-{security.DOMAIN_RID_RODC_DENY}>',
388                                 f'<SID={security.SID_BUILTIN_ADMINISTRATORS}>',
389                                 f'<SID={security.SID_BUILTIN_SERVER_OPERATORS}>',
390                                 f'<SID={security.SID_BUILTIN_BACKUP_OPERATORS}>',
391                                 f'<SID={security.SID_BUILTIN_ACCOUNT_OPERATORS}>']
392         ctx.reveal_sid = f'<SID={ctx.domsid}-{security.DOMAIN_RID_RODC_ALLOW}>'
393
394         mysid = ctx.get_mysid()
395         admin_dn = f'<SID={mysid}>'
396         ctx.managedby = admin_dn
397
398         ctx.userAccountControl = (UF_WORKSTATION_TRUST_ACCOUNT |
399                                   UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION |
400                                   UF_PARTIAL_SECRETS_ACCOUNT)
401
402         ctx.connection_dn = f'CN=RODC Connection (FRS),{ctx.ntds_dn}'
403         ctx.secure_channel_type = misc.SEC_CHAN_RODC
404         ctx.RODC = True
405         ctx.replica_flags = (drsuapi.DRSUAPI_DRS_INIT_SYNC |
406                              drsuapi.DRSUAPI_DRS_PER_SYNC |
407                              drsuapi.DRSUAPI_DRS_GET_ANC |
408                              drsuapi.DRSUAPI_DRS_NEVER_SYNCED |
409                              drsuapi.DRSUAPI_DRS_SPECIAL_SECRET_PROCESSING)
410         ctx.domain_replica_flags = ctx.replica_flags | drsuapi.DRSUAPI_DRS_CRITICAL_ONLY
411
412         ctx.build_nc_lists()
413
414         ctx.cleanup_old_join()
415
416         try:
417             ctx.join_add_objects()
418         except Exception:
419             # cleanup the failed join (checking we still have a live LDB
420             # connection to the remote DC first)
421             ctx.refresh_ldb_connection()
422             ctx.cleanup_old_join()
423             raise
424
425     def replicate_account_to_rodc(self, dn):
426         samdb = self.get_samdb()
427         rodc_samdb = self.get_rodc_samdb()
428
429         repl_val = f'{samdb.get_dsServiceName()}:{dn}:SECRETS_ONLY'
430
431         msg = ldb.Message()
432         msg.dn = ldb.Dn(rodc_samdb, '')
433         msg['replicateSingleObject'] = ldb.MessageElement(
434             repl_val,
435             ldb.FLAG_MOD_REPLACE,
436             'replicateSingleObject')
437
438         try:
439             # Try replication using the replicateSingleObject rootDSE
440             # operation.
441             rodc_samdb.modify(msg)
442         except ldb.LdbError as err:
443             enum, estr = err.args
444             self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)
445             self.assertIn('rootdse_modify: unknown attribute to change!',
446                           estr)
447
448             # If that method wasn't supported, we may be in the rodc:local test
449             # environment, where we can try replicating to the local database.
450
451             lp = self.get_lp()
452
453             rodc_creds = Credentials()
454             rodc_creds.guess(lp)
455             rodc_creds.set_machine_account(lp)
456
457             local_samdb = SamDB(url=None, session_info=system_session(),
458                                 credentials=rodc_creds, lp=lp)
459
460             destination_dsa_guid = misc.GUID(local_samdb.get_ntds_GUID())
461
462             repl = drs_Replicate(f'ncacn_ip_tcp:{self.dc_host}[seal]',
463                                  lp, rodc_creds,
464                                  local_samdb, destination_dsa_guid)
465
466             source_dsa_invocation_id = misc.GUID(samdb.invocation_id)
467
468             repl.replicate(dn,
469                            source_dsa_invocation_id,
470                            destination_dsa_guid,
471                            exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET,
472                            rodc=True)
473
474     def reveal_account_to_mock_rodc(self, dn):
475         samdb = self.get_samdb()
476         rodc_ctx = self.get_mock_rodc_ctx()
477
478         self.get_secrets(
479             samdb,
480             dn,
481             destination_dsa_guid=rodc_ctx.ntds_guid,
482             source_dsa_invocation_id=misc.GUID(samdb.invocation_id))
483
484     def check_revealed(self, dn, rodc_dn, revealed=True):
485         samdb = self.get_samdb()
486
487         res = samdb.search(base=rodc_dn,
488                            scope=ldb.SCOPE_BASE,
489                            attrs=['msDS-RevealedUsers'])
490
491         revealed_users = res[0].get('msDS-RevealedUsers')
492         if revealed_users is None:
493             self.assertFalse(revealed)
494             return
495
496         revealed_dns = set(str(dsdb_Dn(samdb, str(user),
497                                        syntax_oid=DSDB_SYNTAX_BINARY_DN).dn)
498                            for user in revealed_users)
499
500         if revealed:
501             self.assertIn(str(dn), revealed_dns)
502         else:
503             self.assertNotIn(str(dn), revealed_dns)
504
505     def get_secrets(self, samdb, dn,
506                     destination_dsa_guid,
507                     source_dsa_invocation_id):
508         admin_creds = self.get_admin_creds()
509
510         dns_hostname = samdb.host_dns_name()
511         (bind, handle, _) = drsuapi_connect(dns_hostname,
512                                             self.get_lp(),
513                                             admin_creds)
514
515         req = drsuapi.DsGetNCChangesRequest8()
516
517         req.destination_dsa_guid = destination_dsa_guid
518         req.source_dsa_invocation_id = source_dsa_invocation_id
519
520         naming_context = drsuapi.DsReplicaObjectIdentifier()
521         naming_context.dn = dn
522
523         req.naming_context = naming_context
524
525         hwm = drsuapi.DsReplicaHighWaterMark()
526         hwm.tmp_highest_usn = 0
527         hwm.reserved_usn = 0
528         hwm.highest_usn = 0
529
530         req.highwatermark = hwm
531         req.uptodateness_vector = None
532
533         req.replica_flags = 0
534
535         req.max_object_count = 1
536         req.max_ndr_size = 402116
537         req.extended_op = drsuapi.DRSUAPI_EXOP_REPL_SECRET
538
539         attids = [drsuapi.DRSUAPI_ATTID_supplementalCredentials,
540                   drsuapi.DRSUAPI_ATTID_unicodePwd]
541
542         partial_attribute_set = drsuapi.DsPartialAttributeSet()
543         partial_attribute_set.version = 1
544         partial_attribute_set.attids = attids
545         partial_attribute_set.num_attids = len(attids)
546
547         req.partial_attribute_set = partial_attribute_set
548
549         req.partial_attribute_set_ex = None
550         req.mapping_ctr.num_mappings = 0
551         req.mapping_ctr.mappings = None
552
553         _, ctr = bind.DsGetNCChanges(handle, 8, req)
554
555         self.assertEqual(1, ctr.object_count)
556
557         identifier = ctr.first_object.object.identifier
558         attributes = ctr.first_object.object.attribute_ctr.attributes
559
560         self.assertEqual(dn, identifier.dn)
561
562         return bind, identifier, attributes
563
564     def get_keys(self, samdb, dn, expected_etypes=None):
565         admin_creds = self.get_admin_creds()
566
567         bind, identifier, attributes = self.get_secrets(
568             samdb,
569             str(dn),
570             destination_dsa_guid=misc.GUID(samdb.get_ntds_GUID()),
571             source_dsa_invocation_id=misc.GUID())
572
573         rid = identifier.sid.split()[1]
574
575         net_ctx = net.Net(admin_creds)
576
577         keys = {}
578
579         for attr in attributes:
580             if attr.attid == drsuapi.DRSUAPI_ATTID_supplementalCredentials:
581                 net_ctx.replicate_decrypt(bind, attr, rid)
582                 attr_val = attr.value_ctr.values[0].blob
583
584                 spl = ndr_unpack(drsblobs.supplementalCredentialsBlob,
585                                  attr_val)
586                 for pkg in spl.sub.packages:
587                     if pkg.name == 'Primary:Kerberos-Newer-Keys':
588                         krb5_new_keys_raw = binascii.a2b_hex(pkg.data)
589                         krb5_new_keys = ndr_unpack(
590                             drsblobs.package_PrimaryKerberosBlob,
591                             krb5_new_keys_raw)
592                         for key in krb5_new_keys.ctr.keys:
593                             keytype = key.keytype
594                             if keytype in (kcrypto.Enctype.AES256,
595                                            kcrypto.Enctype.AES128):
596                                 keys[keytype] = key.value.hex()
597             elif attr.attid == drsuapi.DRSUAPI_ATTID_unicodePwd:
598                 net_ctx.replicate_decrypt(bind, attr, rid)
599                 pwd = attr.value_ctr.values[0].blob
600                 keys[kcrypto.Enctype.RC4] = pwd.hex()
601
602         if expected_etypes is None:
603             expected_etypes = self.get_default_enctypes()
604
605         self.assertCountEqual(expected_etypes, keys)
606
607         return keys
608
609     def creds_set_keys(self, creds, keys):
610         if keys is not None:
611             for enctype, key in keys.items():
612                 creds.set_forced_key(enctype, key)
613
614     def creds_set_enctypes(self, creds):
615         samdb = self.get_samdb()
616
617         res = samdb.search(creds.get_dn(),
618                            scope=ldb.SCOPE_BASE,
619                            attrs=['msDS-SupportedEncryptionTypes'])
620         supported_enctypes = res[0].get('msDS-SupportedEncryptionTypes', idx=0)
621
622         if supported_enctypes is None:
623             supported_enctypes = 0
624
625         creds.set_as_supported_enctypes(supported_enctypes)
626         creds.set_tgs_supported_enctypes(supported_enctypes)
627         creds.set_ap_supported_enctypes(supported_enctypes)
628
629     def creds_set_default_enctypes(self, creds,
630                                    fast_support=False,
631                                    claims_support=False,
632                                    compound_id_support=False):
633         default_enctypes = self.get_default_enctypes()
634         supported_enctypes = KerberosCredentials.etypes_to_bits(
635             default_enctypes)
636
637         if fast_support:
638             supported_enctypes |= security.KERB_ENCTYPE_FAST_SUPPORTED
639         if claims_support:
640             supported_enctypes |= security.KERB_ENCTYPE_CLAIMS_SUPPORTED
641         if compound_id_support:
642             supported_enctypes |= (
643                 security.KERB_ENCTYPE_COMPOUND_IDENTITY_SUPPORTED)
644
645         creds.set_as_supported_enctypes(supported_enctypes)
646         creds.set_tgs_supported_enctypes(supported_enctypes)
647         creds.set_ap_supported_enctypes(supported_enctypes)
648
649     def add_to_group(self, account_dn, group_dn, group_attr, expect_attr=True):
650         samdb = self.get_samdb()
651
652         try:
653             res = samdb.search(base=group_dn,
654                                scope=ldb.SCOPE_BASE,
655                                attrs=[group_attr])
656         except ldb.LdbError as err:
657             num, _ = err.args
658             if num != ldb.ERR_NO_SUCH_OBJECT:
659                 raise
660
661             self.fail(err)
662
663         orig_msg = res[0]
664         members = orig_msg.get(group_attr)
665         if expect_attr:
666             self.assertIsNotNone(members)
667         elif members is None:
668             members = ()
669
670         members = list(members)
671         members.append(account_dn)
672
673         msg = ldb.Message()
674         msg.dn = group_dn
675         msg[group_attr] = ldb.MessageElement(members,
676                                              ldb.FLAG_MOD_REPLACE,
677                                              group_attr)
678
679         cleanup = samdb.msg_diff(msg, orig_msg)
680         self.ldb_cleanups.append(cleanup)
681         samdb.modify(msg)
682
683         return cleanup
684
685     def remove_from_group(self, account_dn, group_dn):
686         samdb = self.get_samdb()
687
688         res = samdb.search(base=group_dn,
689                            scope=ldb.SCOPE_BASE,
690                            attrs=['member'])
691         orig_msg = res[0]
692         self.assertIn('member', orig_msg)
693         members = list(orig_msg['member'])
694
695         account_dn = str(account_dn).encode('utf-8')
696         self.assertIn(account_dn, members)
697         members.remove(account_dn)
698
699         msg = ldb.Message()
700         msg.dn = group_dn
701         msg['member'] = ldb.MessageElement(members,
702                                            ldb.FLAG_MOD_REPLACE,
703                                            'member')
704
705         cleanup = samdb.msg_diff(msg, orig_msg)
706         self.ldb_cleanups.append(cleanup)
707         samdb.modify(msg)
708
709         return cleanup
710
711     def get_cached_creds(self, *,
712                          account_type,
713                          opts=None,
714                          use_cache=True):
715         if opts is None:
716             opts = {}
717
718         opts_default = {
719             'name_prefix': None,
720             'name_suffix': None,
721             'add_dollar': True,
722             'upn': None,
723             'spn': None,
724             'additional_details': None,
725             'allowed_replication': False,
726             'allowed_replication_mock': False,
727             'denied_replication': False,
728             'denied_replication_mock': False,
729             'revealed_to_rodc': False,
730             'revealed_to_mock_rodc': False,
731             'no_auth_data_required': False,
732             'supported_enctypes': None,
733             'not_delegated': False,
734             'delegation_to_spn': None,
735             'delegation_from_dn': None,
736             'trusted_to_auth_for_delegation': False,
737             'fast_support': False,
738             'member_of': None,
739             'kerberos_enabled': True,
740             'secure_channel_type': None,
741             'id': None
742         }
743
744         account_opts = {
745             'account_type': account_type,
746             **opts_default,
747             **opts
748         }
749
750         cache_key = tuple(sorted(account_opts.items()))
751
752         if use_cache:
753             creds = self.account_cache.get(cache_key)
754             if creds is not None:
755                 return creds
756
757         creds = self.create_account_opts(**account_opts)
758         if use_cache:
759             self.account_cache[cache_key] = creds
760
761         return creds
762
763     def create_account_opts(self, *,
764                             account_type,
765                             name_prefix,
766                             name_suffix,
767                             add_dollar,
768                             upn,
769                             spn,
770                             additional_details,
771                             allowed_replication,
772                             allowed_replication_mock,
773                             denied_replication,
774                             denied_replication_mock,
775                             revealed_to_rodc,
776                             revealed_to_mock_rodc,
777                             no_auth_data_required,
778                             supported_enctypes,
779                             not_delegated,
780                             delegation_to_spn,
781                             delegation_from_dn,
782                             trusted_to_auth_for_delegation,
783                             fast_support,
784                             member_of,
785                             kerberos_enabled,
786                             secure_channel_type,
787                             id):
788         if account_type is self.AccountType.USER:
789             self.assertIsNone(spn)
790             self.assertIsNone(delegation_to_spn)
791             self.assertIsNone(delegation_from_dn)
792             self.assertFalse(trusted_to_auth_for_delegation)
793         else:
794             self.assertFalse(not_delegated)
795
796         samdb = self.get_samdb()
797
798         user_name = self.get_new_username()
799         if name_prefix is not None:
800             user_name = name_prefix + user_name
801         if name_suffix is not None:
802             user_name += name_suffix
803
804         user_account_control = 0
805         if trusted_to_auth_for_delegation:
806             user_account_control |= UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION
807         if not_delegated:
808             user_account_control |= UF_NOT_DELEGATED
809         if no_auth_data_required:
810             user_account_control |= UF_NO_AUTH_DATA_REQUIRED
811
812         if additional_details:
813             details = {k: v for k, v in additional_details}
814         else:
815             details = {}
816
817         enctypes = supported_enctypes
818         if fast_support:
819             enctypes = enctypes or 0
820             enctypes |= KerberosCredentials.fast_supported_bits
821
822         if enctypes is not None:
823             details['msDS-SupportedEncryptionTypes'] = str(enctypes)
824
825         if delegation_to_spn:
826             details['msDS-AllowedToDelegateTo'] = delegation_to_spn
827
828         if delegation_from_dn:
829             security_descriptor = self.get_security_descriptor(
830                 delegation_from_dn)
831             details['msDS-AllowedToActOnBehalfOfOtherIdentity'] = (
832                 security_descriptor)
833
834         if spn is None and account_type is not self.AccountType.USER:
835             spn = 'host/' + user_name
836
837         creds, dn = self.create_account(samdb, user_name,
838                                         account_type=account_type,
839                                         upn=upn,
840                                         spn=spn,
841                                         additional_details=details,
842                                         account_control=user_account_control,
843                                         add_dollar=add_dollar)
844
845         keys = self.get_keys(samdb, dn)
846         self.creds_set_keys(creds, keys)
847
848         # Handle secret replication to the RODC.
849
850         if allowed_replication or revealed_to_rodc:
851             rodc_samdb = self.get_rodc_samdb()
852             rodc_dn = self.get_server_dn(rodc_samdb)
853
854             # Allow replicating this account's secrets if requested, or allow
855             # it only temporarily if we're about to replicate them.
856             allowed_cleanup = self.add_to_group(
857                 dn, rodc_dn,
858                 'msDS-RevealOnDemandGroup')
859
860             if revealed_to_rodc:
861                 # Replicate this account's secrets to the RODC.
862                 self.replicate_account_to_rodc(dn)
863
864             if not allowed_replication:
865                 # If we don't want replicating secrets to be allowed for this
866                 # account, disable it again.
867                 samdb.modify(allowed_cleanup)
868
869             self.check_revealed(dn,
870                                 rodc_dn,
871                                 revealed=revealed_to_rodc)
872
873         if denied_replication:
874             rodc_samdb = self.get_rodc_samdb()
875             rodc_dn = self.get_server_dn(rodc_samdb)
876
877             # Deny replicating this account's secrets to the RODC.
878             self.add_to_group(dn, rodc_dn, 'msDS-NeverRevealGroup')
879
880         # Handle secret replication to the mock RODC.
881
882         if allowed_replication_mock or revealed_to_mock_rodc:
883             # Allow replicating this account's secrets if requested, or allow
884             # it only temporarily if we want to add the account to the mock
885             # RODC's msDS-RevealedUsers.
886             rodc_ctx = self.get_mock_rodc_ctx()
887             mock_rodc_dn = ldb.Dn(samdb, rodc_ctx.acct_dn)
888
889             allowed_mock_cleanup = self.add_to_group(
890                 dn, mock_rodc_dn,
891                 'msDS-RevealOnDemandGroup')
892
893             if revealed_to_mock_rodc:
894                 # Request replicating this account's secrets to the mock RODC,
895                 # which updates msDS-RevealedUsers.
896                 self.reveal_account_to_mock_rodc(dn)
897
898             if not allowed_replication_mock:
899                 # If we don't want replicating secrets to be allowed for this
900                 # account, disable it again.
901                 samdb.modify(allowed_mock_cleanup)
902
903             self.check_revealed(dn,
904                                 mock_rodc_dn,
905                                 revealed=revealed_to_mock_rodc)
906
907         if denied_replication_mock:
908             # Deny replicating this account's secrets to the mock RODC.
909             rodc_ctx = self.get_mock_rodc_ctx()
910             mock_rodc_dn = ldb.Dn(samdb, rodc_ctx.acct_dn)
911
912             self.add_to_group(dn, mock_rodc_dn, 'msDS-NeverRevealGroup')
913
914         if member_of is not None:
915             for group_dn in member_of:
916                 self.add_to_group(dn, ldb.Dn(samdb, group_dn), 'member',
917                                   expect_attr=False)
918
919         if kerberos_enabled:
920             creds.set_kerberos_state(MUST_USE_KERBEROS)
921         else:
922             creds.set_kerberos_state(DONT_USE_KERBEROS)
923
924         if secure_channel_type is not None:
925             creds.set_secure_channel_type(secure_channel_type)
926
927         return creds
928
929     def get_new_username(self):
930         user_name = self.account_base + str(self.account_id)
931         type(self).account_id += 1
932
933         return user_name
934
935     def get_client_creds(self,
936                          allow_missing_password=False,
937                          allow_missing_keys=True):
938         def create_client_account():
939             return self.get_cached_creds(account_type=self.AccountType.USER)
940
941         c = self._get_krb5_creds(prefix='CLIENT',
942                                  allow_missing_password=allow_missing_password,
943                                  allow_missing_keys=allow_missing_keys,
944                                  fallback_creds_fn=create_client_account)
945         return c
946
947     def get_mach_creds(self,
948                        allow_missing_password=False,
949                        allow_missing_keys=True):
950         def create_mach_account():
951             return self.get_cached_creds(account_type=self.AccountType.COMPUTER,
952                                          opts={'fast_support': True})
953
954         c = self._get_krb5_creds(prefix='MAC',
955                                  allow_missing_password=allow_missing_password,
956                                  allow_missing_keys=allow_missing_keys,
957                                  fallback_creds_fn=create_mach_account)
958         return c
959
960     def get_service_creds(self,
961                           allow_missing_password=False,
962                           allow_missing_keys=True):
963         def create_service_account():
964             return self.get_cached_creds(
965                 account_type=self.AccountType.COMPUTER,
966                 opts={
967                     'trusted_to_auth_for_delegation': True,
968                     'fast_support': True
969                 })
970
971         c = self._get_krb5_creds(prefix='SERVICE',
972                                  allow_missing_password=allow_missing_password,
973                                  allow_missing_keys=allow_missing_keys,
974                                  fallback_creds_fn=create_service_account)
975         return c
976
977     def get_rodc_krbtgt_creds(self,
978                               require_keys=True,
979                               require_strongest_key=False):
980         if require_strongest_key:
981             self.assertTrue(require_keys)
982
983         def download_rodc_krbtgt_creds():
984             samdb = self.get_samdb()
985             rodc_samdb = self.get_rodc_samdb()
986
987             rodc_dn = self.get_server_dn(rodc_samdb)
988
989             res = samdb.search(rodc_dn,
990                                scope=ldb.SCOPE_BASE,
991                                attrs=['msDS-KrbTgtLink'])
992             krbtgt_dn = res[0]['msDS-KrbTgtLink'][0]
993
994             res = samdb.search(krbtgt_dn,
995                                scope=ldb.SCOPE_BASE,
996                                attrs=['sAMAccountName',
997                                       'msDS-KeyVersionNumber',
998                                       'msDS-SecondaryKrbTgtNumber'])
999             krbtgt_dn = res[0].dn
1000             username = str(res[0]['sAMAccountName'])
1001
1002             creds = KerberosCredentials()
1003             creds.set_domain(self.env_get_var('DOMAIN', 'RODC_KRBTGT'))
1004             creds.set_realm(self.env_get_var('REALM', 'RODC_KRBTGT'))
1005             creds.set_username(username)
1006
1007             kvno = int(res[0]['msDS-KeyVersionNumber'][0])
1008             krbtgt_number = int(res[0]['msDS-SecondaryKrbTgtNumber'][0])
1009
1010             rodc_kvno = krbtgt_number << 16 | kvno
1011             creds.set_kvno(rodc_kvno)
1012             creds.set_dn(krbtgt_dn)
1013
1014             keys = self.get_keys(samdb, krbtgt_dn)
1015             self.creds_set_keys(creds, keys)
1016
1017             # The RODC krbtgt account should support the default enctypes,
1018             # although it might not have the msDS-SupportedEncryptionTypes
1019             # attribute.
1020             self.creds_set_default_enctypes(
1021                 creds,
1022                 fast_support=self.kdc_fast_support,
1023                 claims_support=self.kdc_claims_support,
1024                 compound_id_support=self.kdc_compound_id_support)
1025
1026             return creds
1027
1028         c = self._get_krb5_creds(prefix='RODC_KRBTGT',
1029                                  allow_missing_password=True,
1030                                  allow_missing_keys=not require_keys,
1031                                  require_strongest_key=require_strongest_key,
1032                                  fallback_creds_fn=download_rodc_krbtgt_creds)
1033         return c
1034
1035     def get_mock_rodc_krbtgt_creds(self,
1036                                    require_keys=True,
1037                                    require_strongest_key=False):
1038         if require_strongest_key:
1039             self.assertTrue(require_keys)
1040
1041         def create_rodc_krbtgt_account():
1042             samdb = self.get_samdb()
1043
1044             rodc_ctx = self.get_mock_rodc_ctx()
1045
1046             krbtgt_dn = rodc_ctx.new_krbtgt_dn
1047
1048             res = samdb.search(base=ldb.Dn(samdb, krbtgt_dn),
1049                                scope=ldb.SCOPE_BASE,
1050                                attrs=['msDS-KeyVersionNumber',
1051                                       'msDS-SecondaryKrbTgtNumber'])
1052             dn = res[0].dn
1053             username = str(rodc_ctx.krbtgt_name)
1054
1055             creds = KerberosCredentials()
1056             creds.set_domain(self.env_get_var('DOMAIN', 'RODC_KRBTGT'))
1057             creds.set_realm(self.env_get_var('REALM', 'RODC_KRBTGT'))
1058             creds.set_username(username)
1059
1060             kvno = int(res[0]['msDS-KeyVersionNumber'][0])
1061             krbtgt_number = int(res[0]['msDS-SecondaryKrbTgtNumber'][0])
1062
1063             rodc_kvno = krbtgt_number << 16 | kvno
1064             creds.set_kvno(rodc_kvno)
1065             creds.set_dn(dn)
1066
1067             keys = self.get_keys(samdb, dn)
1068             self.creds_set_keys(creds, keys)
1069
1070             self.creds_set_enctypes(creds)
1071
1072             return creds
1073
1074         c = self._get_krb5_creds(prefix='MOCK_RODC_KRBTGT',
1075                                  allow_missing_password=True,
1076                                  allow_missing_keys=not require_keys,
1077                                  require_strongest_key=require_strongest_key,
1078                                  fallback_creds_fn=create_rodc_krbtgt_account)
1079         return c
1080
1081     def get_krbtgt_creds(self,
1082                          require_keys=True,
1083                          require_strongest_key=False):
1084         if require_strongest_key:
1085             self.assertTrue(require_keys)
1086
1087         def download_krbtgt_creds():
1088             samdb = self.get_samdb()
1089
1090             krbtgt_rid = security.DOMAIN_RID_KRBTGT
1091             krbtgt_sid = '%s-%d' % (samdb.get_domain_sid(), krbtgt_rid)
1092
1093             res = samdb.search(base='<SID=%s>' % krbtgt_sid,
1094                                scope=ldb.SCOPE_BASE,
1095                                attrs=['sAMAccountName',
1096                                       'msDS-KeyVersionNumber'])
1097             dn = res[0].dn
1098             username = str(res[0]['sAMAccountName'])
1099
1100             creds = KerberosCredentials()
1101             creds.set_domain(self.env_get_var('DOMAIN', 'KRBTGT'))
1102             creds.set_realm(self.env_get_var('REALM', 'KRBTGT'))
1103             creds.set_username(username)
1104
1105             kvno = int(res[0]['msDS-KeyVersionNumber'][0])
1106             creds.set_kvno(kvno)
1107             creds.set_dn(dn)
1108
1109             keys = self.get_keys(samdb, dn)
1110             self.creds_set_keys(creds, keys)
1111
1112             # The krbtgt account should support the default enctypes, although
1113             # it might not (on Samba) have the msDS-SupportedEncryptionTypes
1114             # attribute.
1115             self.creds_set_default_enctypes(
1116                 creds,
1117                 fast_support=self.kdc_fast_support,
1118                 claims_support=self.kdc_claims_support,
1119                 compound_id_support=self.kdc_compound_id_support)
1120
1121             return creds
1122
1123         c = self._get_krb5_creds(prefix='KRBTGT',
1124                                  default_username='krbtgt',
1125                                  allow_missing_password=True,
1126                                  allow_missing_keys=not require_keys,
1127                                  require_strongest_key=require_strongest_key,
1128                                  fallback_creds_fn=download_krbtgt_creds)
1129         return c
1130
1131     def get_dc_creds(self,
1132                      require_keys=True,
1133                      require_strongest_key=False):
1134         if require_strongest_key:
1135             self.assertTrue(require_keys)
1136
1137         def download_dc_creds():
1138             samdb = self.get_samdb()
1139
1140             dc_rid = 1000
1141             dc_sid = '%s-%d' % (samdb.get_domain_sid(), dc_rid)
1142
1143             res = samdb.search(base='<SID=%s>' % dc_sid,
1144                                scope=ldb.SCOPE_BASE,
1145                                attrs=['sAMAccountName',
1146                                       'msDS-KeyVersionNumber'])
1147             dn = res[0].dn
1148             username = str(res[0]['sAMAccountName'])
1149
1150             creds = KerberosCredentials()
1151             creds.set_domain(self.env_get_var('DOMAIN', 'DC'))
1152             creds.set_realm(self.env_get_var('REALM', 'DC'))
1153             creds.set_username(username)
1154
1155             kvno = int(res[0]['msDS-KeyVersionNumber'][0])
1156             creds.set_kvno(kvno)
1157             creds.set_dn(dn)
1158
1159             keys = self.get_keys(samdb, dn)
1160             self.creds_set_keys(creds, keys)
1161
1162             self.creds_set_enctypes(creds)
1163
1164             return creds
1165
1166         c = self._get_krb5_creds(prefix='DC',
1167                                  allow_missing_password=True,
1168                                  allow_missing_keys=not require_keys,
1169                                  require_strongest_key=require_strongest_key,
1170                                  fallback_creds_fn=download_dc_creds)
1171         return c
1172
1173     def get_server_creds(self,
1174                      require_keys=True,
1175                      require_strongest_key=False):
1176         if require_strongest_key:
1177             self.assertTrue(require_keys)
1178
1179         def download_server_creds():
1180             samdb = self.get_samdb()
1181
1182             res = samdb.search(base=samdb.get_default_basedn(),
1183                                expression=(f'(|(sAMAccountName={self.host}*)'
1184                                            f'(dNSHostName={self.host}))'),
1185                                scope=ldb.SCOPE_SUBTREE,
1186                                attrs=['sAMAccountName',
1187                                       'msDS-KeyVersionNumber'])
1188             self.assertEqual(1, len(res))
1189             dn = res[0].dn
1190             username = str(res[0]['sAMAccountName'])
1191
1192             creds = KerberosCredentials()
1193             creds.set_domain(self.env_get_var('DOMAIN', 'SERVER'))
1194             creds.set_realm(self.env_get_var('REALM', 'SERVER'))
1195             creds.set_username(username)
1196
1197             kvno = int(res[0]['msDS-KeyVersionNumber'][0])
1198             creds.set_kvno(kvno)
1199             creds.set_dn(dn)
1200
1201             keys = self.get_keys(samdb, dn)
1202             self.creds_set_keys(creds, keys)
1203
1204             self.creds_set_enctypes(creds)
1205
1206             return creds
1207
1208         c = self._get_krb5_creds(prefix='SERVER',
1209                                  allow_missing_password=True,
1210                                  allow_missing_keys=not require_keys,
1211                                  require_strongest_key=require_strongest_key,
1212                                  fallback_creds_fn=download_server_creds)
1213         return c
1214
1215     def as_req(self, cname, sname, realm, etypes, padata=None, kdc_options=0):
1216         '''Send a Kerberos AS_REQ, returns the undecoded response
1217         '''
1218
1219         till = self.get_KerberosTime(offset=36000)
1220
1221         req = self.AS_REQ_create(padata=padata,
1222                                  kdc_options=str(kdc_options),
1223                                  cname=cname,
1224                                  realm=realm,
1225                                  sname=sname,
1226                                  from_time=None,
1227                                  till_time=till,
1228                                  renew_time=None,
1229                                  nonce=0x7fffffff,
1230                                  etypes=etypes,
1231                                  addresses=None,
1232                                  additional_tickets=None)
1233         rep = self.send_recv_transaction(req)
1234         return rep
1235
1236     def get_as_rep_key(self, creds, rep):
1237         '''Extract the session key from an AS-REP
1238         '''
1239         rep_padata = self.der_decode(
1240             rep['e-data'],
1241             asn1Spec=krb5_asn1.METHOD_DATA())
1242
1243         for pa in rep_padata:
1244             if pa['padata-type'] == PADATA_ETYPE_INFO2:
1245                 padata_value = pa['padata-value']
1246                 break
1247
1248         etype_info2 = self.der_decode(
1249             padata_value, asn1Spec=krb5_asn1.ETYPE_INFO2())
1250
1251         key = self.PasswordKey_from_etype_info2(creds, etype_info2[0],
1252                                                 creds.get_kvno())
1253         return key
1254
1255     def get_enc_timestamp_pa_data(self, creds, rep, skew=0):
1256         '''generate the pa_data data element for an AS-REQ
1257         '''
1258
1259         key = self.get_as_rep_key(creds, rep)
1260
1261         return self.get_enc_timestamp_pa_data_from_key(key, skew=skew)
1262
1263     def get_enc_timestamp_pa_data_from_key(self, key, skew=0):
1264         (patime, pausec) = self.get_KerberosTimeWithUsec(offset=skew)
1265         padata = self.PA_ENC_TS_ENC_create(patime, pausec)
1266         padata = self.der_encode(padata, asn1Spec=krb5_asn1.PA_ENC_TS_ENC())
1267
1268         padata = self.EncryptedData_create(key, KU_PA_ENC_TIMESTAMP, padata)
1269         padata = self.der_encode(padata, asn1Spec=krb5_asn1.EncryptedData())
1270
1271         padata = self.PA_DATA_create(PADATA_ENC_TIMESTAMP, padata)
1272
1273         return padata
1274
1275     def get_challenge_pa_data(self, client_challenge_key, skew=0):
1276         patime, pausec = self.get_KerberosTimeWithUsec(offset=skew)
1277         padata = self.PA_ENC_TS_ENC_create(patime, pausec)
1278         padata = self.der_encode(padata,
1279                                  asn1Spec=krb5_asn1.PA_ENC_TS_ENC())
1280
1281         padata = self.EncryptedData_create(client_challenge_key,
1282                                            KU_ENC_CHALLENGE_CLIENT,
1283                                            padata)
1284         padata = self.der_encode(padata,
1285                                  asn1Spec=krb5_asn1.EncryptedData())
1286
1287         padata = self.PA_DATA_create(PADATA_ENCRYPTED_CHALLENGE,
1288                                      padata)
1289
1290         return padata
1291
1292     def get_as_rep_enc_data(self, key, rep):
1293         ''' Decrypt and Decode the encrypted data in an AS-REP
1294         '''
1295         enc_part = key.decrypt(KU_AS_REP_ENC_PART, rep['enc-part']['cipher'])
1296         # MIT KDC encodes both EncASRepPart and EncTGSRepPart with
1297         # application tag 26
1298         try:
1299             enc_part = self.der_decode(
1300                 enc_part, asn1Spec=krb5_asn1.EncASRepPart())
1301         except Exception:
1302             enc_part = self.der_decode(
1303                 enc_part, asn1Spec=krb5_asn1.EncTGSRepPart())
1304
1305         return enc_part
1306
1307     def check_pre_authentication(self, rep):
1308         """ Check that the kdc response was pre-authentication required
1309         """
1310         self.check_error_rep(rep, KDC_ERR_PREAUTH_REQUIRED)
1311
1312     def check_as_reply(self, rep):
1313         """ Check that the kdc response is an AS-REP and that the
1314             values for:
1315                 msg-type
1316                 pvno
1317                 tkt-pvno
1318                 kvno
1319             match the expected values
1320         """
1321         self.check_reply(rep, msg_type=KRB_AS_REP)
1322
1323     def check_tgs_reply(self, rep):
1324         """ Check that the kdc response is an TGS-REP and that the
1325             values for:
1326                 msg-type
1327                 pvno
1328                 tkt-pvno
1329                 kvno
1330             match the expected values
1331         """
1332         self.check_reply(rep, msg_type=KRB_TGS_REP)
1333
1334     def check_reply(self, rep, msg_type):
1335
1336         # Should have a reply, and it should an TGS-REP message.
1337         self.assertIsNotNone(rep)
1338         self.assertEqual(rep['msg-type'], msg_type, "rep = {%s}" % rep)
1339
1340         # Protocol version number should be 5
1341         pvno = int(rep['pvno'])
1342         self.assertEqual(5, pvno, "rep = {%s}" % rep)
1343
1344         # The ticket version number should be 5
1345         tkt_vno = int(rep['ticket']['tkt-vno'])
1346         self.assertEqual(5, tkt_vno, "rep = {%s}" % rep)
1347
1348         # Check that the kvno is not an RODC kvno
1349         # MIT kerberos does not provide the kvno, so we treat it as optional.
1350         # This is tested in compatability_test.py
1351         if 'kvno' in rep['enc-part']:
1352             kvno = int(rep['enc-part']['kvno'])
1353             # If the high order bits are set this is an RODC kvno.
1354             self.assertEqual(0, kvno & 0xFFFF0000, "rep = {%s}" % rep)
1355
1356     def check_error_rep(self, rep, expected):
1357         """ Check that the reply is an error message, with the expected
1358             error-code specified.
1359         """
1360         self.assertIsNotNone(rep)
1361         self.assertEqual(rep['msg-type'], KRB_ERROR, "rep = {%s}" % rep)
1362         if isinstance(expected, collections.abc.Container):
1363             self.assertIn(rep['error-code'], expected, "rep = {%s}" % rep)
1364         else:
1365             self.assertEqual(rep['error-code'], expected, "rep = {%s}" % rep)
1366
1367     def tgs_req(self, cname, sname, realm, ticket, key, etypes,
1368                 expected_error_mode=0, padata=None, kdc_options=0,
1369                 to_rodc=False, service_creds=None, expect_pac=True,
1370                 expect_edata=None, expected_flags=None, unexpected_flags=None):
1371         '''Send a TGS-REQ, returns the response and the decrypted and
1372            decoded enc-part
1373         '''
1374
1375         subkey = self.RandomKey(key.etype)
1376
1377         (ctime, cusec) = self.get_KerberosTimeWithUsec()
1378
1379         tgt = KerberosTicketCreds(ticket,
1380                                   key,
1381                                   crealm=realm,
1382                                   cname=cname)
1383
1384         if service_creds is not None:
1385             decryption_key = self.TicketDecryptionKey_from_creds(
1386                 service_creds)
1387             expected_supported_etypes = service_creds.tgs_supported_enctypes
1388         else:
1389             decryption_key = None
1390             expected_supported_etypes = None
1391
1392         if not expected_error_mode:
1393             check_error_fn = None
1394             check_rep_fn = self.generic_check_kdc_rep
1395         else:
1396             check_error_fn = self.generic_check_kdc_error
1397             check_rep_fn = None
1398
1399         def generate_padata(_kdc_exchange_dict,
1400                             _callback_dict,
1401                             req_body):
1402
1403             return padata, req_body
1404
1405         kdc_exchange_dict = self.tgs_exchange_dict(
1406             expected_crealm=realm,
1407             expected_cname=cname,
1408             expected_srealm=realm,
1409             expected_sname=sname,
1410             expected_error_mode=expected_error_mode,
1411             expected_flags=expected_flags,
1412             unexpected_flags=unexpected_flags,
1413             expected_supported_etypes=expected_supported_etypes,
1414             check_error_fn=check_error_fn,
1415             check_rep_fn=check_rep_fn,
1416             check_kdc_private_fn=self.generic_check_kdc_private,
1417             ticket_decryption_key=decryption_key,
1418             generate_padata_fn=generate_padata if padata is not None else None,
1419             tgt=tgt,
1420             authenticator_subkey=subkey,
1421             kdc_options=str(kdc_options),
1422             expect_edata=expect_edata,
1423             expect_pac=expect_pac,
1424             to_rodc=to_rodc)
1425
1426         rep = self._generic_kdc_exchange(kdc_exchange_dict,
1427                                          cname=None,
1428                                          realm=realm,
1429                                          sname=sname,
1430                                          etypes=etypes)
1431
1432         if expected_error_mode:
1433             enc_part = None
1434         else:
1435             ticket_creds = kdc_exchange_dict['rep_ticket_creds']
1436             enc_part = ticket_creds.encpart_private
1437
1438         return rep, enc_part
1439
1440     def get_service_ticket(self, tgt, target_creds, service='host',
1441                            target_name=None, till=None, rc4_support=True,
1442                            to_rodc=False, kdc_options=None,
1443                            expected_flags=None, unexpected_flags=None,
1444                            pac_request=True, expect_pac=True, fresh=False):
1445         user_name = tgt.cname['name-string'][0]
1446         if target_name is None:
1447             target_name = target_creds.get_username()[:-1]
1448         cache_key = (user_name, target_name, service, to_rodc, kdc_options,
1449                      pac_request, str(expected_flags), str(unexpected_flags),
1450                      till, rc4_support,
1451                      expect_pac)
1452
1453         if not fresh:
1454             ticket = self.tkt_cache.get(cache_key)
1455
1456             if ticket is not None:
1457                 return ticket
1458
1459         etype = (AES256_CTS_HMAC_SHA1_96, ARCFOUR_HMAC_MD5)
1460
1461         if kdc_options is None:
1462             kdc_options = '0'
1463         kdc_options = str(krb5_asn1.KDCOptions(kdc_options))
1464
1465         sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
1466                                           names=[service, target_name])
1467         srealm = target_creds.get_realm()
1468
1469         authenticator_subkey = self.RandomKey(kcrypto.Enctype.AES256)
1470
1471         decryption_key = self.TicketDecryptionKey_from_creds(target_creds)
1472
1473         kdc_exchange_dict = self.tgs_exchange_dict(
1474             expected_crealm=tgt.crealm,
1475             expected_cname=tgt.cname,
1476             expected_srealm=srealm,
1477             expected_sname=sname,
1478             expected_supported_etypes=target_creds.tgs_supported_enctypes,
1479             expected_flags=expected_flags,
1480             unexpected_flags=unexpected_flags,
1481             ticket_decryption_key=decryption_key,
1482             check_rep_fn=self.generic_check_kdc_rep,
1483             check_kdc_private_fn=self.generic_check_kdc_private,
1484             tgt=tgt,
1485             authenticator_subkey=authenticator_subkey,
1486             kdc_options=kdc_options,
1487             pac_request=pac_request,
1488             expect_pac=expect_pac,
1489             rc4_support=rc4_support,
1490             to_rodc=to_rodc)
1491
1492         rep = self._generic_kdc_exchange(kdc_exchange_dict,
1493                                          cname=None,
1494                                          realm=srealm,
1495                                          sname=sname,
1496                                          till_time=till,
1497                                          etypes=etype)
1498         self.check_tgs_reply(rep)
1499
1500         service_ticket_creds = kdc_exchange_dict['rep_ticket_creds']
1501
1502         if to_rodc:
1503             krbtgt_creds = self.get_rodc_krbtgt_creds()
1504         else:
1505             krbtgt_creds = self.get_krbtgt_creds()
1506         krbtgt_key = self.TicketDecryptionKey_from_creds(krbtgt_creds)
1507         self.verify_ticket(service_ticket_creds, krbtgt_key,
1508                            service_ticket=True, expect_pac=expect_pac,
1509                            expect_ticket_checksum=self.tkt_sig_support)
1510
1511         self.tkt_cache[cache_key] = service_ticket_creds
1512
1513         return service_ticket_creds
1514
1515     def get_tgt(self, creds, to_rodc=False, kdc_options=None,
1516                 client_account=None, client_name_type=NT_PRINCIPAL,
1517                 expected_flags=None, unexpected_flags=None,
1518                 expected_account_name=None, expected_upn_name=None,
1519                 expected_cname=None,
1520                 expected_sid=None,
1521                 pac_request=True, expect_pac=True,
1522                 expect_pac_attrs=None, expect_pac_attrs_pac_request=None,
1523                 expect_requester_sid=None,
1524                 rc4_support=True,
1525                 fresh=False):
1526         if client_account is not None:
1527             user_name = client_account
1528         else:
1529             user_name = creds.get_username()
1530
1531         cache_key = (user_name, to_rodc, kdc_options, pac_request,
1532                      client_name_type,
1533                      str(expected_flags), str(unexpected_flags),
1534                      expected_account_name, expected_upn_name, expected_sid,
1535                      str(expected_cname),
1536                      rc4_support,
1537                      expect_pac, expect_pac_attrs,
1538                      expect_pac_attrs_pac_request, expect_requester_sid)
1539
1540         if not fresh:
1541             tgt = self.tkt_cache.get(cache_key)
1542
1543             if tgt is not None:
1544                 return tgt
1545
1546         realm = creds.get_realm()
1547
1548         salt = creds.get_salt()
1549
1550         etype = (AES256_CTS_HMAC_SHA1_96, ARCFOUR_HMAC_MD5)
1551         cname = self.PrincipalName_create(name_type=client_name_type,
1552                                           names=user_name.split('/'))
1553         sname = self.PrincipalName_create(name_type=NT_SRV_INST,
1554                                           names=['krbtgt', realm])
1555
1556         if expected_cname is None:
1557             expected_cname = cname
1558
1559         till = self.get_KerberosTime(offset=36000)
1560
1561         if to_rodc:
1562             krbtgt_creds = self.get_rodc_krbtgt_creds()
1563         else:
1564             krbtgt_creds = self.get_krbtgt_creds()
1565         ticket_decryption_key = (
1566             self.TicketDecryptionKey_from_creds(krbtgt_creds))
1567
1568         expected_etypes = krbtgt_creds.tgs_supported_enctypes
1569
1570         if kdc_options is None:
1571             kdc_options = ('forwardable,'
1572                            'renewable,'
1573                            'canonicalize,'
1574                            'renewable-ok')
1575         kdc_options = krb5_asn1.KDCOptions(kdc_options)
1576
1577         pac_options = '1'  # supports claims
1578
1579         rep, kdc_exchange_dict = self._test_as_exchange(
1580             cname=cname,
1581             realm=realm,
1582             sname=sname,
1583             till=till,
1584             client_as_etypes=etype,
1585             expected_error_mode=KDC_ERR_PREAUTH_REQUIRED,
1586             expected_crealm=realm,
1587             expected_cname=expected_cname,
1588             expected_srealm=realm,
1589             expected_sname=sname,
1590             expected_account_name=expected_account_name,
1591             expected_upn_name=expected_upn_name,
1592             expected_sid=expected_sid,
1593             expected_salt=salt,
1594             expected_flags=expected_flags,
1595             unexpected_flags=unexpected_flags,
1596             expected_supported_etypes=expected_etypes,
1597             etypes=etype,
1598             padata=None,
1599             kdc_options=kdc_options,
1600             preauth_key=None,
1601             ticket_decryption_key=ticket_decryption_key,
1602             pac_request=pac_request,
1603             pac_options=pac_options,
1604             expect_pac=expect_pac,
1605             expect_pac_attrs=expect_pac_attrs,
1606             expect_pac_attrs_pac_request=expect_pac_attrs_pac_request,
1607             expect_requester_sid=expect_requester_sid,
1608             rc4_support=rc4_support,
1609             to_rodc=to_rodc)
1610         self.check_pre_authentication(rep)
1611
1612         etype_info2 = kdc_exchange_dict['preauth_etype_info2']
1613
1614         preauth_key = self.PasswordKey_from_etype_info2(creds,
1615                                                         etype_info2[0],
1616                                                         creds.get_kvno())
1617
1618         ts_enc_padata = self.get_enc_timestamp_pa_data_from_key(preauth_key)
1619
1620         padata = [ts_enc_padata]
1621
1622         expected_realm = realm.upper()
1623
1624         expected_sname = self.PrincipalName_create(
1625             name_type=NT_SRV_INST, names=['krbtgt', realm.upper()])
1626
1627         rep, kdc_exchange_dict = self._test_as_exchange(
1628             cname=cname,
1629             realm=realm,
1630             sname=sname,
1631             till=till,
1632             client_as_etypes=etype,
1633             expected_error_mode=0,
1634             expected_crealm=expected_realm,
1635             expected_cname=expected_cname,
1636             expected_srealm=expected_realm,
1637             expected_sname=expected_sname,
1638             expected_account_name=expected_account_name,
1639             expected_upn_name=expected_upn_name,
1640             expected_sid=expected_sid,
1641             expected_salt=salt,
1642             expected_flags=expected_flags,
1643             unexpected_flags=unexpected_flags,
1644             expected_supported_etypes=expected_etypes,
1645             etypes=etype,
1646             padata=padata,
1647             kdc_options=kdc_options,
1648             preauth_key=preauth_key,
1649             ticket_decryption_key=ticket_decryption_key,
1650             pac_request=pac_request,
1651             pac_options=pac_options,
1652             expect_pac=expect_pac,
1653             expect_pac_attrs=expect_pac_attrs,
1654             expect_pac_attrs_pac_request=expect_pac_attrs_pac_request,
1655             expect_requester_sid=expect_requester_sid,
1656             rc4_support=rc4_support,
1657             to_rodc=to_rodc)
1658         self.check_as_reply(rep)
1659
1660         ticket_creds = kdc_exchange_dict['rep_ticket_creds']
1661
1662         self.tkt_cache[cache_key] = ticket_creds
1663
1664         return ticket_creds
1665
1666     # Named tuple to contain values of interest when the PAC is decoded.
1667     PacData = namedtuple(
1668         "PacData",
1669         "account_name account_sid logon_name upn domain_name")
1670
1671     def get_pac_data(self, authorization_data):
1672         '''Decode the PAC element contained in the authorization-data element
1673         '''
1674         account_name = None
1675         user_sid = None
1676         logon_name = None
1677         upn = None
1678         domain_name = None
1679
1680         # The PAC data will be wrapped in an AD_IF_RELEVANT element
1681         ad_if_relevant_elements = (
1682             x for x in authorization_data if x['ad-type'] == AD_IF_RELEVANT)
1683         for dt in ad_if_relevant_elements:
1684             buf = self.der_decode(
1685                 dt['ad-data'], asn1Spec=krb5_asn1.AD_IF_RELEVANT())
1686             # The PAC data is further wrapped in a AD_WIN2K_PAC element
1687             for ad in (x for x in buf if x['ad-type'] == AD_WIN2K_PAC):
1688                 pb = ndr_unpack(krb5pac.PAC_DATA, ad['ad-data'])
1689                 for pac in pb.buffers:
1690                     if pac.type == krb5pac.PAC_TYPE_LOGON_INFO:
1691                         account_name = (
1692                             pac.info.info.info3.base.account_name)
1693                         user_sid = (
1694                             str(pac.info.info.info3.base.domain_sid)
1695                             + "-" + str(pac.info.info.info3.base.rid))
1696                     elif pac.type == krb5pac.PAC_TYPE_LOGON_NAME:
1697                         logon_name = pac.info.account_name
1698                     elif pac.type == krb5pac.PAC_TYPE_UPN_DNS_INFO:
1699                         upn = pac.info.upn_name
1700                         domain_name = pac.info.dns_domain_name
1701
1702         return self.PacData(
1703             account_name,
1704             user_sid,
1705             logon_name,
1706             upn,
1707             domain_name)
1708
1709     def decode_service_ticket(self, creds, ticket):
1710         '''Decrypt and decode a service ticket
1711         '''
1712
1713         name = creds.get_username()
1714         if name.endswith('$'):
1715             name = name[:-1]
1716         realm = creds.get_realm()
1717         salt = "%s.%s@%s" % (name, realm.lower(), realm.upper())
1718
1719         key = self.PasswordKey_create(
1720             ticket['enc-part']['etype'],
1721             creds.get_password(),
1722             salt,
1723             ticket['enc-part']['kvno'])
1724
1725         enc_part = key.decrypt(KU_TICKET, ticket['enc-part']['cipher'])
1726         enc_ticket_part = self.der_decode(
1727             enc_part, asn1Spec=krb5_asn1.EncTicketPart())
1728         return enc_ticket_part
1729
1730     def modify_ticket_flag(self, enc_part, flag, value):
1731         self.assertIsInstance(value, bool)
1732
1733         flag = krb5_asn1.TicketFlags(flag)
1734         pos = len(tuple(flag)) - 1
1735
1736         flags = enc_part['flags']
1737         self.assertLessEqual(pos, len(flags))
1738
1739         new_flags = flags[:pos] + str(int(value)) + flags[pos + 1:]
1740         enc_part['flags'] = new_flags
1741
1742         return enc_part
1743
1744     def get_objectSid(self, samdb, dn):
1745         ''' Get the objectSID for a DN
1746             Note: performs an Ldb query.
1747         '''
1748         res = samdb.search(dn, scope=SCOPE_BASE, attrs=["objectSID"])
1749         self.assertTrue(len(res) == 1, "did not get objectSid for %s" % dn)
1750         sid = samdb.schema_format_value("objectSID", res[0]["objectSID"][0])
1751         return sid.decode('utf8')
1752
1753     def add_attribute(self, samdb, dn_str, name, value):
1754         if isinstance(value, list):
1755             values = value
1756         else:
1757             values = [value]
1758         flag = ldb.FLAG_MOD_ADD
1759
1760         dn = ldb.Dn(samdb, dn_str)
1761         msg = ldb.Message(dn)
1762         msg[name] = ldb.MessageElement(values, flag, name)
1763         samdb.modify(msg)
1764
1765     def modify_attribute(self, samdb, dn_str, name, value):
1766         if isinstance(value, list):
1767             values = value
1768         else:
1769             values = [value]
1770         flag = ldb.FLAG_MOD_REPLACE
1771
1772         dn = ldb.Dn(samdb, dn_str)
1773         msg = ldb.Message(dn)
1774         msg[name] = ldb.MessageElement(values, flag, name)
1775         samdb.modify(msg)
1776
1777     def create_ccache(self, cname, ticket, enc_part):
1778         """ Lay out a version 4 on-disk credentials cache, to be read using the
1779             FILE: protocol.
1780         """
1781
1782         field = krb5ccache.DELTATIME_TAG()
1783         field.kdc_sec_offset = 0
1784         field.kdc_usec_offset = 0
1785
1786         v4tag = krb5ccache.V4TAG()
1787         v4tag.tag = 1
1788         v4tag.field = field
1789
1790         v4tags = krb5ccache.V4TAGS()
1791         v4tags.tag = v4tag
1792         v4tags.further_tags = b''
1793
1794         optional_header = krb5ccache.V4HEADER()
1795         optional_header.v4tags = v4tags
1796
1797         cname_string = cname['name-string']
1798
1799         cprincipal = krb5ccache.PRINCIPAL()
1800         cprincipal.name_type = cname['name-type']
1801         cprincipal.component_count = len(cname_string)
1802         cprincipal.realm = ticket['realm']
1803         cprincipal.components = cname_string
1804
1805         sname = ticket['sname']
1806         sname_string = sname['name-string']
1807
1808         sprincipal = krb5ccache.PRINCIPAL()
1809         sprincipal.name_type = sname['name-type']
1810         sprincipal.component_count = len(sname_string)
1811         sprincipal.realm = ticket['realm']
1812         sprincipal.components = sname_string
1813
1814         key = self.EncryptionKey_import(enc_part['key'])
1815
1816         key_data = key.export_obj()
1817         keyblock = krb5ccache.KEYBLOCK()
1818         keyblock.enctype = key_data['keytype']
1819         keyblock.data = key_data['keyvalue']
1820
1821         addresses = krb5ccache.ADDRESSES()
1822         addresses.count = 0
1823         addresses.data = []
1824
1825         authdata = krb5ccache.AUTHDATA()
1826         authdata.count = 0
1827         authdata.data = []
1828
1829         # Re-encode the ticket, since it was decoded by another layer.
1830         ticket_data = self.der_encode(ticket, asn1Spec=krb5_asn1.Ticket())
1831
1832         authtime = enc_part['authtime']
1833         starttime = enc_part.get('starttime', authtime)
1834         endtime = enc_part['endtime']
1835
1836         cred = krb5ccache.CREDENTIAL()
1837         cred.client = cprincipal
1838         cred.server = sprincipal
1839         cred.keyblock = keyblock
1840         cred.authtime = self.get_EpochFromKerberosTime(authtime)
1841         cred.starttime = self.get_EpochFromKerberosTime(starttime)
1842         cred.endtime = self.get_EpochFromKerberosTime(endtime)
1843
1844         # Account for clock skew of up to five minutes.
1845         self.assertLess(cred.authtime - 5 * 60,
1846                         datetime.now(timezone.utc).timestamp(),
1847                         "Ticket not yet valid - clocks may be out of sync.")
1848         self.assertLess(cred.starttime - 5 * 60,
1849                         datetime.now(timezone.utc).timestamp(),
1850                         "Ticket not yet valid - clocks may be out of sync.")
1851         self.assertGreater(cred.endtime - 60 * 60,
1852                            datetime.now(timezone.utc).timestamp(),
1853                            "Ticket already expired/about to expire - "
1854                            "clocks may be out of sync.")
1855
1856         cred.renew_till = cred.endtime
1857         cred.is_skey = 0
1858         cred.ticket_flags = int(enc_part['flags'], 2)
1859         cred.addresses = addresses
1860         cred.authdata = authdata
1861         cred.ticket = ticket_data
1862         cred.second_ticket = b''
1863
1864         ccache = krb5ccache.CCACHE()
1865         ccache.pvno = 5
1866         ccache.version = 4
1867         ccache.optional_header = optional_header
1868         ccache.principal = cprincipal
1869         ccache.cred = cred
1870
1871         # Serialise the credentials cache structure.
1872         result = ndr_pack(ccache)
1873
1874         # Create a temporary file and write the credentials.
1875         cachefile = tempfile.NamedTemporaryFile(dir=self.tempdir, delete=False)
1876         cachefile.write(result)
1877         cachefile.close()
1878
1879         return cachefile
1880
1881     def create_ccache_with_user(self, user_credentials, mach_credentials,
1882                                 service="host", target_name=None, pac=True):
1883         # Obtain a service ticket authorising the user and place it into a
1884         # newly created credentials cache file.
1885
1886         user_name = user_credentials.get_username()
1887         realm = user_credentials.get_realm()
1888
1889         cname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
1890                                           names=[user_name])
1891
1892         tgt = self.get_tgt(user_credentials)
1893
1894         # Request a ticket to the host service on the machine account
1895         ticket = self.get_service_ticket(tgt, mach_credentials,
1896                                          service=service,
1897                                          target_name=target_name)
1898
1899         if not pac:
1900             ticket = self.modified_ticket(ticket, exclude_pac=True)
1901
1902         # Write the ticket into a credentials cache file that can be ingested
1903         # by the main credentials code.
1904         cachefile = self.create_ccache(cname, ticket.ticket,
1905                                        ticket.encpart_private)
1906
1907         # Create a credentials object to reference the credentials cache.
1908         creds = Credentials()
1909         creds.set_kerberos_state(MUST_USE_KERBEROS)
1910         creds.set_username(user_name, SPECIFIED)
1911         creds.set_realm(realm)
1912         creds.set_named_ccache(cachefile.name, SPECIFIED, self.get_lp())
1913
1914         # Return the credentials along with the cache file.
1915         return (creds, cachefile)