PEP8: fix E241: multiple spaces after ','
[samba.git] / python / samba / emulate / traffic.py
index 0fb60b4c87b3ed8e7710922a140cd75941512139..a60ebbba9e2b550fe9a853aff42d485e60a27aab 100644 (file)
@@ -16,7 +16,7 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
-from __future__ import print_function
+from __future__ import print_function, division
 
 import time
 import os
@@ -42,10 +42,14 @@ from samba.drs_utils import drs_DsBind
 import traceback
 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
 from samba.auth import system_session
-from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
-from samba.dsdb import UF_NORMAL_ACCOUNT
-from samba.dcerpc.misc import SEC_CHAN_WKSTA
+from samba.dsdb import (
+    UF_NORMAL_ACCOUNT,
+    UF_SERVER_TRUST_ACCOUNT,
+    UF_TRUSTED_FOR_DELEGATION
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
 from samba import gensec
+from samba import sd_utils
 
 SLEEP_OVERHEAD = 3e-4
 
@@ -134,10 +138,26 @@ class FakePacketError(Exception):
 
 class Packet(object):
     """Details of a network packet"""
-    def __init__(self, fields):
-        if isinstance(fields, str):
-            fields = fields.rstrip('\n').split('\t')
+    def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+                 protocol, opcode, desc, extra):
 
+        self.timestamp = timestamp
+        self.ip_protocol = ip_protocol
+        self.stream_number = stream_number
+        self.src = src
+        self.dest = dest
+        self.protocol = protocol
+        self.opcode = opcode
+        self.desc = desc
+        self.extra = extra
+        if self.src < self.dest:
+            self.endpoints = (self.src, self.dest)
+        else:
+            self.endpoints = (self.dest, self.src)
+
+    @classmethod
+    def from_line(self, line):
+        fields = line.rstrip('\n').split('\t')
         (timestamp,
          ip_protocol,
          stream_number,
@@ -148,23 +168,12 @@ class Packet(object):
          desc) = fields[:8]
         extra = fields[8:]
 
-        self.timestamp = float(timestamp)
-        self.ip_protocol = ip_protocol
-        try:
-            self.stream_number = int(stream_number)
-        except (ValueError, TypeError):
-            self.stream_number = None
-        self.src = int(src)
-        self.dest = int(dest)
-        self.protocol = protocol
-        self.opcode = opcode
-        self.desc = desc
-        self.extra = extra
+        timestamp = float(timestamp)
+        src = int(src)
+        dest = int(dest)
 
-        if self.src < self.dest:
-            self.endpoints = (self.src, self.dest)
-        else:
-            self.endpoints = (self.dest, self.src)
+        return Packet(timestamp, ip_protocol, stream_number, src, dest,
+                      protocol, opcode, desc, extra)
 
     def as_summary(self, time_offset=0.0):
         """Format the packet as a traffic_summary line.
@@ -192,14 +201,15 @@ class Packet(object):
         return "<Packet @%s>" % self
 
     def copy(self):
-        return self.__class__([self.timestamp,
-                               self.ip_protocol,
-                               self.stream_number,
-                               self.src,
-                               self.dest,
-                               self.protocol,
-                               self.opcode,
-                               self.desc] + self.extra)
+        return self.__class__(self.timestamp,
+                              self.ip_protocol,
+                              self.stream_number,
+                              self.src,
+                              self.dest,
+                              self.protocol,
+                              self.opcode,
+                              self.desc,
+                              self.extra)
 
     def as_packet_type(self):
         t = '%s:%s' % (self.protocol, self.opcode)
@@ -272,13 +282,12 @@ class Packet(object):
             return False
 
         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
-        try:
-            fn = getattr(traffic_packets, fn_name)
-            if fn is traffic_packets.null_packet:
-                return False
-        except AttributeError:
+        fn = getattr(traffic_packets, fn_name, None)
+        if not fn:
             print("missing packet %s" % fn_name, file=sys.stderr)
             return False
+        if fn is traffic_packets.null_packet:
+            return False
         return True
 
 
@@ -331,7 +340,7 @@ class ReplayContext(object):
         self.last_netlogon_bad        = False
         self.last_samlogon_bad        = False
         self.generate_ldap_search_tables()
-        self.next_conversation_id = itertools.count().next
+        self.next_conversation_id = itertools.count()
 
     def generate_ldap_search_tables(self):
         session = system_session()
@@ -343,6 +352,7 @@ class ReplayContext(object):
 
         res = db.search(db.domain_dn(),
                         scope=ldb.SCOPE_SUBTREE,
+                        controls=["paged_results:1:1000"],
                         attrs=['dn'])
 
         # find a list of dns for each pattern
@@ -365,7 +375,7 @@ class ReplayContext(object):
         # for k, v in self.dn_map.items():
         #     print >>sys.stderr, k, len(v)
 
-        for k, v in dn_map.items():
+        for k in list(dn_map.keys()):
             if k[-3:] != ',DC':
                 continue
             p = k[:-3]
@@ -394,8 +404,8 @@ class ReplayContext(object):
                                      'conversation-%d' %
                                      conversation.conversation_id)
 
-        self.lp.set("private dir",     self.tempdir)
-        self.lp.set("lock dir",        self.tempdir)
+        self.lp.set("private dir", self.tempdir)
+        self.lp.set("lock dir", self.tempdir)
         self.lp.set("state directory", self.tempdir)
         self.lp.set("tls verify peer", "no_check")
 
@@ -455,6 +465,7 @@ class ReplayContext(object):
         self.user_creds.set_workstation(self.netbios_name)
         self.user_creds.set_password(self.userpass)
         self.user_creds.set_username(self.username)
+        self.user_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -509,9 +520,10 @@ class ReplayContext(object):
         self.machine_creds = Credentials()
         self.machine_creds.guess(self.lp)
         self.machine_creds.set_workstation(self.netbios_name)
-        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds.set_password(self.machinepass)
         self.machine_creds.set_username(self.netbios_name + "$")
+        self.machine_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -520,7 +532,7 @@ class ReplayContext(object):
         self.machine_creds_bad = Credentials()
         self.machine_creds_bad.guess(self.lp)
         self.machine_creds_bad.set_workstation(self.netbios_name)
-        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds_bad.set_password(self.machinepass[:-4])
         self.machine_creds_bad.set_username(self.netbios_name + "$")
         if self.prefer_kerberos:
@@ -643,6 +655,15 @@ class ReplayContext(object):
             return self.ldap_connections[-1]
 
         def simple_bind(creds):
+            """
+            To run simple bind against Windows, we need to run
+            following commands in PowerShell:
+
+                Install-windowsfeature ADCS-Cert-Authority
+                Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
+                Restart-Computer
+
+            """
             return SamDB('ldaps://%s' % self.server,
                          credentials=creds,
                          lp=self.lp)
@@ -669,7 +690,8 @@ class ReplayContext(object):
 
     def get_samr_context(self, new=False):
         if not self.samr_contexts or new:
-            self.samr_contexts.append(SamrContext(self.server))
+            self.samr_contexts.append(
+                SamrContext(self.server, lp=self.lp, creds=self.creds))
         return self.samr_contexts[-1]
 
     def get_netlogon_connection(self):
@@ -706,7 +728,7 @@ class ReplayContext(object):
 class SamrContext(object):
     """State/Context associated with a samr connection.
     """
-    def __init__(self, server):
+    def __init__(self, server, lp=None, creds=None):
         self.connection    = None
         self.handle        = None
         self.domain_handle = None
@@ -715,10 +737,16 @@ class SamrContext(object):
         self.user_handle   = None
         self.rids          = None
         self.server        = server
+        self.lp            = lp
+        self.creds         = creds
 
     def get_connection(self):
         if not self.connection:
-            self.connection = samr.samr("ncacn_ip_tcp:%s" % (self.server))
+            self.connection = samr.samr(
+                "ncacn_ip_tcp:%s[seal]" % (self.server),
+                lp_ctx=self.lp,
+                credentials=self.creds)
+
         return self.connection
 
     def get_handle(self):
@@ -774,23 +802,24 @@ class Conversation(object):
         if p.is_really_a_packet():
             self.packets.append(p)
 
-    def add_short_packet(self, timestamp, p, extra, client=True):
+    def add_short_packet(self, timestamp, protocol, opcode, extra,
+                         client=True):
         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
         (possibly empty) list of extra data. If client is True, assume
         this packet is from the client to the server.
         """
-        protocol, opcode = p.split(':', 1)
         src, dest = self.guess_client_server()
         if not client:
             src, dest = dest, src
-
-        desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
-        ip_protocol = IP_PROTOCOLS.get(protocol, '06')
-        fields = [timestamp - self.start_time, ip_protocol,
-                  '', src, dest,
-                  protocol, opcode, desc]
-        fields.extend(extra)
-        packet = Packet(fields)
+        key = (protocol, opcode)
+        desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
+        if protocol in IP_PROTOCOLS:
+            ip_protocol = IP_PROTOCOLS[protocol]
+        else:
+            ip_protocol = '06'
+        packet = Packet(timestamp - self.start_time, ip_protocol,
+                        '', src, dest,
+                        protocol, opcode, desc, extra)
         # XXX we're assuming the timestamp is already adjusted for
         # this conversation?
         # XXX should we adjust client balance for guessed packets?
@@ -853,7 +882,7 @@ class Conversation(object):
             gap = t - now
             print("gap is now %f" % gap, file=sys.stderr)
 
-        self.conversation_id = context.next_conversation_id()
+        self.conversation_id = next(context.next_conversation_id)
         pid = os.fork()
         if pid != 0:
             return pid
@@ -928,18 +957,8 @@ class Conversation(object):
         :param s: start of the window
         :param e: end of the window
         """
-
-        new_packets = []
-        for p in self.packets:
-            if p.timestamp < s or p.timestamp > e:
-                continue
-            new_packets.append(p)
-
-        self.packets = new_packets
-        if new_packets:
-            self.start_time = new_packets[0].timestamp
-        else:
-            self.start_time = None
+        self.packets = [p for p in self.packets if s <= p.timestamp <= e]
+        self.start_time = self.packets[0].timestamp if self.packets else None
 
     def renormalise_times(self, start_time):
         """Adjust the packet start times relative to the new start time."""
@@ -1012,7 +1031,7 @@ def ingest_summaries(files, dns_mode='count'):
             f = open(f)
         print("Ingesting %s" % (f.name,), file=sys.stderr)
         for line in f:
-            p = Packet(line)
+            p = Packet.from_line(line)
             if p.protocol == 'dns' and dns_mode != 'include':
                 dns_counts[p.opcode] += 1
             else:
@@ -1210,7 +1229,7 @@ class TrafficModel(object):
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, p, extra)
+                c.add_short_packet(timestamp, protocol, opcode, extra)
 
             key = key[1:] + (p,)
 
@@ -1248,7 +1267,7 @@ class TrafficModel(object):
             client += 1
 
         print(("we have %d conversations at rate %f" %
-                              (len(conversations), rate)), file=sys.stderr)
+               (len(conversations), rate)), file=sys.stderr)
         conversations.sort()
         return conversations
 
@@ -1398,9 +1417,9 @@ def replay(conversations,
         print(("we have %d accounts but %d conversations" %
                (accounts, conversations)), file=sys.stderr)
 
-    cstack = zip(sorted(conversations,
-                        key=lambda x: x.start_time, reverse=True),
-                 accounts)
+    cstack = list(zip(
+        sorted(conversations, key=lambda x: x.start_time, reverse=True),
+        accounts))
 
     # Set the process group so that the calling scripts are not killed
     # when the forked child processes are killed.
@@ -1481,7 +1500,7 @@ def replay(conversations,
     finally:
         for s in (15, 15, 9):
             print(("killing %d children with -%d" %
-                                 (len(children), s)), file=sys.stderr)
+                   (len(children), s)), file=sys.stderr)
             for pid in children:
                 try:
                     os.kill(pid, s)
@@ -1546,18 +1565,18 @@ def create_ou(ldb, instance_id):
     """
     ou = ou_name(ldb, instance_id)
     try:
-        ldb.add({"dn":          ou.split(',', 1)[1],
+        ldb.add({"dn": ou.split(',', 1)[1],
                  "objectclass": "organizationalunit"})
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore already exists
         if status != 68:
             raise
     try:
-        ldb.add({"dn":          ou,
+        ldb.add({"dn": ou,
                  "objectclass": "organizationalunit"})
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore already exists
         if status != 68:
             raise
@@ -1607,7 +1626,7 @@ def generate_traffic_accounts(ldb, instance_id, number, password):
             create_machine_account(ldb, instance_id, netbios_name, password)
             added += 1
         except LdbError as e:
-            (status, _) = e
+            (status, _) = e.args
             if status == 68:
                 break
             else:
@@ -1623,7 +1642,7 @@ def generate_traffic_accounts(ldb, instance_id, number, password):
             create_user_account(ldb, instance_id, username, password)
             added += 1
         except LdbError as e:
-            (status, _) = e
+            (status, _) = e.args
             if status == 68:
                 break
             else:
@@ -1648,7 +1667,7 @@ def create_machine_account(ldb, instance_id, netbios_name, machinepass):
         "objectclass": "computer",
         "sAMAccountName": "%s$" % netbios_name,
         "userAccountControl":
-        str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
+            str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
         "unicodePwd": utf16pw})
     end = time.time()
     duration = end - start
@@ -1670,6 +1689,11 @@ def create_user_account(ldb, instance_id, username, userpass):
         "userAccountControl": str(UF_NORMAL_ACCOUNT),
         "unicodePwd": utf16pw
     })
+
+    # grant user write permission to do things like write account SPN
+    sdutils = sd_utils.SDUtils(ldb)
+    sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
+
     end = time.time()
     duration = end - start
     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
@@ -1684,6 +1708,7 @@ def create_group(ldb, instance_id, name):
     ldb.add({
         "dn": dn,
         "objectclass": "group",
+        "sAMAccountName": name,
     })
     end = time.time()
     duration = end - start
@@ -1704,7 +1729,7 @@ def generate_users(ldb, instance_id, number, password):
             create_user_account(ldb, instance_id, username, password)
             users += 1
         except LdbError as e:
-            (status, _) = e
+            (status, _) = e.args
             # Stop if entry exists
             if status == 68:
                 break
@@ -1728,7 +1753,7 @@ def generate_groups(ldb, instance_id, number):
             create_group(ldb, instance_id, name)
             groups += 1
         except LdbError as e:
-            (status, _) = e
+            (status, _) = e.args
             # Stop if entry exists
             if status == 68:
                 break
@@ -1743,7 +1768,7 @@ def clean_up_accounts(ldb, instance_id):
     try:
         ldb.delete(ou, ["tree_delete:1"])
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore does not exist
         if status != 32:
             raise
@@ -1939,25 +1964,16 @@ def generate_stats(statsdir, timing_file):
     else:
         failure_rate = failed / duration
 
-    # print the stats in more human-readable format when stdout is going to the
-    # console (as opposed to being redirected to a file)
-    if sys.stdout.isatty():
-        print("Total conversations:   %10d" % conversations)
-        print("Successful operations: %10d (%.3f per second)"
-              % (successful, success_rate))
-        print("Failed operations:     %10d (%.3f per second)"
-              % (failed, failure_rate))
-    else:
-        print("(%d, %d, %d, %.3f, %.3f)" %
-              (conversations, successful, failed, success_rate, failure_rate))
+    print("Total conversations:   %10d" % conversations)
+    print("Successful operations: %10d (%.3f per second)"
+          % (successful, success_rate))
+    print("Failed operations:     %10d (%.3f per second)"
+          % (failed, failure_rate))
+
+    print("Protocol    Op Code  Description                               "
+          " Count       Failed         Mean       Median          "
+          "95%        Range          Max")
 
-    if sys.stdout.isatty():
-        print("Protocol    Op Code  Description                               "
-              " Count       Failed         Mean       Median          "
-              "95%        Range          Max")
-    else:
-        print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
-              "\tmax")
     protocols = sorted(latencies.keys())
     for protocol in protocols:
         packet_types = sorted(latencies[protocol], key=opcode_key)