traffic: new version of model with packet_rate, version number
[samba.git] / python / samba / emulate / traffic.py
index b6097cd..807fa82 100644 (file)
@@ -54,6 +54,8 @@ from samba.compat import get_string
 from samba.logger import get_samba_logger
 import bisect
 
+CURRENT_MODEL_VERSION = 2   # save as this
+REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
 SLEEP_OVERHEAD = 3e-4
 
 # we don't use None, because it complicates [de]serialisation
@@ -1136,7 +1138,7 @@ class TrafficModel(object):
         self.n = n
         self.dns_opcounts = defaultdict(int)
         self.cumulative_duration = 0.0
-        self.conversation_rate = [0, 1]
+        self.packet_rate = [0, 1]
 
     def learn(self, conversations, dns_opcounts={}):
         prev = 0.0
@@ -1149,10 +1151,15 @@ class TrafficModel(object):
             self.dns_opcounts[k] += v
 
         if len(conversations) > 1:
-            elapsed =\
-                conversations[-1].start_time - conversations[0].start_time
-            self.conversation_rate[0] = len(conversations)
-            self.conversation_rate[1] = elapsed
+            first = conversations[0].start_time
+            total = 0
+            last = first + 0.1
+            for c in conversations:
+                total += len(c)
+                last = max(last, c.packets[-1].timestamp)
+
+            self.packet_rate[0] = total
+            self.packet_rate[1] = last - first
 
         for c in conversations:
             client, server = c.guess_client_server(server)
@@ -1196,7 +1203,8 @@ class TrafficModel(object):
             'ngrams': ngrams,
             'query_details': query_details,
             'cumulative_duration': self.cumulative_duration,
-            'conversation_rate': self.conversation_rate,
+            'packet_rate': self.packet_rate,
+            'version': CURRENT_MODEL_VERSION
         }
         d['dns'] = self.dns_opcounts
 
@@ -1211,6 +1219,17 @@ class TrafficModel(object):
 
         d = json.load(f)
 
+        try:
+            version = d["version"]
+            if version < REQUIRED_MODEL_VERSION:
+                raise ValueError("the model file is version %d; "
+                                 "version %d is required" %
+                                 (version, REQUIRED_MODEL_VERSION))
+        except KeyError:
+                raise ValueError("the model file lacks a version number; "
+                                 "version %d is required" %
+                                 (REQUIRED_MODEL_VERSION))
+
         for k, v in d['ngrams'].items():
             k = tuple(str(k).split('\t'))
             values = self.ngrams.setdefault(k, [])
@@ -1232,18 +1251,22 @@ class TrafficModel(object):
                 self.dns_opcounts[k] += v
 
         self.cumulative_duration = d['cumulative_duration']
-        self.conversation_rate = d['conversation_rate']
-
-    def construct_conversation(self, timestamp=0.0, client=2, server=1,
-                               hard_stop=None, replay_speed=1):
-        """Construct a individual converation from the model."""
-
-        c = Conversation(timestamp, (server, client), conversation_id=client)
-
+        self.packet_rate = d['packet_rate']
+
+    def construct_conversation_sequence(self, timestamp=0.0,
+                                        hard_stop=None,
+                                        replay_speed=1,
+                                        ignore_before=0):
+        """Construct an individual conversation packet sequence from the
+        model.
+        """
+        c = []
         key = (NON_PACKET,) * (self.n - 1)
+        if ignore_before is None:
+            ignore_before = timestamp - 1
 
-        while key in self.ngrams:
-            p = random.choice(self.ngrams.get(key, NON_PACKET))
+        while True:
+            p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
             if p == NON_PACKET:
                 break
 
@@ -1263,47 +1286,50 @@ class TrafficModel(object):
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, protocol, opcode, extra)
+                if timestamp >= ignore_before:
+                    c.append((timestamp, protocol, opcode, extra))
 
             key = key[1:] + (p,)
 
         return c
 
-    def generate_conversations(self, rate, duration, replay_speed=1):
+    def generate_conversations(self, scale, duration, replay_speed=1,
+                               server=1, client=2):
         """Generate a list of conversations from the model."""
 
-        # We run the simulation for at least ten times as long as our
-        # desired duration, and take a section near the start.
-        rate_n, rate_t  = self.conversation_rate
+        # We run the simulation for ten times as long as our desired
+        # duration, and take the section at the end.
+        lead_in = 9 * duration
+        rate_n, rate_t  = self.packet_rate
+        target_packets = int(duration * scale * rate_n / rate_t)
 
-        duration2 = max(rate_t, duration * 2)
-        n = rate * duration2 * rate_n / rate_t
+        conversations = []
+        n_packets = 0
+
+        while n_packets < target_packets:
+            start = random.uniform(-lead_in, duration)
+            c = self.construct_conversation_sequence(start,
+                                                     hard_stop=duration,
+                                                     replay_speed=replay_speed,
+                                                     ignore_before=0)
+            conversations.append(c)
+            n_packets += len(c)
+
+        print(("we have %d packets (target %d) in %d conversations at scale %f"
+               % (n_packets, target_packets, len(conversations), scale)),
+              file=sys.stderr)
+        conversations.sort()  # sorts by first element == start time
+        return seq_to_conversations(conversations)
 
-        server = 1
-        client = 2
 
-        conversations = []
-        end = duration2
-        start = end - duration
-
-        while client < n + 2:
-            start = random.uniform(0, duration2)
-            c = self.construct_conversation(start,
-                                            client,
-                                            server,
-                                            hard_stop=(duration2 * 5),
-                                            replay_speed=replay_speed)
-
-            c.forget_packets_outside_window(start, end)
-            c.renormalise_times(start)
-            if len(c) != 0:
-                conversations.append(c)
+def seq_to_conversations(seq, server=1, client=2):
+    conversations = []
+    for s in seq:
+        if s:
+            c = Conversation(s[0][0], (server, client), s)
             client += 1
-
-        print(("we have %d conversations at rate %f" %
-               (len(conversations), rate)), file=sys.stderr)
-        conversations.sort()
-        return conversations
+            conversations.append(c)
+    return conversations
 
 
 IP_PROTOCOLS = {