samba-tool user: use an implicit_attrs list instead of add_ATTR variables
authorStefan Metzmacher <metze@samba.org>
Mon, 18 Jan 2021 14:51:37 +0000 (15:51 +0100)
committerAndrew Bartlett <abartlet@samba.org>
Mon, 1 Mar 2021 03:50:35 +0000 (03:50 +0000)
We'll extent GetPasswordCommand.get_password_attributes() to handle
more virtual formats in future. It'll be much easier to
to maintain a list of attributes we need to filter out again.

sAMAccountName and userPrincipalName are always implicitly
requested in order to keep the existing code sane.

supplementalCredentials and unicodePwd are requested by default
when generating virtual password attributes.

Pair-Programmed-With: Björn Baumbach <bb@sernet.de>

Signed-off-by: Stefan Metzmacher <metze@samba.org>
Signed-off-by: Björn Baumbach <bb@sernet.de>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/netcmd/user.py

index 8ba5168c04955c4643ac3b12544b7fc0091f6375..ce633bf2e8a423fcfdd5bf8db4d0c1b49160692f 100644 (file)
@@ -152,27 +152,6 @@ def get_crypt_value(alg, utf8pw, rounds=0):
             crypt_salt, len(crypt_value), expected_len))
     return crypt_value
 
-# Extract the rounds value from the options of a virtualCrypt attribute
-# i.e. options = "rounds=20;other=ignored;" will return 20
-# if the rounds option is not found or the value is not a number, 0 is returned
-# which indicates that the default number of rounds should be used.
-
-
-def get_rounds(options):
-    if not options:
-        return 0
-
-    opts = options.split(';')
-    for o in opts:
-        if o.lower().startswith("rounds="):
-            (key, _, val) = o.partition('=')
-            try:
-                return int(val)
-            except ValueError:
-                return 0
-    return 0
-
-
 try:
     import hashlib
     h = hashlib.sha1()
@@ -1159,44 +1138,78 @@ class GetPasswordCommand(Command):
     def get_account_attributes(self, samdb, username, basedn, filter, scope,
                                attrs, decrypt):
 
+        def get_option(opts, name):
+            if not opts:
+                return None
+            for o in opts:
+                if o.lower().startswith("%s=" % name.lower()):
+                    (key, _, val) = o.partition('=')
+                    return val
+            return None
+
+        def get_virtual_attr_definition(attr):
+            for van in sorted(virtual_attributes.keys()):
+                if van.lower() != attr.lower():
+                    continue
+                return virtual_attributes[van]
+            return None
+
+        def parse_raw_attr(raw_attr, is_hidden=False):
+            (attr, _, fullopts) = raw_attr.partition(';')
+            if fullopts:
+                opts = fullopts.split(';')
+            else:
+                opts = []
+            a = {}
+            a["raw_attr"] = raw_attr
+            a["attr"] = attr
+            a["opts"] = opts
+            a["vattr"] = get_virtual_attr_definition(attr)
+            a["is_hidden"] = is_hidden
+            return a
+
         raw_attrs = attrs[:]
+        has_wildcard_attr = "*" in raw_attrs
+        has_virtual_attrs = False
+        requested_attrs = []
+        implicit_attrs = []
+
+        for raw_attr in raw_attrs:
+            a = parse_raw_attr(raw_attr)
+            requested_attrs.append(a)
+
         search_attrs = []
-        attr_opts = {}
-        for a in raw_attrs:
-            (attr, _, opts) = a.partition(';')
-            if opts:
-                attr_opts[attr] = opts
-            else:
-                attr_opts[attr] = None
-            search_attrs.append(attr)
-        lower_attrs = [x.lower() for x in search_attrs]
-
-        require_supplementalCredentials = False
-        for a in virtual_attributes.keys():
-            if a.lower() in lower_attrs:
-                require_supplementalCredentials = True
-        add_supplementalCredentials = False
-        add_unicodePwd = False
-        if require_supplementalCredentials:
-            a = "supplementalCredentials"
-            if a.lower() not in lower_attrs:
-                search_attrs += [a]
-                add_supplementalCredentials = True
-            a = "unicodePwd"
-            if a.lower() not in lower_attrs:
-                search_attrs += [a]
-                add_unicodePwd = True
-        add_sAMAcountName = False
-        a = "sAMAccountName"
-        if a.lower() not in lower_attrs:
-            search_attrs += [a]
-            add_sAMAcountName = True
-
-        add_userPrincipalName = False
-        upn = "userPrincipalName"
-        if upn.lower() not in lower_attrs:
-            search_attrs += [upn]
-            add_userPrincipalName = True
+        has_virtual_attrs = False
+        for a in requested_attrs:
+            if a["vattr"] is not None:
+                has_virtual_attrs = True
+                continue
+            if a["raw_attr"] in search_attrs:
+                continue
+            search_attrs.append(a["raw_attr"])
+
+        if not has_wildcard_attr:
+            required_attrs = [
+                "sAMAccountName",
+                "userPrincipalName"
+            ]
+            for required_attr in required_attrs:
+                a = parse_raw_attr(required_attr)
+                implicit_attrs.append(a)
+
+        if has_virtual_attrs:
+            required_attrs = [
+                "supplementalCredentials",
+                "unicodePwd",
+            ]
+            for required_attr in required_attrs:
+                a = parse_raw_attr(required_attr, is_hidden=True)
+                implicit_attrs.append(a)
+
+        for a in implicit_attrs:
+            if a["attr"] in search_attrs:
+                continue
+            search_attrs.append(a["attr"])
 
         if scope == ldb.SCOPE_BASE:
             search_controls = ["show_deleted:1", "show_recycled:1"]
@@ -1220,22 +1233,14 @@ class GetPasswordCommand(Command):
         if "supplementalCredentials" in obj:
             sc_blob = obj["supplementalCredentials"][0]
             sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
-            if add_supplementalCredentials:
-                del obj["supplementalCredentials"]
         if "unicodePwd" in obj:
             unicodePwd = obj["unicodePwd"][0]
-            if add_unicodePwd:
-                del obj["unicodePwd"]
         account_name = str(obj["sAMAccountName"][0])
-        if add_sAMAcountName:
-            del obj["sAMAccountName"]
         if "userPrincipalName" in obj:
             account_upn = str(obj["userPrincipalName"][0])
         else:
             realm = samdb.domain_dns_name()
             account_upn = "%s@%s" % (account_name, realm.lower())
-        if add_userPrincipalName:
-            del obj["userPrincipalName"]
 
         calculated = {}
 
@@ -1479,10 +1484,32 @@ class GetPasswordCommand(Command):
                                    primary_krb5)
             return (krb5_blob.version, krb5_blob.ctr)
 
+        # Extract the rounds value from the options of a virtualCrypt attribute
+        # i.e. options = "rounds=20;other=ignored;" will return 20
+        # if the rounds option is not found or the value is not a number, 0 is returned
+        # which indicates that the default number of rounds should be used.
+        def get_rounds(opts):
+            val = get_option(opts, "rounds")
+            if val is None:
+                return 0
+            try:
+                return int(val)
+            except ValueError:
+                return 0
+
         # We use sort here in order to have a predictable processing order
         for a in sorted(virtual_attributes.keys()):
-            if not a.lower() in lower_attrs:
+            vattr = None
+            for ra in requested_attrs:
+                if ra["vattr"] is None:
+                    continue
+                if ra["attr"].lower() != a.lower():
+                    continue
+                vattr = ra
+                break
+            if vattr is None:
                 continue
+            attr_opts = vattr["opts"]
 
             if a == "virtualClearTextUTF8":
                 b = get_package("Primary:CLEARTEXT")
@@ -1510,13 +1537,13 @@ class GetPasswordCommand(Command):
                 bv = h.digest() + salt
                 v = "{SSHA}" + base64.b64encode(bv).decode('utf8')
             elif a == "virtualCryptSHA256":
-                rounds = get_rounds(attr_opts[a])
+                rounds = get_rounds(attr_opts)
                 x = get_virtual_crypt_value(a, 5, rounds, username, account_name)
                 if x is None:
                     continue
                 v = x
             elif a == "virtualCryptSHA512":
-                rounds = get_rounds(attr_opts[a])
+                rounds = get_rounds(attr_opts)
                 x = get_virtual_crypt_value(a, 6, rounds, username, account_name)
                 if x is None:
                     continue
@@ -1552,6 +1579,30 @@ class GetPasswordCommand(Command):
             else:
                 continue
             obj[a] = ldb.MessageElement(v, ldb.FLAG_MOD_REPLACE, a)
+
+        # Now filter out implicit attributes
+        for delname in obj.keys():
+            keep = False
+            for ra in requested_attrs:
+                if delname.lower() != ra["raw_attr"].lower():
+                    continue
+                keep = True
+                break
+            if keep:
+                continue
+
+            dattr = None
+            for ia in implicit_attrs:
+                if delname.lower() != ia["attr"].lower():
+                    continue
+                dattr = ia
+                break
+            if dattr is None:
+                continue
+
+            if has_wildcard_attr and not dattr["is_hidden"]:
+                continue
+            del obj[delname]
         return obj
 
     def parse_attributes(self, attributes):