PEP8: add spaces after operators
[nivanova/samba-autobuild/.git] / python / samba / kcc / kcc_utils.py
index 1e5586a2c4199cef55359c001232f7cc47e63c48..c099140c9363cee4d616772213d002d19cd59808 100644 (file)
@@ -28,9 +28,10 @@ from samba.dcerpc import (
     drsblobs,
     drsuapi,
     misc,
-    )
+)
 from samba.common import dsdb_Dn
 from samba.ndr import ndr_unpack, ndr_pack
+from collections import Counter
 
 
 class KCCError(Exception):
@@ -40,6 +41,7 @@ class KCCError(Exception):
 class NCType(object):
     (unknown, schema, domain, config, application) = range(0, 5)
 
+
 # map the NCType enum to strings for debugging
 nctype_lut = dict((v, k) for k, v in NCType.__dict__.items() if k[:2] != '__')
 
@@ -83,7 +85,8 @@ class NamingContext(object):
             res = samdb.search(base=self.nc_dnstr,
                                scope=ldb.SCOPE_BASE, attrs=attrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e:
+            (enum, estr) = e.args
             raise KCCError("Unable to find naming context (%s) - (%s)" %
                            (self.nc_dnstr, estr))
         msg = res[0]
@@ -95,21 +98,6 @@ class NamingContext(object):
 
         assert self.nc_guid is not None
 
-    def is_schema(self):
-        '''Return True if NC is schema'''
-        assert self.nc_type != NCType.unknown
-        return self.nc_type == NCType.schema
-
-    def is_domain(self):
-        '''Return True if NC is domain'''
-        assert self.nc_type != NCType.unknown
-        return self.nc_type == NCType.domain
-
-    def is_application(self):
-        '''Return True if NC is application'''
-        assert self.nc_type != NCType.unknown
-        return self.nc_type == NCType.application
-
     def is_config(self):
         '''Return True if NC is config'''
         assert self.nc_type != NCType.unknown
@@ -178,14 +166,14 @@ class NCReplica(NamingContext):
     class) and it identifies unique attributes of the DSA's replica for a NC.
     """
 
-    def __init__(self, dsa_dnstr, dsa_guid, nc_dnstr):
+    def __init__(self, dsa, nc_dnstr):
         """Instantiate a Naming Context Replica
 
         :param dsa_guid: GUID of DSA where replica appears
         :param nc_dnstr: NC dn string
         """
-        self.rep_dsa_dnstr = dsa_dnstr
-        self.rep_dsa_guid = dsa_guid
+        self.rep_dsa_dnstr = dsa.dsa_dnstr
+        self.rep_dsa_guid = dsa.dsa_guid
         self.rep_default = False  # replica for DSA's default domain
         self.rep_partial = False
         self.rep_ro = False
@@ -228,12 +216,9 @@ class NCReplica(NamingContext):
 
         return "%s\n%s" % (NamingContext.__str__(self), text)
 
-    def set_instantiated_flags(self, flags=None):
+    def set_instantiated_flags(self, flags=0):
         '''Set or clear NC replica instantiated flags'''
-        if flags is None:
-            self.rep_instantiated_flags = 0
-        else:
-            self.rep_instantiated_flags = flags
+        self.rep_instantiated_flags = flags
 
     def identify_by_dsa_attr(self, samdb, attr):
         """Given an NC which has been discovered thru the
@@ -319,7 +304,8 @@ class NCReplica(NamingContext):
             res = samdb.search(base=self.nc_dnstr, scope=ldb.SCOPE_BASE,
                                attrs=["repsFrom"])
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e1:
+            (enum, estr) = e1.args
             raise KCCError("Unable to find NC for (%s) - (%s)" %
                            (self.nc_dnstr, estr))
 
@@ -391,7 +377,7 @@ class NCReplica(NamingContext):
         try:
             samdb.modify(m)
 
-        except ldb.LdbError, estr:
+        except ldb.LdbError as estr:
             raise KCCError("Could not set repsFrom for (%s) - (%s)" %
                            (self.nc_dnstr, estr))
 
@@ -407,7 +393,8 @@ class NCReplica(NamingContext):
             res = samdb.search(base=self.nc_dnstr, scope=ldb.SCOPE_BASE,
                                attrs=["replUpToDateVector"])
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e2:
+            (enum, estr) = e2.args
             raise KCCError("Unable to find NC for (%s) - (%s)" %
                            (self.nc_dnstr, estr))
 
@@ -441,7 +428,8 @@ class NCReplica(NamingContext):
             res = samdb.search(base=self.nc_dnstr, scope=ldb.SCOPE_BASE,
                                attrs=["fSMORoleOwner"])
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e3:
+            (enum, estr) = e3.args
             raise KCCError("Unable to find NC for (%s) - (%s)" %
                            (self.nc_dnstr, estr))
 
@@ -470,7 +458,8 @@ class NCReplica(NamingContext):
             res = samdb.search(base=self.nc_dnstr, scope=ldb.SCOPE_BASE,
                                attrs=["repsTo"])
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e4:
+            (enum, estr) = e4.args
             raise KCCError("Unable to find NC for (%s) - (%s)" %
                            (self.nc_dnstr, estr))
 
@@ -542,7 +531,7 @@ class NCReplica(NamingContext):
         try:
             samdb.modify(m)
 
-        except ldb.LdbError, estr:
+        except ldb.LdbError as estr:
             raise KCCError("Could not set repsTo for (%s) - (%s)" %
                            (self.nc_dnstr, estr))
 
@@ -561,7 +550,7 @@ class DirectoryServiceAgent(object):
         self.dsa_ivid = None
         self.dsa_is_ro = False
         self.dsa_is_istg = False
-        self.dsa_options = 0
+        self.options = 0
         self.dsa_behavior = 0
         self.default_dnstr = None  # default domain dn string for dsa
 
@@ -662,7 +651,8 @@ class DirectoryServiceAgent(object):
             res = samdb.search(base=self.dsa_dnstr, scope=ldb.SCOPE_BASE,
                                attrs=attrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e5:
+            (enum, estr) = e5.args
             raise KCCError("Unable to find nTDSDSA for (%s) - (%s)" %
                            (self.dsa_dnstr, estr))
 
@@ -723,7 +713,8 @@ class DirectoryServiceAgent(object):
             res = samdb.search(base=self.dsa_dnstr, scope=ldb.SCOPE_BASE,
                                attrs=ncattrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e6:
+            (enum, estr) = e6.args
             raise KCCError("Unable to find nTDSDSA NCs for (%s) - (%s)" %
                            (self.dsa_dnstr, estr))
 
@@ -749,12 +740,12 @@ class DirectoryServiceAgent(object):
                 for value in res[0][k]:
                     # Turn dn into a dsdb_Dn so we can use
                     # its methods to parse a binary DN
-                    dsdn = dsdb_Dn(samdb, value)
+                    dsdn = dsdb_Dn(samdb, value.decode('utf8'))
                     flags = dsdn.get_binary_integer()
                     dnstr = str(dsdn.dn)
 
-                    if not dnstr in tmp_table:
-                        rep = NCReplica(self.dsa_dnstr, self.dsa_guid, dnstr)
+                    if dnstr not in tmp_table:
+                        rep = NCReplica(self, dnstr)
                         tmp_table[dnstr] = rep
                     else:
                         rep = tmp_table[dnstr]
@@ -791,7 +782,8 @@ class DirectoryServiceAgent(object):
                                scope=ldb.SCOPE_SUBTREE,
                                expression="(objectClass=nTDSConnection)")
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e7:
+            (enum, estr) = e7.args
             raise KCCError("Unable to find nTDSConnection for (%s) - (%s)" %
                            (self.dsa_dnstr, estr))
 
@@ -969,7 +961,8 @@ class NTDSConnection(object):
             res = samdb.search(base=self.dnstr, scope=ldb.SCOPE_BASE,
                                attrs=attrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e8:
+            (enum, estr) = e8.args
             raise KCCError("Unable to find nTDSConnection for (%s) - (%s)" %
                            (self.dnstr, estr))
 
@@ -994,7 +987,7 @@ class NTDSConnection(object):
                            "for (%s)" % (self.dnstr))
 
         if "transportType" in msg:
-            dsdn = dsdb_Dn(samdb, msg["transportType"][0])
+            dsdn = dsdb_Dn(samdb, msg["transportType"][0].decode('utf8'))
             self.load_connection_transport(samdb, str(dsdn.dn))
 
         if "schedule" in msg:
@@ -1004,7 +997,7 @@ class NTDSConnection(object):
             self.whenCreated = ldb.string_to_time(msg["whenCreated"][0])
 
         if "fromServer" in msg:
-            dsdn = dsdb_Dn(samdb, msg["fromServer"][0])
+            dsdn = dsdb_Dn(samdb, msg["fromServer"][0].decode('utf8'))
             self.from_dnstr = str(dsdn.dn)
             assert self.from_dnstr is not None
 
@@ -1019,7 +1012,8 @@ class NTDSConnection(object):
             res = samdb.search(base=tdnstr,
                                scope=ldb.SCOPE_BASE, attrs=attrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e9:
+            (enum, estr) = e9.args
             raise KCCError("Unable to find transport (%s) - (%s)" %
                            (tdnstr, estr))
 
@@ -1046,7 +1040,8 @@ class NTDSConnection(object):
 
         try:
             samdb.delete(self.dnstr)
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e10:
+            (enum, estr) = e10.args
             raise KCCError("Could not delete nTDSConnection for (%s) - (%s)" %
                            (self.dnstr, estr))
 
@@ -1070,7 +1065,8 @@ class NTDSConnection(object):
             if len(msg) != 0:
                 found = True
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e11:
+            (enum, estr) = e11.args
             if enum != ldb.ERR_NO_SUCH_OBJECT:
                 raise KCCError("Unable to search for (%s) - (%s)" %
                                (self.dnstr, estr))
@@ -1115,7 +1111,8 @@ class NTDSConnection(object):
                                    ldb.FLAG_MOD_ADD, "schedule")
         try:
             samdb.add(m)
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e12:
+            (enum, estr) = e12.args
             raise KCCError("Could not add nTDSConnection for (%s) - (%s)" %
                            (self.dnstr, estr))
 
@@ -1138,7 +1135,8 @@ class NTDSConnection(object):
             # of self.dnstr in the database.
             samdb.search(base=self.dnstr, scope=ldb.SCOPE_BASE)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e13:
+            (enum, estr) = e13.args
             if enum == ldb.ERR_NO_SUCH_OBJECT:
                 raise KCCError("nTDSConnection for (%s) doesn't exist!" %
                                self.dnstr)
@@ -1184,7 +1182,8 @@ class NTDSConnection(object):
                 ldb.MessageElement([], ldb.FLAG_MOD_DELETE, "schedule")
         try:
             samdb.modify(m)
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e14:
+            (enum, estr) = e14.args
             raise KCCError("Could not modify nTDSConnection for (%s) - (%s)" %
                            (self.dnstr, estr))
 
@@ -1211,11 +1210,19 @@ class NTDSConnection(object):
 
         :param shed: schedule to compare to
         """
-        if self.schedule is not None:
-            if sched is None:
-                return False
-        elif sched is None:
-            return True
+        # There are 4 cases, where either self.schedule or sched can be None
+        #
+        #                   |  self. is None  |   self. is not None
+        #     --------------+-----------------+--------------------
+        #     sched is None |     True        |     False
+        #     --------------+-----------------+--------------------
+        # sched is not None |    False        |    do calculations
+
+        if self.schedule is None:
+            return sched is None
+
+        if sched is None:
+            return False
 
         if ((self.schedule.size != sched.size or
              self.schedule.bandwidth != sched.bandwidth or
@@ -1336,7 +1343,8 @@ class Partition(NamingContext):
             res = samdb.search(base=self.partstr, scope=ldb.SCOPE_BASE,
                                attrs=attrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e15:
+            (enum, estr) = e15.args
             raise KCCError("Unable to find partition for (%s) - (%s)" %
                            (self.partstr, estr))
         msg = res[0]
@@ -1356,7 +1364,7 @@ class Partition(NamingContext):
                 continue
 
             for value in msg[k]:
-                dsdn = dsdb_Dn(samdb, value)
+                dsdn = dsdb_Dn(samdb, value.decode('utf8'))
                 dnstr = str(dsdn.dn)
 
                 if k == "nCName":
@@ -1476,7 +1484,8 @@ class Site(object):
                                attrs=attrs)
             self_res = samdb.search(base=self.site_dnstr, scope=ldb.SCOPE_BASE,
                                     attrs=['objectGUID'])
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e16:
+            (enum, estr) = e16.args
             raise KCCError("Unable to find site settings for (%s) - (%s)" %
                            (ssdn, estr))
 
@@ -1507,7 +1516,8 @@ class Site(object):
             res = samdb.search(self.site_dnstr,
                                scope=ldb.SCOPE_SUBTREE,
                                expression="(objectClass=nTDSDSA)")
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e17:
+            (enum, estr) = e17.args
             raise KCCError("Unable to find nTDSDSAs - (%s)" % estr)
 
         for msg in res:
@@ -1527,12 +1537,6 @@ class Site(object):
             if not dsa.is_ro():
                 self.rw_dsa_table[dnstr] = dsa
 
-    def get_dsa_by_guidstr(self, guidstr):  # XXX unused
-        for dsa in self.dsa_table.values():
-            if str(dsa.dsa_guid) == guidstr:
-                return dsa
-        return None
-
     def get_dsa(self, dnstr):
         """Return a previously loaded DSA object by consulting
         the sites dsa_table for the provided DSA dn string
@@ -1578,7 +1582,9 @@ class Site(object):
         # Which is a fancy way of saying "sort all the nTDSDSA objects
         # in the site by guid in ascending order".   Place sorted list
         # in D_sort[]
-        D_sort = sorted(self.rw_dsa_table.values(), cmp=sort_dsa_by_guid)
+        D_sort = sorted(
+            self.rw_dsa_table.values(),
+            key=lambda dsa: ndr_pack(dsa.dsa_guid))
 
         # double word number of 100 nanosecond intervals since 1600s
 
@@ -1640,7 +1646,7 @@ class Site(object):
                 i_idx = j_idx
                 t_time = 0
 
-            #XXX doc says current time < c.timeLastSyncSuccess - f
+            # XXX doc says current time < c.timeLastSyncSuccess - f
             # which is true only if f is negative or clocks are wrong.
             # f is not negative in the default case (2 hours).
             elif self.nt_now - cursor.last_sync_success > f:
@@ -1666,7 +1672,7 @@ class Site(object):
         #
         # Note: We don't want to divide by zero here so they must
         #       have meant "f" instead of "o!interSiteTopologyFailover"
-        k_idx = (i_idx + ((self.nt_now - t_time) / f)) % len(D_sort)
+        k_idx = (i_idx + ((self.nt_now - t_time) // f)) % len(D_sort)
 
         # The local writable DC acts as an ISTG for its site if and
         # only if dk is the nTDSDSA object for the local DC. If the
@@ -1702,7 +1708,7 @@ class Site(object):
         try:
             samdb.modify(m)
 
-        except ldb.LdbError, estr:
+        except ldb.LdbError as estr:
             raise KCCError(
                 "Could not set interSiteTopologyGenerator for (%s) - (%s)" %
                 (ssdn, estr))
@@ -1780,7 +1786,9 @@ class GraphNode(object):
         text = text + "\n\tmax_edges=%d" % self.max_edges
 
         for i, edge in enumerate(self.edge_from):
-            text = text + "\n\tedge_from[%d]=%s" % (i, edge)
+            if isinstance(edge, str):
+                text += "\n\tedge_from[%d]=%s" % (i, edge)
+
         return text
 
     def add_edge_from(self, from_dsa_dnstr):
@@ -1788,7 +1796,7 @@ class GraphNode(object):
 
         :param from_dsa_dnstr: the dsa that the edge emanates from
         """
-        assert from_dsa_dnstr is not None
+        assert isinstance(from_dsa_dnstr, str)
 
         # No edges from myself to myself
         if from_dsa_dnstr == self.dsa_dnstr:
@@ -1899,7 +1907,8 @@ class Transport(object):
             res = samdb.search(base=self.dnstr, scope=ldb.SCOPE_BASE,
                                attrs=attrs)
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e18:
+            (enum, estr) = e18.args
             raise KCCError("Unable to find Transport for (%s) - (%s)" %
                            (self.dnstr, estr))
 
@@ -1918,7 +1927,7 @@ class Transport(object):
 
         if "bridgeheadServerListBL" in msg:
             for value in msg["bridgeheadServerListBL"]:
-                dsdn = dsdb_Dn(samdb, value)
+                dsdn = dsdb_Dn(samdb, value.decode('utf8'))
                 dnstr = str(dsdn.dn)
                 if dnstr not in self.bridgehead_list:
                     self.bridgehead_list.append(dnstr)
@@ -2140,8 +2149,8 @@ class SiteLink(object):
                     text = text + "0x%X " % slot
                 text = text + "]"
 
-        for dnstr in self.site_list:
-            text = text + "\n\tsite_list=%s" % dnstr
+        for guid, dn in self.site_list:
+            text = text + "\n\tsite_list=%s (%s)" % (guid, dn)
         return text
 
     def load_sitelink(self, samdb):
@@ -2159,7 +2168,8 @@ class SiteLink(object):
             res = samdb.search(base=self.dnstr, scope=ldb.SCOPE_BASE,
                                attrs=attrs, controls=['extended_dn:0'])
 
-        except ldb.LdbError, (enum, estr):
+        except ldb.LdbError as e19:
+            (enum, estr) = e19.args
             raise KCCError("Unable to find SiteLink for (%s) - (%s)" %
                            (self.dnstr, estr))
 
@@ -2179,10 +2189,11 @@ class SiteLink(object):
 
         if "siteList" in msg:
             for value in msg["siteList"]:
-                dsdn = dsdb_Dn(samdb, value)
+                dsdn = dsdb_Dn(samdb, value.decode('utf8'))
                 guid = misc.GUID(dsdn.dn.get_extended_component('GUID'))
-                if guid not in self.site_list:
-                    self.site_list.append(guid)
+                dnstr = str(dsdn.dn)
+                if (guid, dnstr) not in self.site_list:
+                    self.site_list.append((guid, dnstr))
 
         if "schedule" in msg:
             self.schedule = ndr_unpack(drsblobs.schedule, value)
@@ -2214,11 +2225,6 @@ def get_dsa_config_rep(dsa):
                    dsa.dsa_dnstr)
 
 
-def sort_dsa_by_guid(dsa1, dsa2):
-    "use ndr_pack for GUID comparison, as appears correct in some places"""
-    return cmp(ndr_pack(dsa1.dsa_guid), ndr_pack(dsa2.dsa_guid))
-
-
 def new_connection_schedule():
     """Create a default schedule for an NTDSConnection or Sitelink. This
     is packed differently from the repltimes schedule used elsewhere
@@ -2246,3 +2252,114 @@ def new_connection_schedule():
 
     schedule.dataArray = [data]
     return schedule
+
+
+##################################################
+# DNS related calls
+##################################################
+
+def uncovered_sites_to_cover(samdb, site_name):
+    """
+    Discover which sites have no DCs and whose lowest single-hop cost
+    distance for any link attached to that site is linked to the site supplied.
+
+    We compare the lowest cost of your single-hop link to this site to all of
+    those available (if it exists). This means that a lower ranked siteLink
+    with only the uncovered site can trump any available links (but this can
+    only be done with specific, poorly enacted user configuration).
+
+    If the site is connected to more than one other site with the same
+    siteLink, only the largest site (failing that sorted alphabetically)
+    creates the DNS records.
+
+    :param samdb database
+    :param site_name origin site (with a DC)
+
+    :return a list of sites this site should be covering (for DNS)
+    """
+    sites_to_cover = []
+
+    server_res = samdb.search(base=samdb.get_config_basedn(),
+                              scope=ldb.SCOPE_SUBTREE,
+                              expression="(&(objectClass=server)"
+                              "(serverReference=*))")
+
+    site_res = samdb.search(base=samdb.get_config_basedn(),
+                            scope=ldb.SCOPE_SUBTREE,
+                            expression="(objectClass=site)")
+
+    sites_in_use = Counter()
+    dc_count = 0
+
+    # Assume server is of form DC,Servers,Site-ABCD because of schema
+    for msg in server_res:
+        site_dn = msg.dn.parent().parent()
+        sites_in_use[site_dn.canonical_str()] += 1
+
+        if site_dn.get_rdn_value().lower() == site_name.lower():
+            dc_count += 1
+
+    if len(sites_in_use) != len(site_res):
+        # There is a possible uncovered site
+        sites_uncovered = []
+
+        for msg in site_res:
+            if msg.dn.canonical_str() not in sites_in_use:
+                sites_uncovered.append(msg)
+
+        own_site_dn = "CN={},CN=Sites,{}".format(
+            ldb.binary_encode(site_name),
+            ldb.binary_encode(str(samdb.get_config_basedn()))
+        )
+
+        for site in sites_uncovered:
+            encoded_dn = ldb.binary_encode(str(site.dn))
+
+            # Get a sorted list of all siteLinks featuring the uncovered site
+            link_res1 = samdb.search(base=samdb.get_config_basedn(),
+                                     scope=ldb.SCOPE_SUBTREE, attrs=["cost"],
+                                     expression="(&(objectClass=siteLink)"
+                                     "(siteList={}))".format(encoded_dn),
+                                     controls=["server_sort:1:0:cost"])
+
+            # Get a sorted list of all siteLinks connecting this an the
+            # uncovered site
+            link_res2 = samdb.search(base=samdb.get_config_basedn(),
+                                     scope=ldb.SCOPE_SUBTREE,
+                                     attrs=["cost", "siteList"],
+                                     expression="(&(objectClass=siteLink)"
+                                     "(siteList={})(siteList={}))".format(
+                                         own_site_dn,
+                                         encoded_dn),
+                                     controls=["server_sort:1:0:cost"])
+
+            # Add to list if your link is equal in cost to lowest cost link
+            if len(link_res1) > 0 and len(link_res2) > 0:
+                cost1 = int(link_res1[0]['cost'][0])
+                cost2 = int(link_res2[0]['cost'][0])
+
+                # Own siteLink must match the lowest cost link
+                if cost1 != cost2:
+                    continue
+
+                # In a siteLink with more than 2 sites attached, only pick the
+                # largest site, and if there are multiple, the earliest
+                # alphabetically.
+                to_cover = True
+                for site_val in link_res2[0]['siteList']:
+                    site_dn = ldb.Dn(samdb, str(site_val))
+                    site_dn_str = site_dn.canonical_str()
+                    site_rdn = site_dn.get_rdn_value().lower()
+                    if sites_in_use[site_dn_str] > dc_count:
+                        to_cover = False
+                        break
+                    elif (sites_in_use[site_dn_str] == dc_count and
+                          site_rdn < site_name.lower()):
+                        to_cover = False
+                        break
+
+                if to_cover:
+                    site_cover_rdn = site.dn.get_rdn_value()
+                    sites_to_cover.append(site_cover_rdn.lower())
+
+    return sites_to_cover