py/dnsserver: add .from_string() methods
authorDouglas Bagnall <douglas.bagnall@catalyst.net.nz>
Wed, 7 Apr 2021 01:34:50 +0000 (13:34 +1200)
committerDouglas Bagnall <dbagnall@samba.org>
Thu, 8 Apr 2021 21:54:35 +0000 (21:54 +0000)
The logic to parse DNS value strings (e.g. "example.com 10" for an MX,
which needs to be split on the space) is repeated at least in
samba-tool dns and tests/dcerpc/dnsserver.py. Here we bring it
together so we can do it once.

The sep= keyword allows callers to separate on all runs of
whitespace (the default, as samba-tool dns does) or, using sep='', to
separate on true spaces only.

Signed-off-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andreas Schneider <asn@samba.org>
python/samba/dnsserver.py

index 7703fa9186d6bc753c856624c9d744b66bf38706..42de46b8d4d40e10684005c412a8ef06563d1e88 100644 (file)
@@ -16,6 +16,7 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 
+import shlex
 from samba.dcerpc import dnsserver, dnsp
 
 # Note: these are not quite the same as similar looking classes in
@@ -39,6 +40,9 @@ from samba.dcerpc import dnsserver, dnsp
 # them can represent any type of record.
 
 
+class DNSParseError(ValueError):
+    pass
+
 
 class ARecord(dnsserver.DNS_RPC_RECORD):
     def __init__(self, ip_addr, serial=1, ttl=900, rank=dnsp.DNS_RANK_ZONE,
@@ -50,6 +54,10 @@ class ARecord(dnsserver.DNS_RPC_RECORD):
         self.dwTtlSeconds = ttl
         self.data = ip_addr
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        return cls(data, **kwargs)
+
 
 class AAAARecord(dnsserver.DNS_RPC_RECORD):
 
@@ -62,6 +70,10 @@ class AAAARecord(dnsserver.DNS_RPC_RECORD):
         self.dwTtlSeconds = ttl
         self.data = ip6_addr
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        return cls(data, **kwargs)
+
 
 class PTRRecord(dnsserver.DNS_RPC_RECORD):
 
@@ -77,6 +89,10 @@ class PTRRecord(dnsserver.DNS_RPC_RECORD):
         ptr_name.len = len(ptr)
         self.data = ptr_name
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        return cls(data, **kwargs)
+
 
 class CNAMERecord(dnsserver.DNS_RPC_RECORD):
 
@@ -92,6 +108,10 @@ class CNAMERecord(dnsserver.DNS_RPC_RECORD):
         cname_name.len = len(cname)
         self.data = cname_name
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        return cls(data, **kwargs)
+
 
 class NSRecord(dnsserver.DNS_RPC_RECORD):
 
@@ -107,6 +127,10 @@ class NSRecord(dnsserver.DNS_RPC_RECORD):
         ns.len = len(dns_server)
         self.data = ns
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        return cls(data, **kwargs)
+
 
 class MXRecord(dnsserver.DNS_RPC_RECORD):
 
@@ -123,6 +147,16 @@ class MXRecord(dnsserver.DNS_RPC_RECORD):
         mx.nameExchange.len = len(mail_server)
         self.data = mx
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        try:
+            server, priority = data.split(sep)
+            priority = int(priority)
+        except ValueError as e:
+            raise DNSParseError("MX data must have server and priority "
+                                "(space separated), not %r" % data) from e
+        return cls(server, priority, **kwargs)
+
 
 class SOARecord(dnsserver.DNS_RPC_RECORD):
 
@@ -146,6 +180,21 @@ class SOARecord(dnsserver.DNS_RPC_RECORD):
         soa.ZoneAdministratorEmail.len = len(rname)
         self.data = soa
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        args = data.split(sep)
+        if len(args) != 7:
+            raise DNSParseError('Data requires 7 space separated elements - '
+                                'nameserver, email, serial, '
+                                'refresh, retry, expire, minimumttl')
+        try:
+            for i in range(2, 7):
+                args[i] = int(args[i])
+        except ValueError as e:
+            raise DNSParseError("SOA serial, refresh, retry, expire, minimumttl' "
+                                "should be integers") from e
+        return cls(*args, **kwargs)
+
 
 class SRVRecord(dnsserver.DNS_RPC_RECORD):
 
@@ -164,6 +213,23 @@ class SRVRecord(dnsserver.DNS_RPC_RECORD):
         srv.nameTarget.len = len(target)
         self.data = srv
 
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        try:
+            target, port, priority, weight = data.split(sep)
+        except ValueError as e:
+            raise DNSParseError("SRV data must have four space "
+                                "separated elements: "
+                                "server, port, priority, weight; "
+                                "not %r" % data) from e
+        try:
+            args = (target, int(port), int(priority), int(weight))
+        except ValueError as e:
+            raise DNSParseError("SRV port, priority, and weight "
+                                "must be integers") from e
+
+        return cls(*args, **kwargs)
+
 
 class TXTRecord(dnsserver.DNS_RPC_RECORD):
 
@@ -184,3 +250,8 @@ class TXTRecord(dnsserver.DNS_RPC_RECORD):
         txt.count = len(slist)
         txt.str = names
         self.data = txt
+
+    @classmethod
+    def from_string(cls, data, sep=None, **kwargs):
+        slist = shlex.split(data)
+        return cls(slist, **kwargs)