netcmd/ldapcmp: fix wrong way for string copy
[samba.git] / python / samba / netcmd / ldapcmp.py
index 01739f90f85d1f083673898f02789e404aeeaad9..646eedc0e8079f2d4ca95c5fc8570cb492f5194c 100644 (file)
@@ -37,9 +37,6 @@ from samba.netcmd import (
     Option,
 )
 
-global summary
-summary = {}
-
 
 class LDAPBase(object):
 
@@ -211,13 +208,15 @@ class LDAPBase(object):
         res = dict(res[0])
         # 'Dn' element is not iterable and we have it as 'distinguishedName'
         del res["dn"]
-        for key in list(res.keys()):
-            vals = list(res[key])
-            del res[key]
+
+        attributes = {}
+        for key, vals in res.items():
             name = self.get_attribute_name(key)
-            res[name] = self.get_attribute_values(object_dn, key, vals)
+            # sort vals and return a list, help to compare
+            vals = sorted(vals)
+            attributes[name] = self.get_attribute_values(object_dn, key, vals)
 
-        return res
+        return attributes
 
     def get_descriptor_sddl(self, object_dn):
         res = self.ldb.search(base=object_dn, scope=SCOPE_BASE, attrs=["nTSecurityDescriptor"])
@@ -566,8 +565,8 @@ class LDAPObject(object):
         else:
             raise Exception("Unknown --view option value.")
         #
-        self.screen_output = res[1][:-1]
-        other.screen_output = res[1][:-1]
+        self.screen_output = res[1]
+        other.screen_output = res[1]
         #
         return res[0]
 
@@ -682,8 +681,8 @@ class LDAPObject(object):
         other.summary["unique_attrs"] += other.unique_attrs
         other.summary["df_value_attrs"] += self.df_value_attrs  # they are the same
         #
-        self.screen_output = res[:-1]
-        other.screen_output = res[:-1]
+        self.screen_output = res
+        other.screen_output = res
         #
         return res == ""
 
@@ -746,69 +745,55 @@ class LDAPBundle(object):
             self.log("\n* DN lists have different size: %s != %s" % (self.size, other.size))
             if not self.skip_missing_dn:
                 res = False
+
+        self_dns = set([q.upper() for q in self.dn_list])
+        other_dns = set([q.upper() for q in other.dn_list])
+
         #
         # This is the case where we want to explicitly compare two objects with different DNs.
         # It does not matter if they are in the same DC, in two DC in one domain or in two
         # different domains.
-        if self.search_scope != SCOPE_BASE:
-            title = "\n* DNs found only in %s:" % self.con.host
-            for x in self.dn_list:
-                if not x.upper() in [q.upper() for q in other.dn_list]:
-                    if title and not self.skip_missing_dn:
-                        self.log(title)
-                        title = None
-                        res = False
+        if self.search_scope != SCOPE_BASE and not self.skip_missing_dn:
+
+            self_only = self_dns - other_dns  # missing in other
+            if self_only:
+                res = False
+                self.log("\n* DNs found only in %s:" % self.con.host)
+                for x in self_only:
                     self.log(4 * " " + x)
-                    self.dn_list[self.dn_list.index(x)] = ""
-            self.dn_list = [x for x in self.dn_list if x]
-            #
-            title = "\n* DNs found only in %s:" % other.con.host
-            for x in other.dn_list:
-                if not x.upper() in [q.upper() for q in self.dn_list]:
-                    if title and not self.skip_missing_dn:
-                        self.log(title)
-                        title = None
-                        res = False
+
+            other_only = other_dns - self_dns  # missing in self
+            if other_only:
+                res = False
+                self.log("\n* DNs found only in %s:" % other.con.host)
+                for x in other_only:
                     self.log(4 * " " + x)
-                    other.dn_list[other.dn_list.index(x)] = ""
-            other.dn_list = [x for x in other.dn_list if x]
-            #
-            self.update_size()
-            other.update_size()
-            assert self.size == other.size
-            assert sorted([x.upper() for x in self.dn_list]) == sorted([x.upper() for x in other.dn_list])
-        self.log("\n* Objects to be compared: %s" % self.size)
 
-        index = 0
-        while index < self.size:
-            skip = False
+        common_dns = self_dns & other_dns
+        self.log("\n* Objects to be compared: %d" % len(common_dns))
+
+        for dn in common_dns:
+
             try:
                 object1 = LDAPObject(connection=self.con,
-                                     dn=self.dn_list[index],
+                                     dn=dn,
                                      summary=self.summary,
                                      filter_list=self.filter_list,
                                      outf=self.outf, errf=self.errf)
             except LdbError as e:
-                (enum, estr) = e.args
-                if enum == ERR_NO_SUCH_OBJECT:
-                    self.log("\n!!! Object not found: %s" % self.dn_list[index])
-                    skip = True
-                raise
+                self.log("LdbError for dn %s: %s" % (dn, e))
+                continue
+
             try:
                 object2 = LDAPObject(connection=other.con,
-                                     dn=other.dn_list[index],
+                                     dn=dn,
                                      summary=other.summary,
                                      filter_list=self.filter_list,
                                      outf=self.outf, errf=self.errf)
-            except LdbError as e1:
-                (enum, estr) = e1.args
-                if enum == ERR_NO_SUCH_OBJECT:
-                    self.log("\n!!! Object not found: %s" % other.dn_list[index])
-                    skip = True
-                raise
-            if skip:
-                index += 1
+            except LdbError as e:
+                self.log("LdbError for dn %s: %s" % (dn, e))
                 continue
+
             if object1 == object2:
                 if self.con.verbose:
                     self.log("\nComparing:")
@@ -824,8 +809,7 @@ class LDAPBundle(object):
                 res = False
             self.summary = object1.summary
             other.summary = object2.summary
-            index += 1
-        #
+
         return res
 
     def get_dn_list(self, context):
@@ -863,9 +847,6 @@ class LDAPBundle(object):
             raise
         for x in res:
             dn_list.append(x["dn"].get_linearized())
-        #
-        global summary
-        #
         return dn_list
 
     def print_summary(self):
@@ -905,7 +886,7 @@ class cmd_ldapcmp(Command):
                help="Compare nTSecurityDescriptor attibutes only"),
         Option("--sort-aces", dest="sort_aces", action="store_true", default=False,
                help="Sort ACEs before comparison of nTSecurityDescriptor attribute"),
-        Option("--view", dest="view", default="section",
+        Option("--view", dest="view", default="section", choices=["section", "collision"],
                help="Display mode for nTSecurityDescriptor results. Possible values: section or collision."),
         Option("--base", dest="base", default="",
                help="Pass search base that will build DN list for the first DC."),
@@ -964,19 +945,17 @@ class cmd_ldapcmp(Command):
             raise CommandError("You cannot set --verbose and --quiet together")
         if (not base and base2) or (base and not base2):
             raise CommandError("You need to specify both --base and --base2 at the same time")
-        if descriptor and view.upper() not in ["SECTION", "COLLISION"]:
-            raise CommandError("Invalid --view value. Choose from: section or collision")
 
         con1 = LDAPBase(URL1, creds, lp,
                         two=two, quiet=quiet, descriptor=descriptor, sort_aces=sort_aces,
                         verbose=verbose, view=view, base=base, scope=scope,
-                        outf=self.outf, errf=self.errf)
+                        outf=self.outf, errf=self.errf, skip_missing_dn=skip_missing_dn)
         assert len(con1.base_dn) > 0
 
         con2 = LDAPBase(URL2, creds2, lp,
                         two=two, quiet=quiet, descriptor=descriptor, sort_aces=sort_aces,
                         verbose=verbose, view=view, base=base2, scope=scope,
-                        outf=self.outf, errf=self.errf)
+                        outf=self.outf, errf=self.errf, skip_missing_dn=skip_missing_dn)
         assert len(con2.base_dn) > 0
 
         filter_list = filter.split(",")