traffic: new version of model with packet_rate, version number
[samba.git] / python / samba / emulate / traffic.py
index 291162f279ac066757f1d4459df360759604ef9c..807fa8244e2c13836bb6bc46387eca0116f2859c 100644 (file)
@@ -25,9 +25,8 @@ import json
 import math
 import sys
 import signal
-import itertools
 
-from collections import OrderedDict, Counter, defaultdict
+from collections import OrderedDict, Counter, defaultdict, namedtuple
 from samba.emulate import traffic_packets
 from samba.samdb import SamDB
 import ldb
@@ -55,6 +54,8 @@ from samba.compat import get_string
 from samba.logger import get_samba_logger
 import bisect
 
+CURRENT_MODEL_VERSION = 2   # save as this
+REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
 SLEEP_OVERHEAD = 3e-4
 
 # we don't use None, because it complicates [de]serialisation
@@ -126,14 +127,26 @@ def debug_lineno(*args):
     sys.stderr.flush()
 
 
-def random_colour_print():
-    """Return a function that prints a randomly coloured line to stderr"""
-    n = 18 + random.randrange(214)
-    prefix = "\033[38;5;%dm" % n
-
-    def p(*args):
-        for a in args:
-            print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+def random_colour_print(seeds):
+    """Return a function that prints a coloured line to stderr. The colour
+    of the line depends on a sort of hash of the integer arguments."""
+    if seeds:
+        s = 214
+        for x in seeds:
+            s += 17
+            s *= x
+            s %= 214
+        prefix = "\033[38;5;%dm" % (18 + s)
+
+        def p(*args):
+            if DEBUG_LEVEL > 0:
+                for a in args:
+                    print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+    else:
+        def p(*args):
+            if DEBUG_LEVEL > 0:
+                for a in args:
+                    print(a, file=sys.stderr)
 
     return p
 
@@ -144,9 +157,18 @@ class FakePacketError(Exception):
 
 class Packet(object):
     """Details of a network packet"""
+    __slots__ = ('timestamp',
+                 'ip_protocol',
+                 'stream_number',
+                 'src',
+                 'dest',
+                 'protocol',
+                 'opcode',
+                 'desc',
+                 'extra',
+                 'endpoints')
     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
                  protocol, opcode, desc, extra):
-
         self.timestamp = timestamp
         self.ip_protocol = ip_protocol
         self.stream_number = stream_number
@@ -162,7 +184,7 @@ class Packet(object):
             self.endpoints = (self.dest, self.src)
 
     @classmethod
-    def from_line(self, line):
+    def from_line(cls, line):
         fields = line.rstrip('\n').split('\t')
         (timestamp,
          ip_protocol,
@@ -178,8 +200,8 @@ class Packet(object):
         src = int(src)
         dest = int(dest)
 
-        return Packet(timestamp, ip_protocol, stream_number, src, dest,
-                      protocol, opcode, desc, extra)
+        return cls(timestamp, ip_protocol, stream_number, src, dest,
+                   protocol, opcode, desc, extra)
 
     def as_summary(self, time_offset=0.0):
         """Format the packet as a traffic_summary line.
@@ -276,32 +298,38 @@ class Packet(object):
         return self.timestamp - other.timestamp
 
     def is_really_a_packet(self, missing_packet_stats=None):
-        """Is the packet one that can be ignored?
+        return is_a_real_packet(self.protocol, self.opcode)
 
-        If so removing it will have no effect on the replay
-        """
-        if self.protocol in SKIPPED_PROTOCOLS:
-            # Ignore any packets for the protocols we're not interested in.
-            return False
-        if self.protocol == "ldap" and self.opcode == '':
-            # skip ldap continuation packets
-            return False
 
-        fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
-        fn = getattr(traffic_packets, fn_name, None)
-        if not fn:
-            print("missing packet %s" % fn_name, file=sys.stderr)
-            return False
-        if fn is traffic_packets.null_packet:
-            return False
-        return True
+def is_a_real_packet(protocol, opcode):
+    """Is the packet one that can be ignored?
+
+    If so removing it will have no effect on the replay
+    """
+    if protocol in SKIPPED_PROTOCOLS:
+        # Ignore any packets for the protocols we're not interested in.
+        return False
+    if protocol == "ldap" and opcode == '':
+        # skip ldap continuation packets
+        return False
+
+    fn_name = 'packet_%s_%s' % (protocol, opcode)
+    fn = getattr(traffic_packets, fn_name, None)
+    if fn is None:
+        LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+        return False
+    if fn is traffic_packets.null_packet:
+        return False
+    return True
 
 
 class ReplayContext(object):
-    """State/Context for an individual conversation between an simulated client
-       and a server.
+    """State/Context for a conversation between an simulated client and a
+       server. Some of the context is shared amongst all conversations
+       and should be generated before the fork, while other context is
+       specific to a particular conversation and should be generated
+       *after* the fork, in generate_process_local_config().
     """
-
     def __init__(self,
                  server=None,
                  lp=None,
@@ -316,13 +344,6 @@ class ReplayContext(object):
                  domain_sid=None):
 
         self.server                   = server
-        self.ldap_connections         = []
-        self.dcerpc_connections       = []
-        self.lsarpc_connections       = []
-        self.lsarpc_connections_named = []
-        self.drsuapi_connections      = []
-        self.srvsvc_connections       = []
-        self.samr_contexts            = []
         self.netlogon_connection      = None
         self.creds                    = creds
         self.lp                       = lp
@@ -346,7 +367,6 @@ class ReplayContext(object):
         self.last_netlogon_bad        = False
         self.last_samlogon_bad        = False
         self.generate_ldap_search_tables()
-        self.next_conversation_id = itertools.count()
 
     def generate_ldap_search_tables(self):
         session = system_session()
@@ -399,8 +419,13 @@ class ReplayContext(object):
         self.attribute_clue_map = attribute_clue_map
 
     def generate_process_local_config(self, account, conversation):
-        if account is None:
-            return
+        self.ldap_connections         = []
+        self.dcerpc_connections       = []
+        self.lsarpc_connections       = []
+        self.lsarpc_connections_named = []
+        self.drsuapi_connections      = []
+        self.srvsvc_connections       = []
+        self.samr_contexts            = []
         self.netbios_name             = account.netbios_name
         self.machinepass              = account.machinepass
         self.username                 = account.username
@@ -764,14 +789,16 @@ class SamrContext(object):
 
 class Conversation(object):
     """Details of a converation between a simulated client and a server."""
-    conversation_id = None
-
-    def __init__(self, start_time=None, endpoints=None):
+    def __init__(self, start_time=None, endpoints=None, seq=(),
+                 conversation_id=None):
         self.start_time = start_time
         self.endpoints = endpoints
         self.packets = []
-        self.msg = random_colour_print()
+        self.msg = random_colour_print(endpoints)
         self.client_balance = 0.0
+        self.conversation_id = conversation_id
+        for p in seq:
+            self.add_short_packet(*p)
 
     def __cmp__(self, other):
         if self.start_time is None:
@@ -1053,11 +1080,11 @@ def ingest_summaries(files, dns_mode='count'):
 
     print("gathering packets into conversations", file=sys.stderr)
     conversations = OrderedDict()
-    for p in packets:
+    for i, p in enumerate(packets):
         p.timestamp -= start_time
         c = conversations.get(p.endpoints)
         if c is None:
-            c = Conversation()
+            c = Conversation(conversation_id=(i + 2))
             conversations[p.endpoints] = c
         c.add_packet(p)
 
@@ -1111,7 +1138,7 @@ class TrafficModel(object):
         self.n = n
         self.dns_opcounts = defaultdict(int)
         self.cumulative_duration = 0.0
-        self.conversation_rate = [0, 1]
+        self.packet_rate = [0, 1]
 
     def learn(self, conversations, dns_opcounts={}):
         prev = 0.0
@@ -1124,10 +1151,15 @@ class TrafficModel(object):
             self.dns_opcounts[k] += v
 
         if len(conversations) > 1:
-            elapsed =\
-                conversations[-1].start_time - conversations[0].start_time
-            self.conversation_rate[0] = len(conversations)
-            self.conversation_rate[1] = elapsed
+            first = conversations[0].start_time
+            total = 0
+            last = first + 0.1
+            for c in conversations:
+                total += len(c)
+                last = max(last, c.packets[-1].timestamp)
+
+            self.packet_rate[0] = total
+            self.packet_rate[1] = last - first
 
         for c in conversations:
             client, server = c.guess_client_server(server)
@@ -1171,7 +1203,8 @@ class TrafficModel(object):
             'ngrams': ngrams,
             'query_details': query_details,
             'cumulative_duration': self.cumulative_duration,
-            'conversation_rate': self.conversation_rate,
+            'packet_rate': self.packet_rate,
+            'version': CURRENT_MODEL_VERSION
         }
         d['dns'] = self.dns_opcounts
 
@@ -1186,11 +1219,23 @@ class TrafficModel(object):
 
         d = json.load(f)
 
+        try:
+            version = d["version"]
+            if version < REQUIRED_MODEL_VERSION:
+                raise ValueError("the model file is version %d; "
+                                 "version %d is required" %
+                                 (version, REQUIRED_MODEL_VERSION))
+        except KeyError:
+                raise ValueError("the model file lacks a version number; "
+                                 "version %d is required" %
+                                 (REQUIRED_MODEL_VERSION))
+
         for k, v in d['ngrams'].items():
             k = tuple(str(k).split('\t'))
             values = self.ngrams.setdefault(k, [])
             for p, count in v.items():
                 values.extend([str(p)] * count)
+            values.sort()
 
         for k, v in d['query_details'].items():
             values = self.query_details.setdefault(str(k), [])
@@ -1199,26 +1244,32 @@ class TrafficModel(object):
                     values.extend([()] * count)
                 else:
                     values.extend([tuple(str(p).split('\t'))] * count)
+            values.sort()
 
         if 'dns' in d:
             for k, v in d['dns'].items():
                 self.dns_opcounts[k] += v
 
         self.cumulative_duration = d['cumulative_duration']
-        self.conversation_rate = d['conversation_rate']
-
-    def construct_conversation(self, timestamp=0.0, client=2, server=1,
-                               hard_stop=None, packet_rate=1):
-        """Construct a individual converation from the model."""
-
-        c = Conversation(timestamp, (server, client))
-
+        self.packet_rate = d['packet_rate']
+
+    def construct_conversation_sequence(self, timestamp=0.0,
+                                        hard_stop=None,
+                                        replay_speed=1,
+                                        ignore_before=0):
+        """Construct an individual conversation packet sequence from the
+        model.
+        """
+        c = []
         key = (NON_PACKET,) * (self.n - 1)
+        if ignore_before is None:
+            ignore_before = timestamp - 1
 
-        while key in self.ngrams:
-            p = random.choice(self.ngrams.get(key, NON_PACKET))
+        while True:
+            p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
             if p == NON_PACKET:
                 break
+
             if p in self.query_details:
                 extra = random.choice(self.query_details[p])
             else:
@@ -1227,55 +1278,58 @@ class TrafficModel(object):
             protocol, opcode = p.split(':', 1)
             if protocol == 'wait':
                 log_wait_time = int(opcode) + random.random()
-                wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
+                wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
                 timestamp += wait
             else:
                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
-                wait = math.exp(log_wait) / packet_rate
+                wait = math.exp(log_wait) / replay_speed
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, protocol, opcode, extra)
+                if timestamp >= ignore_before:
+                    c.append((timestamp, protocol, opcode, extra))
 
             key = key[1:] + (p,)
 
         return c
 
-    def generate_conversations(self, rate, duration, packet_rate=1):
+    def generate_conversations(self, scale, duration, replay_speed=1,
+                               server=1, client=2):
         """Generate a list of conversations from the model."""
 
-        # We run the simulation for at least ten times as long as our
-        # desired duration, and take a section near the start.
-        rate_n, rate_t  = self.conversation_rate
+        # We run the simulation for ten times as long as our desired
+        # duration, and take the section at the end.
+        lead_in = 9 * duration
+        rate_n, rate_t  = self.packet_rate
+        target_packets = int(duration * scale * rate_n / rate_t)
 
-        duration2 = max(rate_t, duration * 2)
-        n = rate * duration2 * rate_n / rate_t
+        conversations = []
+        n_packets = 0
+
+        while n_packets < target_packets:
+            start = random.uniform(-lead_in, duration)
+            c = self.construct_conversation_sequence(start,
+                                                     hard_stop=duration,
+                                                     replay_speed=replay_speed,
+                                                     ignore_before=0)
+            conversations.append(c)
+            n_packets += len(c)
+
+        print(("we have %d packets (target %d) in %d conversations at scale %f"
+               % (n_packets, target_packets, len(conversations), scale)),
+              file=sys.stderr)
+        conversations.sort()  # sorts by first element == start time
+        return seq_to_conversations(conversations)
 
-        server = 1
-        client = 2
 
-        conversations = []
-        end = duration2
-        start = end - duration
-
-        while client < n + 2:
-            start = random.uniform(0, duration2)
-            c = self.construct_conversation(start,
-                                            client,
-                                            server,
-                                            hard_stop=(duration2 * 5),
-                                            packet_rate=packet_rate)
-
-            c.forget_packets_outside_window(start, end)
-            c.renormalise_times(start)
-            if len(c) != 0:
-                conversations.append(c)
+def seq_to_conversations(seq, server=1, client=2):
+    conversations = []
+    for s in seq:
+        if s:
+            c = Conversation(s[0][0], (server, client), s)
             client += 1
-
-        print(("we have %d conversations at rate %f" %
-               (len(conversations), rate)), file=sys.stderr)
-        conversations.sort()
-        return conversations
+            conversations.append(c)
+    return conversations
 
 
 IP_PROTOCOLS = {
@@ -1590,14 +1644,15 @@ def create_ou(ldb, instance_id):
     return ou
 
 
-class ConversationAccounts(object):
-    """Details of the machine and user accounts associated with a conversation.
-    """
-    def __init__(self, netbios_name, machinepass, username, userpass):
-        self.netbios_name = netbios_name
-        self.machinepass  = machinepass
-        self.username     = username
-        self.userpass     = userpass
+# ConversationAccounts holds details of the machine and user accounts
+# associated with a conversation.
+#
+# We use a named tuple to reduce shared memory usage.
+ConversationAccounts = namedtuple('ConversationAccounts',
+                                  ('netbios_name',
+                                   'machinepass',
+                                   'username',
+                                   'userpass'))
 
 
 def generate_replay_accounts(ldb, instance_id, number, password):
@@ -2146,8 +2201,9 @@ def calc_percentile(values, percentile):
 
 
 def mk_masked_dir(*path):
-    """In a testenv we end up with 0777 diectories that look an alarming
+    """In a testenv we end up with 0777 directories that look an alarming
     green colour with ls. Use umask to avoid that."""
+    # py3 os.mkdir can do this
     d = os.path.join(*path)
     mask = os.umask(0o077)
     os.mkdir(d)