dns: auto-delete incorrect SRV entries for our hostname
[samba.git] / source4 / scripting / bin / samba_dnsupdate
index c5af17a759efe04b743a9f53e510071d1df0c3a0..9a90eac9dcfe536ba0648e98214042600aad16bd 100755 (executable)
@@ -94,6 +94,8 @@ class dnsobj(object):
         self.dest = None
         self.port = None
         self.ip = None
+        self.existing_port = None
+        self.existing_weight = None
     def __str__(self):
         if d.type == "A":     return "%s:%s:%s" % (self.type, self.name, self.ip)
         if d.type == "SRV":   return "%s:%s:%s:%s" % (self.type, self.name, self.dest, self.port)
@@ -131,10 +133,11 @@ def hostname_match(h1, h2):
 ############################################
 # check that a DNS entry exists
 def check_dns_name(d):
+    normalised_name = d.name.rstrip('.') + '.'
     if opts.verbose:
-        print "Looking for DNS entry %s" % d
+        print "Looking for DNS entry %s as %s" % (d, normalised_name)
     try:
-        ans = dns.resolver.query(d.name, d.type)
+        ans = dns.resolver.query(normalised_name, d.type)
     except dns.resolver.NXDOMAIN:
         return False
     if d.type == 'A':
@@ -147,14 +150,15 @@ def check_dns_name(d):
             if hostname_match(ans[i].target, d.dest):
                 return True
     if d.type == 'SRV':
-        if opts.verbose:
-            print "Got %u replies in SRV lookup for %s" % (len(ans), d.name)
-        for i in range(len(ans)):
-            rdata = ans[i]
+        for rdata in ans:
             if opts.verbose:
                 print "Checking %s against %s" % (rdata, d)
-            if str(rdata.port) == str(d.port) and hostname_match(rdata.target, d.dest):
-                return True
+            if hostname_match(rdata.target, d.dest):
+                if str(rdata.port) == str(d.port):
+                    return True
+                else:
+                    d.existing_port     = str(rdata.port)
+                    d.existing_weight = str(rdata.weight)
     if opts.verbose:
         print "Failed to find DNS entry %s" % d
     return False
@@ -190,6 +194,9 @@ def call_nsupdate(d):
     if d.type == "A":
         f.write("update add %s %u A %s\n" % (d.name, default_ttl, d.ip))
     if d.type == "SRV":
+        if d.existing_port is not None:
+            f.write("update delete %s SRV 0 %s %s %s\n" % (d.name, d.existing_weight,
+                                                           d.existing_port, d.dest))
         f.write("update add %s %u SRV 0 100 %s %s\n" % (d.name, default_ttl, d.port, d.dest))
     if d.type == "CNAME":
         f.write("update add %s %u SRV %s\n" % (d.name, default_ttl, d.dest))