traffic: new version of model with packet_rate, version number
[samba.git] / python / samba / emulate / traffic.py
index 03a24f4161c931ec7f8a78a1e49cf599ac50ed84..807fa8244e2c13836bb6bc46387eca0116f2859c 100644 (file)
@@ -25,9 +25,8 @@ import json
 import math
 import sys
 import signal
-import itertools
 
-from collections import OrderedDict, Counter, defaultdict
+from collections import OrderedDict, Counter, defaultdict, namedtuple
 from samba.emulate import traffic_packets
 from samba.samdb import SamDB
 import ldb
@@ -42,11 +41,21 @@ from samba.drs_utils import drs_DsBind
 import traceback
 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
 from samba.auth import system_session
-from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
-from samba.dsdb import UF_NORMAL_ACCOUNT
-from samba.dcerpc.misc import SEC_CHAN_WKSTA
+from samba.dsdb import (
+    UF_NORMAL_ACCOUNT,
+    UF_SERVER_TRUST_ACCOUNT,
+    UF_TRUSTED_FOR_DELEGATION,
+    UF_WORKSTATION_TRUST_ACCOUNT
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
 from samba import gensec
+from samba import sd_utils
+from samba.compat import get_string
+from samba.logger import get_samba_logger
+import bisect
 
+CURRENT_MODEL_VERSION = 2   # save as this
+REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
 SLEEP_OVERHEAD = 3e-4
 
 # we don't use None, because it complicates [de]serialisation
@@ -84,6 +93,8 @@ NO_WAIT_LOG_TIME_RANGE = (-10, -3)
 # DEBUG_LEVEL can be changed by scripts with -d
 DEBUG_LEVEL = 0
 
+LOGGER = get_samba_logger(name=__name__)
+
 
 def debug(level, msg, *args):
     """Print a formatted debug message to standard error.
@@ -116,14 +127,26 @@ def debug_lineno(*args):
     sys.stderr.flush()
 
 
-def random_colour_print():
-    """Return a function that prints a randomly coloured line to stderr"""
-    n = 18 + random.randrange(214)
-    prefix = "\033[38;5;%dm" % n
-
-    def p(*args):
-        for a in args:
-            print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+def random_colour_print(seeds):
+    """Return a function that prints a coloured line to stderr. The colour
+    of the line depends on a sort of hash of the integer arguments."""
+    if seeds:
+        s = 214
+        for x in seeds:
+            s += 17
+            s *= x
+            s %= 214
+        prefix = "\033[38;5;%dm" % (18 + s)
+
+        def p(*args):
+            if DEBUG_LEVEL > 0:
+                for a in args:
+                    print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+    else:
+        def p(*args):
+            if DEBUG_LEVEL > 0:
+                for a in args:
+                    print(a, file=sys.stderr)
 
     return p
 
@@ -134,10 +157,35 @@ class FakePacketError(Exception):
 
 class Packet(object):
     """Details of a network packet"""
-    def __init__(self, fields):
-        if isinstance(fields, str):
-            fields = fields.rstrip('\n').split('\t')
+    __slots__ = ('timestamp',
+                 'ip_protocol',
+                 'stream_number',
+                 'src',
+                 'dest',
+                 'protocol',
+                 'opcode',
+                 'desc',
+                 'extra',
+                 'endpoints')
+    def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+                 protocol, opcode, desc, extra):
+        self.timestamp = timestamp
+        self.ip_protocol = ip_protocol
+        self.stream_number = stream_number
+        self.src = src
+        self.dest = dest
+        self.protocol = protocol
+        self.opcode = opcode
+        self.desc = desc
+        self.extra = extra
+        if self.src < self.dest:
+            self.endpoints = (self.src, self.dest)
+        else:
+            self.endpoints = (self.dest, self.src)
 
+    @classmethod
+    def from_line(cls, line):
+        fields = line.rstrip('\n').split('\t')
         (timestamp,
          ip_protocol,
          stream_number,
@@ -148,23 +196,12 @@ class Packet(object):
          desc) = fields[:8]
         extra = fields[8:]
 
-        self.timestamp = float(timestamp)
-        self.ip_protocol = ip_protocol
-        try:
-            self.stream_number = int(stream_number)
-        except (ValueError, TypeError):
-            self.stream_number = None
-        self.src = int(src)
-        self.dest = int(dest)
-        self.protocol = protocol
-        self.opcode = opcode
-        self.desc = desc
-        self.extra = extra
+        timestamp = float(timestamp)
+        src = int(src)
+        dest = int(dest)
 
-        if self.src < self.dest:
-            self.endpoints = (self.src, self.dest)
-        else:
-            self.endpoints = (self.dest, self.src)
+        return cls(timestamp, ip_protocol, stream_number, src, dest,
+                   protocol, opcode, desc, extra)
 
     def as_summary(self, time_offset=0.0):
         """Format the packet as a traffic_summary line.
@@ -192,14 +229,15 @@ class Packet(object):
         return "<Packet @%s>" % self
 
     def copy(self):
-        return self.__class__([self.timestamp,
-                               self.ip_protocol,
-                               self.stream_number,
-                               self.src,
-                               self.dest,
-                               self.protocol,
-                               self.opcode,
-                               self.desc] + self.extra)
+        return self.__class__(self.timestamp,
+                              self.ip_protocol,
+                              self.stream_number,
+                              self.src,
+                              self.dest,
+                              self.protocol,
+                              self.opcode,
+                              self.desc,
+                              self.extra)
 
     def as_packet_type(self):
         t = '%s:%s' % (self.protocol, self.opcode)
@@ -228,7 +266,7 @@ class Packet(object):
             fn = getattr(traffic_packets, fn_name)
 
         except AttributeError as e:
-            print("Conversation(%s) Missing handler %s" % \
+            print("Conversation(%s) Missing handler %s" %
                   (conversation.conversation_id, fn_name),
                   file=sys.stderr)
             return
@@ -260,33 +298,38 @@ class Packet(object):
         return self.timestamp - other.timestamp
 
     def is_really_a_packet(self, missing_packet_stats=None):
-        """Is the packet one that can be ignored?
+        return is_a_real_packet(self.protocol, self.opcode)
 
-        If so removing it will have no effect on the replay
-        """
-        if self.protocol in SKIPPED_PROTOCOLS:
-            # Ignore any packets for the protocols we're not interested in.
-            return False
-        if self.protocol == "ldap" and self.opcode == '':
-            # skip ldap continuation packets
-            return False
 
-        fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
-        try:
-            fn = getattr(traffic_packets, fn_name)
-            if fn is traffic_packets.null_packet:
-                return False
-        except AttributeError:
-            print("missing packet %s" % fn_name, file=sys.stderr)
-            return False
-        return True
+def is_a_real_packet(protocol, opcode):
+    """Is the packet one that can be ignored?
+
+    If so removing it will have no effect on the replay
+    """
+    if protocol in SKIPPED_PROTOCOLS:
+        # Ignore any packets for the protocols we're not interested in.
+        return False
+    if protocol == "ldap" and opcode == '':
+        # skip ldap continuation packets
+        return False
+
+    fn_name = 'packet_%s_%s' % (protocol, opcode)
+    fn = getattr(traffic_packets, fn_name, None)
+    if fn is None:
+        LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+        return False
+    if fn is traffic_packets.null_packet:
+        return False
+    return True
 
 
 class ReplayContext(object):
-    """State/Context for an individual conversation between an simulated client
-       and a server.
+    """State/Context for a conversation between an simulated client and a
+       server. Some of the context is shared amongst all conversations
+       and should be generated before the fork, while other context is
+       specific to a particular conversation and should be generated
+       *after* the fork, in generate_process_local_config().
     """
-
     def __init__(self,
                  server=None,
                  lp=None,
@@ -301,13 +344,6 @@ class ReplayContext(object):
                  domain_sid=None):
 
         self.server                   = server
-        self.ldap_connections         = []
-        self.dcerpc_connections       = []
-        self.lsarpc_connections       = []
-        self.lsarpc_connections_named = []
-        self.drsuapi_connections      = []
-        self.srvsvc_connections       = []
-        self.samr_contexts            = []
         self.netlogon_connection      = None
         self.creds                    = creds
         self.lp                       = lp
@@ -331,7 +367,6 @@ class ReplayContext(object):
         self.last_netlogon_bad        = False
         self.last_samlogon_bad        = False
         self.generate_ldap_search_tables()
-        self.next_conversation_id = itertools.count().next
 
     def generate_ldap_search_tables(self):
         session = system_session()
@@ -366,7 +401,7 @@ class ReplayContext(object):
         # for k, v in self.dn_map.items():
         #     print >>sys.stderr, k, len(v)
 
-        for k, v in dn_map.items():
+        for k in list(dn_map.keys()):
             if k[-3:] != ',DC':
                 continue
             p = k[:-3]
@@ -384,8 +419,13 @@ class ReplayContext(object):
         self.attribute_clue_map = attribute_clue_map
 
     def generate_process_local_config(self, account, conversation):
-        if account is None:
-            return
+        self.ldap_connections         = []
+        self.dcerpc_connections       = []
+        self.lsarpc_connections       = []
+        self.lsarpc_connections_named = []
+        self.drsuapi_connections      = []
+        self.srvsvc_connections       = []
+        self.samr_contexts            = []
         self.netbios_name             = account.netbios_name
         self.machinepass              = account.machinepass
         self.username                 = account.username
@@ -395,8 +435,8 @@ class ReplayContext(object):
                                      'conversation-%d' %
                                      conversation.conversation_id)
 
-        self.lp.set("private dir",     self.tempdir)
-        self.lp.set("lock dir",        self.tempdir)
+        self.lp.set("private dir", self.tempdir)
+        self.lp.set("lock dir", self.tempdir)
         self.lp.set("state directory", self.tempdir)
         self.lp.set("tls verify peer", "no_check")
 
@@ -427,8 +467,8 @@ class ReplayContext(object):
            than that requested, but not significantly.
         """
         if not failed_last_time:
-            if (self.badpassword_frequency > 0 and
-               random.random() < self.badpassword_frequency):
+            if (self.badpassword_frequency and self.badpassword_frequency > 0
+                and random.random() < self.badpassword_frequency):
                 try:
                     f(bad)
                 except:
@@ -456,6 +496,7 @@ class ReplayContext(object):
         self.user_creds.set_workstation(self.netbios_name)
         self.user_creds.set_password(self.userpass)
         self.user_creds.set_username(self.username)
+        self.user_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -510,9 +551,10 @@ class ReplayContext(object):
         self.machine_creds = Credentials()
         self.machine_creds.guess(self.lp)
         self.machine_creds.set_workstation(self.netbios_name)
-        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds.set_password(self.machinepass)
         self.machine_creds.set_username(self.netbios_name + "$")
+        self.machine_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -521,7 +563,7 @@ class ReplayContext(object):
         self.machine_creds_bad = Credentials()
         self.machine_creds_bad.guess(self.lp)
         self.machine_creds_bad.set_workstation(self.netbios_name)
-        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds_bad.set_password(self.machinepass[:-4])
         self.machine_creds_bad.set_username(self.netbios_name + "$")
         if self.prefer_kerberos:
@@ -707,7 +749,7 @@ class ReplayContext(object):
     def get_authenticator(self):
         auth = self.machine_creds.new_client_authenticator()
         current  = netr_Authenticator()
-        current.cred.data = [ord(x) for x in auth["credential"]]
+        current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
         current.timestamp = auth["timestamp"]
 
         subsequent = netr_Authenticator()
@@ -747,14 +789,16 @@ class SamrContext(object):
 
 class Conversation(object):
     """Details of a converation between a simulated client and a server."""
-    conversation_id = None
-
-    def __init__(self, start_time=None, endpoints=None):
+    def __init__(self, start_time=None, endpoints=None, seq=(),
+                 conversation_id=None):
         self.start_time = start_time
         self.endpoints = endpoints
         self.packets = []
-        self.msg = random_colour_print()
+        self.msg = random_colour_print(endpoints)
         self.client_balance = 0.0
+        self.conversation_id = conversation_id
+        for p in seq:
+            self.add_short_packet(*p)
 
     def __cmp__(self, other):
         if self.start_time is None:
@@ -791,23 +835,24 @@ class Conversation(object):
         if p.is_really_a_packet():
             self.packets.append(p)
 
-    def add_short_packet(self, timestamp, p, extra, client=True):
+    def add_short_packet(self, timestamp, protocol, opcode, extra,
+                         client=True):
         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
         (possibly empty) list of extra data. If client is True, assume
         this packet is from the client to the server.
         """
-        protocol, opcode = p.split(':', 1)
         src, dest = self.guess_client_server()
         if not client:
             src, dest = dest, src
-
-        desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
-        ip_protocol = IP_PROTOCOLS.get(protocol, '06')
-        fields = [timestamp - self.start_time, ip_protocol,
-                  '', src, dest,
-                  protocol, opcode, desc]
-        fields.extend(extra)
-        packet = Packet(fields)
+        key = (protocol, opcode)
+        desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
+        if protocol in IP_PROTOCOLS:
+            ip_protocol = IP_PROTOCOLS[protocol]
+        else:
+            ip_protocol = '06'
+        packet = Packet(timestamp - self.start_time, ip_protocol,
+                        '', src, dest,
+                        protocol, opcode, desc, extra)
         # XXX we're assuming the timestamp is already adjusted for
         # this conversation?
         # XXX should we adjust client balance for guessed packets?
@@ -870,7 +915,7 @@ class Conversation(object):
             gap = t - now
             print("gap is now %f" % gap, file=sys.stderr)
 
-        self.conversation_id = context.next_conversation_id()
+        self.conversation_id = next(context.next_conversation_id)
         pid = os.fork()
         if pid != 0:
             return pid
@@ -945,18 +990,8 @@ class Conversation(object):
         :param s: start of the window
         :param e: end of the window
         """
-
-        new_packets = []
-        for p in self.packets:
-            if p.timestamp < s or p.timestamp > e:
-                continue
-            new_packets.append(p)
-
-        self.packets = new_packets
-        if new_packets:
-            self.start_time = new_packets[0].timestamp
-        else:
-            self.start_time = None
+        self.packets = [p for p in self.packets if s <= p.timestamp <= e]
+        self.start_time = self.packets[0].timestamp if self.packets else None
 
     def renormalise_times(self, start_time):
         """Adjust the packet start times relative to the new start time."""
@@ -1029,7 +1064,7 @@ def ingest_summaries(files, dns_mode='count'):
             f = open(f)
         print("Ingesting %s" % (f.name,), file=sys.stderr)
         for line in f:
-            p = Packet(line)
+            p = Packet.from_line(line)
             if p.protocol == 'dns' and dns_mode != 'include':
                 dns_counts[p.opcode] += 1
             else:
@@ -1045,11 +1080,11 @@ def ingest_summaries(files, dns_mode='count'):
 
     print("gathering packets into conversations", file=sys.stderr)
     conversations = OrderedDict()
-    for p in packets:
+    for i, p in enumerate(packets):
         p.timestamp -= start_time
         c = conversations.get(p.endpoints)
         if c is None:
-            c = Conversation()
+            c = Conversation(conversation_id=(i + 2))
             conversations[p.endpoints] = c
         c.add_packet(p)
 
@@ -1103,7 +1138,7 @@ class TrafficModel(object):
         self.n = n
         self.dns_opcounts = defaultdict(int)
         self.cumulative_duration = 0.0
-        self.conversation_rate = [0, 1]
+        self.packet_rate = [0, 1]
 
     def learn(self, conversations, dns_opcounts={}):
         prev = 0.0
@@ -1116,10 +1151,15 @@ class TrafficModel(object):
             self.dns_opcounts[k] += v
 
         if len(conversations) > 1:
-            elapsed =\
-                conversations[-1].start_time - conversations[0].start_time
-            self.conversation_rate[0] = len(conversations)
-            self.conversation_rate[1] = elapsed
+            first = conversations[0].start_time
+            total = 0
+            last = first + 0.1
+            for c in conversations:
+                total += len(c)
+                last = max(last, c.packets[-1].timestamp)
+
+            self.packet_rate[0] = total
+            self.packet_rate[1] = last - first
 
         for c in conversations:
             client, server = c.guess_client_server(server)
@@ -1163,7 +1203,8 @@ class TrafficModel(object):
             'ngrams': ngrams,
             'query_details': query_details,
             'cumulative_duration': self.cumulative_duration,
-            'conversation_rate': self.conversation_rate,
+            'packet_rate': self.packet_rate,
+            'version': CURRENT_MODEL_VERSION
         }
         d['dns'] = self.dns_opcounts
 
@@ -1178,11 +1219,23 @@ class TrafficModel(object):
 
         d = json.load(f)
 
+        try:
+            version = d["version"]
+            if version < REQUIRED_MODEL_VERSION:
+                raise ValueError("the model file is version %d; "
+                                 "version %d is required" %
+                                 (version, REQUIRED_MODEL_VERSION))
+        except KeyError:
+                raise ValueError("the model file lacks a version number; "
+                                 "version %d is required" %
+                                 (REQUIRED_MODEL_VERSION))
+
         for k, v in d['ngrams'].items():
             k = tuple(str(k).split('\t'))
             values = self.ngrams.setdefault(k, [])
             for p, count in v.items():
                 values.extend([str(p)] * count)
+            values.sort()
 
         for k, v in d['query_details'].items():
             values = self.query_details.setdefault(str(k), [])
@@ -1191,26 +1244,32 @@ class TrafficModel(object):
                     values.extend([()] * count)
                 else:
                     values.extend([tuple(str(p).split('\t'))] * count)
+            values.sort()
 
         if 'dns' in d:
             for k, v in d['dns'].items():
                 self.dns_opcounts[k] += v
 
         self.cumulative_duration = d['cumulative_duration']
-        self.conversation_rate = d['conversation_rate']
-
-    def construct_conversation(self, timestamp=0.0, client=2, server=1,
-                               hard_stop=None, packet_rate=1):
-        """Construct a individual converation from the model."""
-
-        c = Conversation(timestamp, (server, client))
-
+        self.packet_rate = d['packet_rate']
+
+    def construct_conversation_sequence(self, timestamp=0.0,
+                                        hard_stop=None,
+                                        replay_speed=1,
+                                        ignore_before=0):
+        """Construct an individual conversation packet sequence from the
+        model.
+        """
+        c = []
         key = (NON_PACKET,) * (self.n - 1)
+        if ignore_before is None:
+            ignore_before = timestamp - 1
 
-        while key in self.ngrams:
-            p = random.choice(self.ngrams.get(key, NON_PACKET))
+        while True:
+            p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
             if p == NON_PACKET:
                 break
+
             if p in self.query_details:
                 extra = random.choice(self.query_details[p])
             else:
@@ -1219,55 +1278,58 @@ class TrafficModel(object):
             protocol, opcode = p.split(':', 1)
             if protocol == 'wait':
                 log_wait_time = int(opcode) + random.random()
-                wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
+                wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
                 timestamp += wait
             else:
                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
-                wait = math.exp(log_wait) / packet_rate
+                wait = math.exp(log_wait) / replay_speed
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, p, extra)
+                if timestamp >= ignore_before:
+                    c.append((timestamp, protocol, opcode, extra))
 
             key = key[1:] + (p,)
 
         return c
 
-    def generate_conversations(self, rate, duration, packet_rate=1):
+    def generate_conversations(self, scale, duration, replay_speed=1,
+                               server=1, client=2):
         """Generate a list of conversations from the model."""
 
-        # We run the simulation for at least ten times as long as our
-        # desired duration, and take a section near the start.
-        rate_n, rate_t  = self.conversation_rate
+        # We run the simulation for ten times as long as our desired
+        # duration, and take the section at the end.
+        lead_in = 9 * duration
+        rate_n, rate_t  = self.packet_rate
+        target_packets = int(duration * scale * rate_n / rate_t)
 
-        duration2 = max(rate_t, duration * 2)
-        n = rate * duration2 * rate_n / rate_t
+        conversations = []
+        n_packets = 0
+
+        while n_packets < target_packets:
+            start = random.uniform(-lead_in, duration)
+            c = self.construct_conversation_sequence(start,
+                                                     hard_stop=duration,
+                                                     replay_speed=replay_speed,
+                                                     ignore_before=0)
+            conversations.append(c)
+            n_packets += len(c)
+
+        print(("we have %d packets (target %d) in %d conversations at scale %f"
+               % (n_packets, target_packets, len(conversations), scale)),
+              file=sys.stderr)
+        conversations.sort()  # sorts by first element == start time
+        return seq_to_conversations(conversations)
 
-        server = 1
-        client = 2
 
-        conversations = []
-        end = duration2
-        start = end - duration
-
-        while client < n + 2:
-            start = random.uniform(0, duration2)
-            c = self.construct_conversation(start,
-                                            client,
-                                            server,
-                                            hard_stop=(duration2 * 5),
-                                            packet_rate=packet_rate)
-
-            c.forget_packets_outside_window(start, end)
-            c.renormalise_times(start)
-            if len(c) != 0:
-                conversations.append(c)
+def seq_to_conversations(seq, server=1, client=2):
+    conversations = []
+    for s in seq:
+        if s:
+            c = Conversation(s[0][0], (server, client), s)
             client += 1
-
-        print(("we have %d conversations at rate %f" %
-                              (len(conversations), rate)), file=sys.stderr)
-        conversations.sort()
-        return conversations
+            conversations.append(c)
+    return conversations
 
 
 IP_PROTOCOLS = {
@@ -1435,7 +1497,7 @@ def replay(conversations,
 
     end = start + duration
 
-    print("Replaying traffic for %u conversations over %d seconds"
+    LOGGER.info("Replaying traffic for %u conversations over %d seconds"
           % (len(conversations), duration))
 
     children = {}
@@ -1498,7 +1560,7 @@ def replay(conversations,
     finally:
         for s in (15, 15, 9):
             print(("killing %d children with -%d" %
-                                 (len(children), s)), file=sys.stderr)
+                   (len(children), s)), file=sys.stderr)
             for pid in children:
                 try:
                     os.kill(pid, s)
@@ -1545,6 +1607,7 @@ def openLdb(host, creds, lp):
     session = system_session()
     ldb = SamDB(url="ldap://%s" % host,
                 session_info=session,
+                options=['modules:paged_searches'],
                 credentials=creds,
                 lp=lp)
     return ldb
@@ -1563,42 +1626,42 @@ def create_ou(ldb, instance_id):
     """
     ou = ou_name(ldb, instance_id)
     try:
-        ldb.add({"dn":          ou.split(',', 1)[1],
+        ldb.add({"dn": ou.split(',', 1)[1],
                  "objectclass": "organizationalunit"})
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore already exists
         if status != 68:
             raise
     try:
-        ldb.add({"dn":          ou,
+        ldb.add({"dn": ou,
                  "objectclass": "organizationalunit"})
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore already exists
         if status != 68:
             raise
     return ou
 
 
-class ConversationAccounts(object):
-    """Details of the machine and user accounts associated with a conversation.
-    """
-    def __init__(self, netbios_name, machinepass, username, userpass):
-        self.netbios_name = netbios_name
-        self.machinepass  = machinepass
-        self.username     = username
-        self.userpass     = userpass
+# ConversationAccounts holds details of the machine and user accounts
+# associated with a conversation.
+#
+# We use a named tuple to reduce shared memory usage.
+ConversationAccounts = namedtuple('ConversationAccounts',
+                                  ('netbios_name',
+                                   'machinepass',
+                                   'username',
+                                   'userpass'))
 
 
 def generate_replay_accounts(ldb, instance_id, number, password):
     """Generate a series of unique machine and user account names."""
 
-    generate_traffic_accounts(ldb, instance_id, number, password)
     accounts = []
     for i in range(1, number + 1):
-        netbios_name = "STGM-%d-%d" % (instance_id, i)
-        username     = "STGU-%d-%d" % (instance_id, i)
+        netbios_name = machine_name(instance_id, i)
+        username = user_name(instance_id, i)
 
         account = ConversationAccounts(netbios_name, password, username,
                                        password)
@@ -1606,80 +1669,36 @@ def generate_replay_accounts(ldb, instance_id, number, password):
     return accounts
 
 
-def generate_traffic_accounts(ldb, instance_id, number, password):
-    """Create the specified number of user and machine accounts.
-
-    As accounts are not explicitly deleted between runs. This function starts
-    with the last account and iterates backwards stopping either when it
-    finds an already existing account or it has generated all the required
-    accounts.
-    """
-    print(("Generating machine and conversation accounts, "
-           "as required for %d conversations" % number),
-          file=sys.stderr)
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            netbios_name = "STGM-%d-%d" % (instance_id, i)
-            create_machine_account(ldb, instance_id, netbios_name, password)
-            added += 1
-        except LdbError as e:
-            (status, _) = e
-            if status == 68:
-                break
-            else:
-                raise
-    if added > 0:
-        print("Added %d new machine accounts" % added,
-              file=sys.stderr)
-
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            username = "STGU-%d-%d" % (instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
-            added += 1
-        except LdbError as e:
-            (status, _) = e
-            if status == 68:
-                break
-            else:
-                raise
-
-    if added > 0:
-        print("Added %d new user accounts" % added,
-              file=sys.stderr)
-
-
-def create_machine_account(ldb, instance_id, netbios_name, machinepass):
+def create_machine_account(ldb, instance_id, netbios_name, machinepass,
+                           traffic_account=True):
     """Create a machine account via ldap."""
 
     ou = ou_name(ldb, instance_id)
     dn = "cn=%s,%s" % (netbios_name, ou)
-    utf16pw = unicode(
-        '"' + machinepass.encode('utf-8') + '"', 'utf-8'
-    ).encode('utf-16-le')
-    start = time.time()
+    utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
+
+    if traffic_account:
+        # we set these bits for the machine account otherwise the replayed
+        # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
+        account_controls = str(UF_TRUSTED_FOR_DELEGATION |
+                               UF_SERVER_TRUST_ACCOUNT)
+
+    else:
+        account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
+
     ldb.add({
         "dn": dn,
         "objectclass": "computer",
         "sAMAccountName": "%s$" % netbios_name,
-        "userAccountControl":
-        str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
+        "userAccountControl": account_controls,
         "unicodePwd": utf16pw})
-    end = time.time()
-    duration = end - start
-    print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
 
 
 def create_user_account(ldb, instance_id, username, userpass):
     """Create a user account via ldap."""
     ou = ou_name(ldb, instance_id)
     user_dn = "cn=%s,%s" % (username, ou)
-    utf16pw = unicode(
-        '"' + userpass.encode('utf-8') + '"', 'utf-8'
-    ).encode('utf-16-le')
-    start = time.time()
+    utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
     ldb.add({
         "dn": user_dn,
         "objectclass": "user",
@@ -1687,9 +1706,10 @@ def create_user_account(ldb, instance_id, username, userpass):
         "userAccountControl": str(UF_NORMAL_ACCOUNT),
         "unicodePwd": utf16pw
     })
-    end = time.time()
-    duration = end - start
-    print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
+
+    # grant user write permission to do things like write account SPN
+    sdutils = sd_utils.SDUtils(ldb)
+    sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
 
 
 def create_group(ldb, instance_id, name):
@@ -1697,14 +1717,11 @@ def create_group(ldb, instance_id, name):
 
     ou = ou_name(ldb, instance_id)
     dn = "cn=%s,%s" % (name, ou)
-    start = time.time()
     ldb.add({
         "dn": dn,
         "objectclass": "group",
+        "sAMAccountName": name,
     })
-    end = time.time()
-    duration = end - start
-    print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
 
 
 def user_name(instance_id, i):
@@ -1712,25 +1729,62 @@ def user_name(instance_id, i):
     return "STGU-%d-%d" % (instance_id, i)
 
 
+def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
+    """Seach objectclass, return attr in a set"""
+    objs = ldb.search(
+        expression="(objectClass={})".format(objectclass),
+        attrs=[attr]
+    )
+    return {str(obj[attr]) for obj in objs}
+
+
 def generate_users(ldb, instance_id, number, password):
     """Add users to the server"""
+    existing_objects = search_objectclass(ldb, objectclass='user')
     users = 0
     for i in range(number, 0, -1):
-        try:
-            username = user_name(instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
+        name = user_name(instance_id, i)
+        if name not in existing_objects:
+            create_user_account(ldb, instance_id, name, password)
             users += 1
-        except LdbError as e:
-            (status, _) = e
-            # Stop if entry exists
-            if status == 68:
-                break
-            else:
-                raise
+            if users % 50 == 0:
+                LOGGER.info("Created %u/%u users" % (users, number))
 
     return users
 
 
+def machine_name(instance_id, i, traffic_account=True):
+    """Generate a machine account name from instance id."""
+    if traffic_account:
+        # traffic accounts correspond to a given user, and use different
+        # userAccountControl flags to ensure packets get processed correctly
+        # by the DC
+        return "STGM-%d-%d" % (instance_id, i)
+    else:
+        # Otherwise we're just generating computer accounts to simulate a
+        # semi-realistic network. These use the default computer
+        # userAccountControl flags, so we use a different account name so that
+        # we don't try to use them when generating packets
+        return "PC-%d-%d" % (instance_id, i)
+
+
+def generate_machine_accounts(ldb, instance_id, number, password,
+                              traffic_account=True):
+    """Add machine accounts to the server"""
+    existing_objects = search_objectclass(ldb, objectclass='computer')
+    added = 0
+    for i in range(number, 0, -1):
+        name = machine_name(instance_id, i, traffic_account)
+        if name + "$" not in existing_objects:
+            create_machine_account(ldb, instance_id, name, password,
+                                   traffic_account)
+            added += 1
+            if added % 50 == 0:
+                LOGGER.info("Created %u/%u machine accounts" % (added, number))
+
+    return added
+
+
 def group_name(instance_id, i):
     """Generate a group name from instance id."""
     return "STGG-%d-%d" % (instance_id, i)
@@ -1738,19 +1792,16 @@ def group_name(instance_id, i):
 
 def generate_groups(ldb, instance_id, number):
     """Create the required number of groups on the server."""
+    existing_objects = search_objectclass(ldb, objectclass='group')
     groups = 0
     for i in range(number, 0, -1):
-        try:
-            name = group_name(instance_id, i)
+        name = group_name(instance_id, i)
+        if name not in existing_objects:
             create_group(ldb, instance_id, name)
             groups += 1
-        except LdbError as e:
-            (status, _) = e
-            # Stop if entry exists
-            if status == 68:
-                break
-            else:
-                raise
+            if groups % 1000 == 0:
+                LOGGER.info("Created %u/%u groups" % (groups, number))
+
     return groups
 
 
@@ -1760,7 +1811,7 @@ def clean_up_accounts(ldb, instance_id):
     try:
         ldb.delete(ou, ["tree_delete:1"])
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore does not exist
         if status != 32:
             raise
@@ -1768,123 +1819,239 @@ def clean_up_accounts(ldb, instance_id):
 
 def generate_users_and_groups(ldb, instance_id, password,
                               number_of_users, number_of_groups,
-                              group_memberships):
+                              group_memberships, max_members,
+                              machine_accounts, traffic_accounts=True):
     """Generate the required users and groups, allocating the users to
        those groups."""
-    assignments = []
-    groups_added  = 0
+    memberships_added = 0
+    groups_added = 0
+    computers_added = 0
 
     create_ou(ldb, instance_id)
 
-    print("Generating dummy user accounts", file=sys.stderr)
+    LOGGER.info("Generating dummy user accounts")
     users_added = generate_users(ldb, instance_id, number_of_users, password)
 
+    LOGGER.info("Generating dummy machine accounts")
+    computers_added = generate_machine_accounts(ldb, instance_id,
+                                                machine_accounts, password,
+                                                traffic_accounts)
+
     if number_of_groups > 0:
-        print("Generating dummy groups", file=sys.stderr)
+        LOGGER.info("Generating dummy groups")
         groups_added = generate_groups(ldb, instance_id, number_of_groups)
 
     if group_memberships > 0:
-        print("Assigning users to groups", file=sys.stderr)
-        assignments = assign_groups(number_of_groups,
-                                    groups_added,
-                                    number_of_users,
-                                    users_added,
-                                    group_memberships)
-        print("Adding users to groups", file=sys.stderr)
+        LOGGER.info("Assigning users to groups")
+        assignments = GroupAssignments(number_of_groups,
+                                       groups_added,
+                                       number_of_users,
+                                       users_added,
+                                       group_memberships,
+                                       max_members)
+        LOGGER.info("Adding users to groups")
         add_users_to_groups(ldb, instance_id, assignments)
+        memberships_added = assignments.total()
 
     if (groups_added > 0 and users_added == 0 and
        number_of_groups != groups_added):
-        print("Warning: the added groups will contain no members",
-              file=sys.stderr)
-
-    print(("Added %d users, %d groups and %d group memberships" %
-           (users_added, groups_added, len(assignments))),
-          file=sys.stderr)
-
-
-def assign_groups(number_of_groups,
-                  groups_added,
-                  number_of_users,
-                  users_added,
-                  group_memberships):
-    """Allocate users to groups.
-
-    The intention is to have a few users that belong to most groups, while
-    the majority of users belong to a few groups.
-
-    A few groups will contain most users, with the remaining only having a
-    few users.
-    """
+        LOGGER.warning("The added groups will contain no members")
+
+    LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
+                (users_added, computers_added, groups_added,
+                 memberships_added))
+
+
+class GroupAssignments(object):
+    def __init__(self, number_of_groups, groups_added, number_of_users,
+                 users_added, group_memberships, max_members):
+
+        self.count = 0
+        self.generate_group_distribution(number_of_groups)
+        self.generate_user_distribution(number_of_users, group_memberships)
+        self.max_members = max_members
+        self.assignments = defaultdict(list)
+        self.assign_groups(number_of_groups, groups_added, number_of_users,
+                           users_added, group_memberships)
+
+    def cumulative_distribution(self, weights):
+        # make sure the probabilities conform to a cumulative distribution
+        # spread between 0.0 and 1.0. Dividing by the weighted total gives each
+        # probability a proportional share of 1.0. Higher probabilities get a
+        # bigger share, so are more likely to be picked. We use the cumulative
+        # value, so we can use random.random() as a simple index into the list
+        dist = []
+        total = sum(weights)
+        if total == 0:
+            return None
+
+        cumulative = 0.0
+        for probability in weights:
+            cumulative += probability
+            dist.append(cumulative / total)
+        return dist
 
-    def generate_user_distribution(n):
+    def generate_user_distribution(self, num_users, num_memberships):
         """Probability distribution of a user belonging to a group.
         """
-        dist = []
-        for x in range(1, n + 1):
-            p = 1 / (x + 0.001)
-            dist.append(p)
-        return dist
+        # Assign a weighted probability to each user. Use the Pareto
+        # Distribution so that some users are in a lot of groups, and the
+        # bulk of users are in only a few groups. If we're assigning a large
+        # number of group memberships, use a higher shape. This means slightly
+        # fewer outlying users that are in large numbers of groups. The aim is
+        # to have no users belonging to more than ~500 groups.
+        if num_memberships > 5000000:
+            shape = 3.0
+        elif num_memberships > 2000000:
+            shape = 2.5
+        elif num_memberships > 300000:
+            shape = 2.25
+        else:
+            shape = 1.75
+
+        weights = []
+        for x in range(1, num_users + 1):
+            p = random.paretovariate(shape)
+            weights.append(p)
+
+        # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.user_dist = self.cumulative_distribution(weights)
 
-    def generate_group_distribution(n):
+    def generate_group_distribution(self, n):
         """Probability distribution of a group containing a user."""
-        dist = []
+
+        # Assign a weighted probability to each user. Probability decreases
+        # as the group-ID increases
+        weights = []
         for x in range(1, n + 1):
             p = 1 / (x**1.3)
-            dist.append(p)
-        return dist
+            weights.append(p)
+
+        # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.group_weights = weights
+        self.group_dist = self.cumulative_distribution(weights)
+
+    def generate_random_membership(self):
+        """Returns a randomly generated user-group membership"""
+
+        # the list items are cumulative distribution values between 0.0 and
+        # 1.0, which makes random() a handy way to index the list to get a
+        # weighted random user/group. (Here the user/group returned are
+        # zero-based array indexes)
+        user = bisect.bisect(self.user_dist, random.random())
+        group = bisect.bisect(self.group_dist, random.random())
+
+        return user, group
+
+    def users_in_group(self, group):
+        return self.assignments[group]
+
+    def get_groups(self):
+        return self.assignments.keys()
+
+    def cap_group_membership(self, group, max_members):
+        """Prevent the group's membership from exceeding the max specified"""
+        num_members = len(self.assignments[group])
+        if num_members >= max_members:
+            LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+            # remove this group and then recalculate the cumulative
+            # distribution, so this group is no longer selected
+            self.group_weights[group - 1] = 0
+            new_dist = self.cumulative_distribution(self.group_weights)
+            self.group_dist = new_dist
+
+    def add_assignment(self, user, group):
+        # the assignments are stored in a dictionary where key=group,
+        # value=list-of-users-in-group (indexing by group-ID allows us to
+        # optimize for DB membership writes)
+        if user not in self.assignments[group]:
+            self.assignments[group].append(user)
+            self.count += 1
+
+        # check if there'a cap on how big the groups can grow
+        if self.max_members:
+            self.cap_group_membership(group, self.max_members)
+
+    def assign_groups(self, number_of_groups, groups_added,
+                      number_of_users, users_added, group_memberships):
+        """Allocate users to groups.
+
+        The intention is to have a few users that belong to most groups, while
+        the majority of users belong to a few groups.
+
+        A few groups will contain most users, with the remaining only having a
+        few users.
+        """
 
-    assignments = set()
-    if group_memberships <= 0:
-        return assignments
+        if group_memberships <= 0:
+            return
 
-    group_dist = generate_group_distribution(number_of_groups)
-    user_dist  = generate_user_distribution(number_of_users)
+        # Calculate the number of group menberships required
+        group_memberships = math.ceil(
+            float(group_memberships) *
+            (float(users_added) / float(number_of_users)))
 
-    # Calculate the number of group menberships required
-    group_memberships = math.ceil(
-        float(group_memberships) *
-        (float(users_added) / float(number_of_users)))
+        if self.max_members:
+            group_memberships = min(group_memberships,
+                                    self.max_members * number_of_groups)
 
-    existing_users  = number_of_users  - users_added  - 1
-    existing_groups = number_of_groups - groups_added - 1
-    while len(assignments) < group_memberships:
-        user        = random.randint(0, number_of_users - 1)
-        group       = random.randint(0, number_of_groups - 1)
-        probability = group_dist[group] * user_dist[user]
+        existing_users  = number_of_users  - users_added  - 1
+        existing_groups = number_of_groups - groups_added - 1
+        while self.total() < group_memberships:
+            user, group = self.generate_random_membership()
 
-        if ((random.random() < probability * 10000) and
-           (group > existing_groups or user > existing_users)):
-            # the + 1 converts the array index to the corresponding
-            # group or user number
-            assignments.add(((user + 1), (group + 1)))
+            if group > existing_groups or user > existing_users:
+                # the + 1 converts the array index to the corresponding
+                # group or user number
+                self.add_assignment(user + 1, group + 1)
 
-    return assignments
+    def total(self):
+        return self.count
 
 
 def add_users_to_groups(db, instance_id, assignments):
-    """Add users to their assigned groups.
+    """Takes the assignments of users to groups and applies them to the DB."""
+
+    total = assignments.total()
+    count = 0
+    added = 0
+
+    for group in assignments.get_groups():
+        users_in_group = assignments.users_in_group(group)
+        if len(users_in_group) == 0:
+            continue
+
+        # Split up the users into chunks, so we write no more than 1K at a
+        # time. (Minimizing the DB modifies is more efficient, but writing
+        # 10K+ users to a single group becomes inefficient memory-wise)
+        for chunk in range(0, len(users_in_group), 1000):
+            chunk_of_users = users_in_group[chunk:chunk + 1000]
+            add_group_members(db, instance_id, group, chunk_of_users)
 
-    Takes the list of (group,user) tuples generated by assign_groups and
-    assign the users to their specified groups."""
+            added += len(chunk_of_users)
+            count += 1
+            if count % 50 == 0:
+                LOGGER.info("Added %u/%u memberships" % (added, total))
+
+def add_group_members(db, instance_id, group, users_in_group):
+    """Adds the given users to group specified."""
 
     ou = ou_name(db, instance_id)
 
     def build_dn(name):
         return("cn=%s,%s" % (name, ou))
 
-    for (user, group) in assignments:
-        user_dn  = build_dn(user_name(instance_id, user))
-        group_dn = build_dn(group_name(instance_id, group))
+    group_dn = build_dn(group_name(instance_id, group))
+    m = ldb.Message()
+    m.dn = ldb.Dn(db, group_dn)
 
-        m = ldb.Message()
-        m.dn = ldb.Dn(db, group_dn)
-        m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
-        start = time.time()
-        db.modify(m)
-        end = time.time()
-        duration = end - start
-        print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
+    for user in users_in_group:
+        user_dn = build_dn(user_name(instance_id, user))
+        idx = "member-" + str(user)
+        m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
+
+    db.modify(m)
 
 
 def generate_stats(statsdir, timing_file):
@@ -1956,25 +2123,16 @@ def generate_stats(statsdir, timing_file):
     else:
         failure_rate = failed / duration
 
-    # print the stats in more human-readable format when stdout is going to the
-    # console (as opposed to being redirected to a file)
-    if sys.stdout.isatty():
-        print("Total conversations:   %10d" % conversations)
-        print("Successful operations: %10d (%.3f per second)"
-              % (successful, success_rate))
-        print("Failed operations:     %10d (%.3f per second)"
-              % (failed, failure_rate))
-    else:
-        print("(%d, %d, %d, %.3f, %.3f)" %
-              (conversations, successful, failed, success_rate, failure_rate))
+    print("Total conversations:   %10d" % conversations)
+    print("Successful operations: %10d (%.3f per second)"
+          % (successful, success_rate))
+    print("Failed operations:     %10d (%.3f per second)"
+          % (failed, failure_rate))
+
+    print("Protocol    Op Code  Description                               "
+          " Count       Failed         Mean       Median          "
+          "95%        Range          Max")
 
-    if sys.stdout.isatty():
-        print("Protocol    Op Code  Description                               "
-              " Count       Failed         Mean       Median          "
-              "95%        Range          Max")
-    else:
-        print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
-              "\tmax")
     protocols = sorted(latencies.keys())
     for protocol in protocols:
         packet_types = sorted(latencies[protocol], key=opcode_key)
@@ -2043,8 +2201,9 @@ def calc_percentile(values, percentile):
 
 
 def mk_masked_dir(*path):
-    """In a testenv we end up with 0777 diectories that look an alarming
+    """In a testenv we end up with 0777 directories that look an alarming
     green colour with ls. Use umask to avoid that."""
+    # py3 os.mkdir can do this
     d = os.path.join(*path)
     mask = os.umask(0o077)
     os.mkdir(d)