python/tests/krb5: Allow getting a TGT in pkinit tests
authorAndrew Bartlett <abartlet@samba.org>
Tue, 26 Mar 2024 01:42:20 +0000 (14:42 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Thu, 28 Mar 2024 01:50:41 +0000 (01:50 +0000)
Signed-off-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Jo Sutton <josutton@catalyst.net.nz>
python/samba/tests/krb5/pkinit_tests.py

index 6aabc08483065f6d6d3b7cf2f60ceed946c4061d..ac54d8e890046e1c0a8ef0c23c2e9d400a9750b0 100755 (executable)
@@ -51,6 +51,7 @@ from samba.tests.krb5.rfc4120_constants import (
     KDC_ERR_PREAUTH_REQUIRED,
     KU_PA_ENC_TIMESTAMP,
     NT_PRINCIPAL,
+    NT_SRV_INST,
     PADATA_AS_FRESHNESS,
     PADATA_ENC_TIMESTAMP,
     PADATA_PK_AS_REP_19,
@@ -625,8 +626,12 @@ class PkInitTests(KDCBaseTest):
         target_name = target_creds.get_username()
         target_realm = target_creds.get_realm()
 
-        sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
-                                          names=['host', target_name[:-1]])
+        if target_name == "krbtgt":
+            sname = self.PrincipalName_create(name_type=NT_SRV_INST,
+                                              names=['krbtgt', target_realm])
+        else:
+            sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
+                                              names=['host', target_name[:-1]])
 
         if expect_error:
             check_error_fn = self.generic_check_kdc_error
@@ -637,8 +642,11 @@ class PkInitTests(KDCBaseTest):
             check_error_fn = None
             check_rep_fn = self.generic_check_kdc_rep
 
-            expected_sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
-                                                       names=[target_name])
+            if target_name == "krbtgt":
+                expected_sname = sname
+            else:
+                expected_sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
+                                                           names=[target_name])
 
         kdc_options = ('forwardable,'
                        'renewable,'
@@ -1146,21 +1154,27 @@ class PkInitTests(KDCBaseTest):
         target_name = target_creds.get_username()
         target_realm = target_creds.get_realm()
 
-        sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
-                                          names=['host', target_name[:-1]])
+        target_name = target_creds.get_username()
+        if target_name == "krbtgt":
+            target_sname = self.PrincipalName_create(name_type=NT_SRV_INST,
+                                                     names=['krbtgt', target_realm])
+            expected_sname = target_sname
+        else:
+            target_sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
+                                                     names=['host', target_name[:-1]])
+
+            expected_sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
+                                                           names=[target_name])
 
         if expect_error:
             check_error_fn = self.generic_check_kdc_error
             check_rep_fn = None
 
-            expected_sname = sname
+            expected_sname = target_sname
         else:
             check_error_fn = None
             check_rep_fn = self.generic_check_kdc_rep
 
-            expected_sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
-                                                       names=[target_name])
-
         kdc_options = ('forwardable,'
                        'renewable,'
                        'canonicalize,'
@@ -1213,7 +1227,7 @@ class PkInitTests(KDCBaseTest):
         rep = self._generic_kdc_exchange(kdc_exchange_dict,
                                          cname=cname,
                                          realm=target_realm,
-                                         sname=sname,
+                                         sname=target_sname,
                                          till_time=till,
                                          etypes=etypes)
         if expect_error: