traffic: new version of model with packet_rate, version number
[samba.git] / python / samba / emulate / traffic.py
index 6595996e3146772201a4d9bd98dcd8f95eea45bc..807fa8244e2c13836bb6bc46387eca0116f2859c 100644 (file)
@@ -25,7 +25,6 @@ import json
 import math
 import sys
 import signal
-import itertools
 
 from collections import OrderedDict, Counter, defaultdict, namedtuple
 from samba.emulate import traffic_packets
@@ -55,6 +54,8 @@ from samba.compat import get_string
 from samba.logger import get_samba_logger
 import bisect
 
+CURRENT_MODEL_VERSION = 2   # save as this
+REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
 SLEEP_OVERHEAD = 3e-4
 
 # we don't use None, because it complicates [de]serialisation
@@ -156,9 +157,18 @@ class FakePacketError(Exception):
 
 class Packet(object):
     """Details of a network packet"""
+    __slots__ = ('timestamp',
+                 'ip_protocol',
+                 'stream_number',
+                 'src',
+                 'dest',
+                 'protocol',
+                 'opcode',
+                 'desc',
+                 'extra',
+                 'endpoints')
     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
                  protocol, opcode, desc, extra):
-
         self.timestamp = timestamp
         self.ip_protocol = ip_protocol
         self.stream_number = stream_number
@@ -288,32 +298,38 @@ class Packet(object):
         return self.timestamp - other.timestamp
 
     def is_really_a_packet(self, missing_packet_stats=None):
-        """Is the packet one that can be ignored?
+        return is_a_real_packet(self.protocol, self.opcode)
 
-        If so removing it will have no effect on the replay
-        """
-        if self.protocol in SKIPPED_PROTOCOLS:
-            # Ignore any packets for the protocols we're not interested in.
-            return False
-        if self.protocol == "ldap" and self.opcode == '':
-            # skip ldap continuation packets
-            return False
 
-        fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
-        fn = getattr(traffic_packets, fn_name, None)
-        if not fn:
-            print("missing packet %s" % fn_name, file=sys.stderr)
-            return False
-        if fn is traffic_packets.null_packet:
-            return False
-        return True
+def is_a_real_packet(protocol, opcode):
+    """Is the packet one that can be ignored?
+
+    If so removing it will have no effect on the replay
+    """
+    if protocol in SKIPPED_PROTOCOLS:
+        # Ignore any packets for the protocols we're not interested in.
+        return False
+    if protocol == "ldap" and opcode == '':
+        # skip ldap continuation packets
+        return False
+
+    fn_name = 'packet_%s_%s' % (protocol, opcode)
+    fn = getattr(traffic_packets, fn_name, None)
+    if fn is None:
+        LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+        return False
+    if fn is traffic_packets.null_packet:
+        return False
+    return True
 
 
 class ReplayContext(object):
-    """State/Context for an individual conversation between an simulated client
-       and a server.
+    """State/Context for a conversation between an simulated client and a
+       server. Some of the context is shared amongst all conversations
+       and should be generated before the fork, while other context is
+       specific to a particular conversation and should be generated
+       *after* the fork, in generate_process_local_config().
     """
-
     def __init__(self,
                  server=None,
                  lp=None,
@@ -328,13 +344,6 @@ class ReplayContext(object):
                  domain_sid=None):
 
         self.server                   = server
-        self.ldap_connections         = []
-        self.dcerpc_connections       = []
-        self.lsarpc_connections       = []
-        self.lsarpc_connections_named = []
-        self.drsuapi_connections      = []
-        self.srvsvc_connections       = []
-        self.samr_contexts            = []
         self.netlogon_connection      = None
         self.creds                    = creds
         self.lp                       = lp
@@ -358,7 +367,6 @@ class ReplayContext(object):
         self.last_netlogon_bad        = False
         self.last_samlogon_bad        = False
         self.generate_ldap_search_tables()
-        self.next_conversation_id = itertools.count()
 
     def generate_ldap_search_tables(self):
         session = system_session()
@@ -411,8 +419,13 @@ class ReplayContext(object):
         self.attribute_clue_map = attribute_clue_map
 
     def generate_process_local_config(self, account, conversation):
-        if account is None:
-            return
+        self.ldap_connections         = []
+        self.dcerpc_connections       = []
+        self.lsarpc_connections       = []
+        self.lsarpc_connections_named = []
+        self.drsuapi_connections      = []
+        self.srvsvc_connections       = []
+        self.samr_contexts            = []
         self.netbios_name             = account.netbios_name
         self.machinepass              = account.machinepass
         self.username                 = account.username
@@ -776,14 +789,16 @@ class SamrContext(object):
 
 class Conversation(object):
     """Details of a converation between a simulated client and a server."""
-    conversation_id = None
-
-    def __init__(self, start_time=None, endpoints=None):
+    def __init__(self, start_time=None, endpoints=None, seq=(),
+                 conversation_id=None):
         self.start_time = start_time
         self.endpoints = endpoints
         self.packets = []
         self.msg = random_colour_print(endpoints)
         self.client_balance = 0.0
+        self.conversation_id = conversation_id
+        for p in seq:
+            self.add_short_packet(*p)
 
     def __cmp__(self, other):
         if self.start_time is None:
@@ -1065,11 +1080,11 @@ def ingest_summaries(files, dns_mode='count'):
 
     print("gathering packets into conversations", file=sys.stderr)
     conversations = OrderedDict()
-    for p in packets:
+    for i, p in enumerate(packets):
         p.timestamp -= start_time
         c = conversations.get(p.endpoints)
         if c is None:
-            c = Conversation()
+            c = Conversation(conversation_id=(i + 2))
             conversations[p.endpoints] = c
         c.add_packet(p)
 
@@ -1123,7 +1138,7 @@ class TrafficModel(object):
         self.n = n
         self.dns_opcounts = defaultdict(int)
         self.cumulative_duration = 0.0
-        self.conversation_rate = [0, 1]
+        self.packet_rate = [0, 1]
 
     def learn(self, conversations, dns_opcounts={}):
         prev = 0.0
@@ -1136,10 +1151,15 @@ class TrafficModel(object):
             self.dns_opcounts[k] += v
 
         if len(conversations) > 1:
-            elapsed =\
-                conversations[-1].start_time - conversations[0].start_time
-            self.conversation_rate[0] = len(conversations)
-            self.conversation_rate[1] = elapsed
+            first = conversations[0].start_time
+            total = 0
+            last = first + 0.1
+            for c in conversations:
+                total += len(c)
+                last = max(last, c.packets[-1].timestamp)
+
+            self.packet_rate[0] = total
+            self.packet_rate[1] = last - first
 
         for c in conversations:
             client, server = c.guess_client_server(server)
@@ -1183,7 +1203,8 @@ class TrafficModel(object):
             'ngrams': ngrams,
             'query_details': query_details,
             'cumulative_duration': self.cumulative_duration,
-            'conversation_rate': self.conversation_rate,
+            'packet_rate': self.packet_rate,
+            'version': CURRENT_MODEL_VERSION
         }
         d['dns'] = self.dns_opcounts
 
@@ -1198,11 +1219,23 @@ class TrafficModel(object):
 
         d = json.load(f)
 
+        try:
+            version = d["version"]
+            if version < REQUIRED_MODEL_VERSION:
+                raise ValueError("the model file is version %d; "
+                                 "version %d is required" %
+                                 (version, REQUIRED_MODEL_VERSION))
+        except KeyError:
+                raise ValueError("the model file lacks a version number; "
+                                 "version %d is required" %
+                                 (REQUIRED_MODEL_VERSION))
+
         for k, v in d['ngrams'].items():
             k = tuple(str(k).split('\t'))
             values = self.ngrams.setdefault(k, [])
             for p, count in v.items():
                 values.extend([str(p)] * count)
+            values.sort()
 
         for k, v in d['query_details'].items():
             values = self.query_details.setdefault(str(k), [])
@@ -1211,26 +1244,32 @@ class TrafficModel(object):
                     values.extend([()] * count)
                 else:
                     values.extend([tuple(str(p).split('\t'))] * count)
+            values.sort()
 
         if 'dns' in d:
             for k, v in d['dns'].items():
                 self.dns_opcounts[k] += v
 
         self.cumulative_duration = d['cumulative_duration']
-        self.conversation_rate = d['conversation_rate']
-
-    def construct_conversation(self, timestamp=0.0, client=2, server=1,
-                               hard_stop=None, packet_rate=1):
-        """Construct a individual converation from the model."""
-
-        c = Conversation(timestamp, (server, client))
-
+        self.packet_rate = d['packet_rate']
+
+    def construct_conversation_sequence(self, timestamp=0.0,
+                                        hard_stop=None,
+                                        replay_speed=1,
+                                        ignore_before=0):
+        """Construct an individual conversation packet sequence from the
+        model.
+        """
+        c = []
         key = (NON_PACKET,) * (self.n - 1)
+        if ignore_before is None:
+            ignore_before = timestamp - 1
 
-        while key in self.ngrams:
-            p = random.choice(self.ngrams.get(key, NON_PACKET))
+        while True:
+            p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
             if p == NON_PACKET:
                 break
+
             if p in self.query_details:
                 extra = random.choice(self.query_details[p])
             else:
@@ -1239,55 +1278,58 @@ class TrafficModel(object):
             protocol, opcode = p.split(':', 1)
             if protocol == 'wait':
                 log_wait_time = int(opcode) + random.random()
-                wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
+                wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
                 timestamp += wait
             else:
                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
-                wait = math.exp(log_wait) / packet_rate
+                wait = math.exp(log_wait) / replay_speed
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, protocol, opcode, extra)
+                if timestamp >= ignore_before:
+                    c.append((timestamp, protocol, opcode, extra))
 
             key = key[1:] + (p,)
 
         return c
 
-    def generate_conversations(self, rate, duration, packet_rate=1):
+    def generate_conversations(self, scale, duration, replay_speed=1,
+                               server=1, client=2):
         """Generate a list of conversations from the model."""
 
-        # We run the simulation for at least ten times as long as our
-        # desired duration, and take a section near the start.
-        rate_n, rate_t  = self.conversation_rate
+        # We run the simulation for ten times as long as our desired
+        # duration, and take the section at the end.
+        lead_in = 9 * duration
+        rate_n, rate_t  = self.packet_rate
+        target_packets = int(duration * scale * rate_n / rate_t)
 
-        duration2 = max(rate_t, duration * 2)
-        n = rate * duration2 * rate_n / rate_t
+        conversations = []
+        n_packets = 0
+
+        while n_packets < target_packets:
+            start = random.uniform(-lead_in, duration)
+            c = self.construct_conversation_sequence(start,
+                                                     hard_stop=duration,
+                                                     replay_speed=replay_speed,
+                                                     ignore_before=0)
+            conversations.append(c)
+            n_packets += len(c)
+
+        print(("we have %d packets (target %d) in %d conversations at scale %f"
+               % (n_packets, target_packets, len(conversations), scale)),
+              file=sys.stderr)
+        conversations.sort()  # sorts by first element == start time
+        return seq_to_conversations(conversations)
 
-        server = 1
-        client = 2
 
-        conversations = []
-        end = duration2
-        start = end - duration
-
-        while client < n + 2:
-            start = random.uniform(0, duration2)
-            c = self.construct_conversation(start,
-                                            client,
-                                            server,
-                                            hard_stop=(duration2 * 5),
-                                            packet_rate=packet_rate)
-
-            c.forget_packets_outside_window(start, end)
-            c.renormalise_times(start)
-            if len(c) != 0:
-                conversations.append(c)
+def seq_to_conversations(seq, server=1, client=2):
+    conversations = []
+    for s in seq:
+        if s:
+            c = Conversation(s[0][0], (server, client), s)
             client += 1
-
-        print(("we have %d conversations at rate %f" %
-               (len(conversations), rate)), file=sys.stderr)
-        conversations.sort()
-        return conversations
+            conversations.append(c)
+    return conversations
 
 
 IP_PROTOCOLS = {