traffic: new version of model with packet_rate, version number
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28
29 from collections import OrderedDict, Counter, defaultdict, namedtuple
30 from samba.emulate import traffic_packets
31 from samba.samdb import SamDB
32 import ldb
33 from ldb import LdbError
34 from samba.dcerpc import ClientConnection
35 from samba.dcerpc import security, drsuapi, lsa
36 from samba.dcerpc import netlogon
37 from samba.dcerpc.netlogon import netr_Authenticator
38 from samba.dcerpc import srvsvc
39 from samba.dcerpc import samr
40 from samba.drs_utils import drs_DsBind
41 import traceback
42 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
43 from samba.auth import system_session
44 from samba.dsdb import (
45     UF_NORMAL_ACCOUNT,
46     UF_SERVER_TRUST_ACCOUNT,
47     UF_TRUSTED_FOR_DELEGATION,
48     UF_WORKSTATION_TRUST_ACCOUNT
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
55 import bisect
56
57 CURRENT_MODEL_VERSION = 2   # save as this
58 REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
59 SLEEP_OVERHEAD = 3e-4
60
61 # we don't use None, because it complicates [de]serialisation
62 NON_PACKET = '-'
63
64 CLIENT_CLUES = {
65     ('dns', '0'): 1.0,      # query
66     ('smb', '0x72'): 1.0,   # Negotiate protocol
67     ('ldap', '0'): 1.0,     # bind
68     ('ldap', '3'): 1.0,     # searchRequest
69     ('ldap', '2'): 1.0,     # unbindRequest
70     ('cldap', '3'): 1.0,
71     ('dcerpc', '11'): 1.0,  # bind
72     ('dcerpc', '14'): 1.0,  # Alter_context
73     ('nbns', '0'): 1.0,     # query
74 }
75
76 SERVER_CLUES = {
77     ('dns', '1'): 1.0,      # response
78     ('ldap', '1'): 1.0,     # bind response
79     ('ldap', '4'): 1.0,     # search result
80     ('ldap', '5'): 1.0,     # search done
81     ('cldap', '5'): 1.0,
82     ('dcerpc', '12'): 1.0,  # bind_ack
83     ('dcerpc', '13'): 1.0,  # bind_nak
84     ('dcerpc', '15'): 1.0,  # Alter_context response
85 }
86
87 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
88
89 WAIT_SCALE = 10.0
90 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
91 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
92
93 # DEBUG_LEVEL can be changed by scripts with -d
94 DEBUG_LEVEL = 0
95
96 LOGGER = get_samba_logger(name=__name__)
97
98
99 def debug(level, msg, *args):
100     """Print a formatted debug message to standard error.
101
102
103     :param level: The debug level, message will be printed if it is <= the
104                   currently set debug level. The debug level can be set with
105                   the -d option.
106     :param msg:   The message to be logged, can contain C-Style format
107                   specifiers
108     :param args:  The parameters required by the format specifiers
109     """
110     if level <= DEBUG_LEVEL:
111         if not args:
112             print(msg, file=sys.stderr)
113         else:
114             print(msg % tuple(args), file=sys.stderr)
115
116
117 def debug_lineno(*args):
118     """ Print an unformatted log message to stderr, contaning the line number
119     """
120     tb = traceback.extract_stack(limit=2)
121     print((" %s:" "\033[01;33m"
122            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
123           file=sys.stderr)
124     for a in args:
125         print(a, file=sys.stderr)
126     print(file=sys.stderr)
127     sys.stderr.flush()
128
129
130 def random_colour_print(seeds):
131     """Return a function that prints a coloured line to stderr. The colour
132     of the line depends on a sort of hash of the integer arguments."""
133     if seeds:
134         s = 214
135         for x in seeds:
136             s += 17
137             s *= x
138             s %= 214
139         prefix = "\033[38;5;%dm" % (18 + s)
140
141         def p(*args):
142             if DEBUG_LEVEL > 0:
143                 for a in args:
144                     print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
145     else:
146         def p(*args):
147             if DEBUG_LEVEL > 0:
148                 for a in args:
149                     print(a, file=sys.stderr)
150
151     return p
152
153
154 class FakePacketError(Exception):
155     pass
156
157
158 class Packet(object):
159     """Details of a network packet"""
160     __slots__ = ('timestamp',
161                  'ip_protocol',
162                  'stream_number',
163                  'src',
164                  'dest',
165                  'protocol',
166                  'opcode',
167                  'desc',
168                  'extra',
169                  'endpoints')
170     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
171                  protocol, opcode, desc, extra):
172         self.timestamp = timestamp
173         self.ip_protocol = ip_protocol
174         self.stream_number = stream_number
175         self.src = src
176         self.dest = dest
177         self.protocol = protocol
178         self.opcode = opcode
179         self.desc = desc
180         self.extra = extra
181         if self.src < self.dest:
182             self.endpoints = (self.src, self.dest)
183         else:
184             self.endpoints = (self.dest, self.src)
185
186     @classmethod
187     def from_line(cls, line):
188         fields = line.rstrip('\n').split('\t')
189         (timestamp,
190          ip_protocol,
191          stream_number,
192          src,
193          dest,
194          protocol,
195          opcode,
196          desc) = fields[:8]
197         extra = fields[8:]
198
199         timestamp = float(timestamp)
200         src = int(src)
201         dest = int(dest)
202
203         return cls(timestamp, ip_protocol, stream_number, src, dest,
204                    protocol, opcode, desc, extra)
205
206     def as_summary(self, time_offset=0.0):
207         """Format the packet as a traffic_summary line.
208         """
209         extra = '\t'.join(self.extra)
210         t = self.timestamp + time_offset
211         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
212                 (t,
213                  self.ip_protocol,
214                  self.stream_number or '',
215                  self.src,
216                  self.dest,
217                  self.protocol,
218                  self.opcode,
219                  self.desc,
220                  extra))
221
222     def __str__(self):
223         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
224                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
225                  self.stream_number, self.protocol, self.opcode, self.desc,
226                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
227
228     def __repr__(self):
229         return "<Packet @%s>" % self
230
231     def copy(self):
232         return self.__class__(self.timestamp,
233                               self.ip_protocol,
234                               self.stream_number,
235                               self.src,
236                               self.dest,
237                               self.protocol,
238                               self.opcode,
239                               self.desc,
240                               self.extra)
241
242     def as_packet_type(self):
243         t = '%s:%s' % (self.protocol, self.opcode)
244         return t
245
246     def client_score(self):
247         """A positive number means we think it is a client; a negative number
248         means we think it is a server. Zero means no idea. range: -1 to 1.
249         """
250         key = (self.protocol, self.opcode)
251         if key in CLIENT_CLUES:
252             return CLIENT_CLUES[key]
253         if key in SERVER_CLUES:
254             return -SERVER_CLUES[key]
255         return 0.0
256
257     def play(self, conversation, context):
258         """Send the packet over the network, if required.
259
260         Some packets are ignored, i.e. for  protocols not handled,
261         server response messages, or messages that are generated by the
262         protocol layer associated with other packets.
263         """
264         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
265         try:
266             fn = getattr(traffic_packets, fn_name)
267
268         except AttributeError as e:
269             print("Conversation(%s) Missing handler %s" %
270                   (conversation.conversation_id, fn_name),
271                   file=sys.stderr)
272             return
273
274         # Don't display a message for kerberos packets, they're not directly
275         # generated they're used to indicate kerberos should be used
276         if self.protocol != "kerberos":
277             debug(2, "Conversation(%s) Calling handler %s" %
278                      (conversation.conversation_id, fn_name))
279
280         start = time.time()
281         try:
282             if fn(self, conversation, context):
283                 # Only collect timing data for functions that generate
284                 # network traffic, or fail
285                 end = time.time()
286                 duration = end - start
287                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
288                       (end, conversation.conversation_id, self.protocol,
289                        self.opcode, duration))
290         except Exception as e:
291             end = time.time()
292             duration = end - start
293             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
294                   (end, conversation.conversation_id, self.protocol,
295                    self.opcode, duration, e))
296
297     def __cmp__(self, other):
298         return self.timestamp - other.timestamp
299
300     def is_really_a_packet(self, missing_packet_stats=None):
301         return is_a_real_packet(self.protocol, self.opcode)
302
303
304 def is_a_real_packet(protocol, opcode):
305     """Is the packet one that can be ignored?
306
307     If so removing it will have no effect on the replay
308     """
309     if protocol in SKIPPED_PROTOCOLS:
310         # Ignore any packets for the protocols we're not interested in.
311         return False
312     if protocol == "ldap" and opcode == '':
313         # skip ldap continuation packets
314         return False
315
316     fn_name = 'packet_%s_%s' % (protocol, opcode)
317     fn = getattr(traffic_packets, fn_name, None)
318     if fn is None:
319         LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
320         return False
321     if fn is traffic_packets.null_packet:
322         return False
323     return True
324
325
326 class ReplayContext(object):
327     """State/Context for a conversation between an simulated client and a
328        server. Some of the context is shared amongst all conversations
329        and should be generated before the fork, while other context is
330        specific to a particular conversation and should be generated
331        *after* the fork, in generate_process_local_config().
332     """
333     def __init__(self,
334                  server=None,
335                  lp=None,
336                  creds=None,
337                  badpassword_frequency=None,
338                  prefer_kerberos=None,
339                  tempdir=None,
340                  statsdir=None,
341                  ou=None,
342                  base_dn=None,
343                  domain=None,
344                  domain_sid=None):
345
346         self.server                   = server
347         self.netlogon_connection      = None
348         self.creds                    = creds
349         self.lp                       = lp
350         self.prefer_kerberos          = prefer_kerberos
351         self.ou                       = ou
352         self.base_dn                  = base_dn
353         self.domain                   = domain
354         self.statsdir                 = statsdir
355         self.global_tempdir           = tempdir
356         self.domain_sid               = domain_sid
357         self.realm                    = lp.get('realm')
358
359         # Bad password attempt controls
360         self.badpassword_frequency    = badpassword_frequency
361         self.last_lsarpc_bad          = False
362         self.last_lsarpc_named_bad    = False
363         self.last_simple_bind_bad     = False
364         self.last_bind_bad            = False
365         self.last_srvsvc_bad          = False
366         self.last_drsuapi_bad         = False
367         self.last_netlogon_bad        = False
368         self.last_samlogon_bad        = False
369         self.generate_ldap_search_tables()
370
371     def generate_ldap_search_tables(self):
372         session = system_session()
373
374         db = SamDB(url="ldap://%s" % self.server,
375                    session_info=session,
376                    credentials=self.creds,
377                    lp=self.lp)
378
379         res = db.search(db.domain_dn(),
380                         scope=ldb.SCOPE_SUBTREE,
381                         controls=["paged_results:1:1000"],
382                         attrs=['dn'])
383
384         # find a list of dns for each pattern
385         # e.g. CN,CN,CN,DC,DC
386         dn_map = {}
387         attribute_clue_map = {
388             'invocationId': []
389         }
390
391         for r in res:
392             dn = str(r.dn)
393             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
394             dns = dn_map.setdefault(pattern, [])
395             dns.append(dn)
396             if dn.startswith('CN=NTDS Settings,'):
397                 attribute_clue_map['invocationId'].append(dn)
398
399         # extend the map in case we are working with a different
400         # number of DC components.
401         # for k, v in self.dn_map.items():
402         #     print >>sys.stderr, k, len(v)
403
404         for k in list(dn_map.keys()):
405             if k[-3:] != ',DC':
406                 continue
407             p = k[:-3]
408             while p[-3:] == ',DC':
409                 p = p[:-3]
410             for i in range(5):
411                 p += ',DC'
412                 if p != k and p in dn_map:
413                     print('dn_map collison %s %s' % (k, p),
414                           file=sys.stderr)
415                     continue
416                 dn_map[p] = dn_map[k]
417
418         self.dn_map = dn_map
419         self.attribute_clue_map = attribute_clue_map
420
421     def generate_process_local_config(self, account, conversation):
422         self.ldap_connections         = []
423         self.dcerpc_connections       = []
424         self.lsarpc_connections       = []
425         self.lsarpc_connections_named = []
426         self.drsuapi_connections      = []
427         self.srvsvc_connections       = []
428         self.samr_contexts            = []
429         self.netbios_name             = account.netbios_name
430         self.machinepass              = account.machinepass
431         self.username                 = account.username
432         self.userpass                 = account.userpass
433
434         self.tempdir = mk_masked_dir(self.global_tempdir,
435                                      'conversation-%d' %
436                                      conversation.conversation_id)
437
438         self.lp.set("private dir", self.tempdir)
439         self.lp.set("lock dir", self.tempdir)
440         self.lp.set("state directory", self.tempdir)
441         self.lp.set("tls verify peer", "no_check")
442
443         # If the domain was not specified, check for the environment
444         # variable.
445         if self.domain is None:
446             self.domain = os.environ["DOMAIN"]
447
448         self.remoteAddress = "/root/ncalrpc_as_system"
449         self.samlogon_dn   = ("cn=%s,%s" %
450                               (self.netbios_name, self.ou))
451         self.user_dn       = ("cn=%s,%s" %
452                               (self.username, self.ou))
453
454         self.generate_machine_creds()
455         self.generate_user_creds()
456
457     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
458         """Execute the supplied logon function, randomly choosing the
459            bad credentials.
460
461            Based on the frequency in badpassword_frequency randomly perform the
462            function with the supplied bad credentials.
463            If run with bad credentials, the function is re-run with the good
464            credentials.
465            failed_last_time is used to prevent consecutive bad credential
466            attempts. So the over all bad credential frequency will be lower
467            than that requested, but not significantly.
468         """
469         if not failed_last_time:
470             if (self.badpassword_frequency and self.badpassword_frequency > 0
471                 and random.random() < self.badpassword_frequency):
472                 try:
473                     f(bad)
474                 except:
475                     # Ignore any exceptions as the operation may fail
476                     # as it's being performed with bad credentials
477                     pass
478                 failed_last_time = True
479             else:
480                 failed_last_time = False
481
482         result = f(good)
483         return (result, failed_last_time)
484
485     def generate_user_creds(self):
486         """Generate the conversation specific user Credentials.
487
488         Each Conversation has an associated user account used to simulate
489         any non Administrative user traffic.
490
491         Generates user credentials with good and bad passwords and ldap
492         simple bind credentials with good and bad passwords.
493         """
494         self.user_creds = Credentials()
495         self.user_creds.guess(self.lp)
496         self.user_creds.set_workstation(self.netbios_name)
497         self.user_creds.set_password(self.userpass)
498         self.user_creds.set_username(self.username)
499         self.user_creds.set_domain(self.domain)
500         if self.prefer_kerberos:
501             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
502         else:
503             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
504
505         self.user_creds_bad = Credentials()
506         self.user_creds_bad.guess(self.lp)
507         self.user_creds_bad.set_workstation(self.netbios_name)
508         self.user_creds_bad.set_password(self.userpass[:-4])
509         self.user_creds_bad.set_username(self.username)
510         if self.prefer_kerberos:
511             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
512         else:
513             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
514
515         # Credentials for ldap simple bind.
516         self.simple_bind_creds = Credentials()
517         self.simple_bind_creds.guess(self.lp)
518         self.simple_bind_creds.set_workstation(self.netbios_name)
519         self.simple_bind_creds.set_password(self.userpass)
520         self.simple_bind_creds.set_username(self.username)
521         self.simple_bind_creds.set_gensec_features(
522             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
523         if self.prefer_kerberos:
524             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
525         else:
526             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
527         self.simple_bind_creds.set_bind_dn(self.user_dn)
528
529         self.simple_bind_creds_bad = Credentials()
530         self.simple_bind_creds_bad.guess(self.lp)
531         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
532         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
533         self.simple_bind_creds_bad.set_username(self.username)
534         self.simple_bind_creds_bad.set_gensec_features(
535             self.simple_bind_creds_bad.get_gensec_features() |
536             gensec.FEATURE_SEAL)
537         if self.prefer_kerberos:
538             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
539         else:
540             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
541         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
542
543     def generate_machine_creds(self):
544         """Generate the conversation specific machine Credentials.
545
546         Each Conversation has an associated machine account.
547
548         Generates machine credentials with good and bad passwords.
549         """
550
551         self.machine_creds = Credentials()
552         self.machine_creds.guess(self.lp)
553         self.machine_creds.set_workstation(self.netbios_name)
554         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
555         self.machine_creds.set_password(self.machinepass)
556         self.machine_creds.set_username(self.netbios_name + "$")
557         self.machine_creds.set_domain(self.domain)
558         if self.prefer_kerberos:
559             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
560         else:
561             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
562
563         self.machine_creds_bad = Credentials()
564         self.machine_creds_bad.guess(self.lp)
565         self.machine_creds_bad.set_workstation(self.netbios_name)
566         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
567         self.machine_creds_bad.set_password(self.machinepass[:-4])
568         self.machine_creds_bad.set_username(self.netbios_name + "$")
569         if self.prefer_kerberos:
570             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
571         else:
572             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
573
574     def get_matching_dn(self, pattern, attributes=None):
575         # If the pattern is an empty string, we assume ROOTDSE,
576         # Otherwise we try adding or removing DC suffixes, then
577         # shorter leading patterns until we hit one.
578         # e.g if there is no CN,CN,CN,CN,DC,DC
579         # we first try       CN,CN,CN,CN,DC
580         # and                CN,CN,CN,CN,DC,DC,DC
581         # then change to        CN,CN,CN,DC,DC
582         # and as last resort we use the base_dn
583         attr_clue = self.attribute_clue_map.get(attributes)
584         if attr_clue:
585             return random.choice(attr_clue)
586
587         pattern = pattern.upper()
588         while pattern:
589             if pattern in self.dn_map:
590                 return random.choice(self.dn_map[pattern])
591             # chop one off the front and try it all again.
592             pattern = pattern[3:]
593
594         return self.base_dn
595
596     def get_dcerpc_connection(self, new=False):
597         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
598         if self.dcerpc_connections and not new:
599             return self.dcerpc_connections[-1]
600         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
601                              (guid, 1), self.lp)
602         self.dcerpc_connections.append(c)
603         return c
604
605     def get_srvsvc_connection(self, new=False):
606         if self.srvsvc_connections and not new:
607             return self.srvsvc_connections[-1]
608
609         def connect(creds):
610             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
611                                  self.lp,
612                                  creds)
613
614         (c, self.last_srvsvc_bad) = \
615             self.with_random_bad_credentials(connect,
616                                              self.user_creds,
617                                              self.user_creds_bad,
618                                              self.last_srvsvc_bad)
619
620         self.srvsvc_connections.append(c)
621         return c
622
623     def get_lsarpc_connection(self, new=False):
624         if self.lsarpc_connections and not new:
625             return self.lsarpc_connections[-1]
626
627         def connect(creds):
628             binding_options = 'schannel,seal,sign'
629             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
630                               (self.server, binding_options),
631                               self.lp,
632                               creds)
633
634         (c, self.last_lsarpc_bad) = \
635             self.with_random_bad_credentials(connect,
636                                              self.machine_creds,
637                                              self.machine_creds_bad,
638                                              self.last_lsarpc_bad)
639
640         self.lsarpc_connections.append(c)
641         return c
642
643     def get_lsarpc_named_pipe_connection(self, new=False):
644         if self.lsarpc_connections_named and not new:
645             return self.lsarpc_connections_named[-1]
646
647         def connect(creds):
648             return lsa.lsarpc("ncacn_np:%s" % (self.server),
649                               self.lp,
650                               creds)
651
652         (c, self.last_lsarpc_named_bad) = \
653             self.with_random_bad_credentials(connect,
654                                              self.machine_creds,
655                                              self.machine_creds_bad,
656                                              self.last_lsarpc_named_bad)
657
658         self.lsarpc_connections_named.append(c)
659         return c
660
661     def get_drsuapi_connection_pair(self, new=False, unbind=False):
662         """get a (drs, drs_handle) tuple"""
663         if self.drsuapi_connections and not new:
664             c = self.drsuapi_connections[-1]
665             return c
666
667         def connect(creds):
668             binding_options = 'seal'
669             binding_string = "ncacn_ip_tcp:%s[%s]" %\
670                              (self.server, binding_options)
671             return drsuapi.drsuapi(binding_string, self.lp, creds)
672
673         (drs, self.last_drsuapi_bad) = \
674             self.with_random_bad_credentials(connect,
675                                              self.user_creds,
676                                              self.user_creds_bad,
677                                              self.last_drsuapi_bad)
678
679         (drs_handle, supported_extensions) = drs_DsBind(drs)
680         c = (drs, drs_handle)
681         self.drsuapi_connections.append(c)
682         return c
683
684     def get_ldap_connection(self, new=False, simple=False):
685         if self.ldap_connections and not new:
686             return self.ldap_connections[-1]
687
688         def simple_bind(creds):
689             """
690             To run simple bind against Windows, we need to run
691             following commands in PowerShell:
692
693                 Install-windowsfeature ADCS-Cert-Authority
694                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
695                 Restart-Computer
696
697             """
698             return SamDB('ldaps://%s' % self.server,
699                          credentials=creds,
700                          lp=self.lp)
701
702         def sasl_bind(creds):
703             return SamDB('ldap://%s' % self.server,
704                          credentials=creds,
705                          lp=self.lp)
706         if simple:
707             (samdb, self.last_simple_bind_bad) = \
708                 self.with_random_bad_credentials(simple_bind,
709                                                  self.simple_bind_creds,
710                                                  self.simple_bind_creds_bad,
711                                                  self.last_simple_bind_bad)
712         else:
713             (samdb, self.last_bind_bad) = \
714                 self.with_random_bad_credentials(sasl_bind,
715                                                  self.user_creds,
716                                                  self.user_creds_bad,
717                                                  self.last_bind_bad)
718
719         self.ldap_connections.append(samdb)
720         return samdb
721
722     def get_samr_context(self, new=False):
723         if not self.samr_contexts or new:
724             self.samr_contexts.append(
725                 SamrContext(self.server, lp=self.lp, creds=self.creds))
726         return self.samr_contexts[-1]
727
728     def get_netlogon_connection(self):
729
730         if self.netlogon_connection:
731             return self.netlogon_connection
732
733         def connect(creds):
734             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
735                                      (self.server),
736                                      self.lp,
737                                      creds)
738         (c, self.last_netlogon_bad) = \
739             self.with_random_bad_credentials(connect,
740                                              self.machine_creds,
741                                              self.machine_creds_bad,
742                                              self.last_netlogon_bad)
743         self.netlogon_connection = c
744         return c
745
746     def guess_a_dns_lookup(self):
747         return (self.realm, 'A')
748
749     def get_authenticator(self):
750         auth = self.machine_creds.new_client_authenticator()
751         current  = netr_Authenticator()
752         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
753         current.timestamp = auth["timestamp"]
754
755         subsequent = netr_Authenticator()
756         return (current, subsequent)
757
758
759 class SamrContext(object):
760     """State/Context associated with a samr connection.
761     """
762     def __init__(self, server, lp=None, creds=None):
763         self.connection    = None
764         self.handle        = None
765         self.domain_handle = None
766         self.domain_sid    = None
767         self.group_handle  = None
768         self.user_handle   = None
769         self.rids          = None
770         self.server        = server
771         self.lp            = lp
772         self.creds         = creds
773
774     def get_connection(self):
775         if not self.connection:
776             self.connection = samr.samr(
777                 "ncacn_ip_tcp:%s[seal]" % (self.server),
778                 lp_ctx=self.lp,
779                 credentials=self.creds)
780
781         return self.connection
782
783     def get_handle(self):
784         if not self.handle:
785             c = self.get_connection()
786             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
787         return self.handle
788
789
790 class Conversation(object):
791     """Details of a converation between a simulated client and a server."""
792     def __init__(self, start_time=None, endpoints=None, seq=(),
793                  conversation_id=None):
794         self.start_time = start_time
795         self.endpoints = endpoints
796         self.packets = []
797         self.msg = random_colour_print(endpoints)
798         self.client_balance = 0.0
799         self.conversation_id = conversation_id
800         for p in seq:
801             self.add_short_packet(*p)
802
803     def __cmp__(self, other):
804         if self.start_time is None:
805             if other.start_time is None:
806                 return 0
807             return -1
808         if other.start_time is None:
809             return 1
810         return self.start_time - other.start_time
811
812     def add_packet(self, packet):
813         """Add a packet object to this conversation, making a local copy with
814         a conversation-relative timestamp."""
815         p = packet.copy()
816
817         if self.start_time is None:
818             self.start_time = p.timestamp
819
820         if self.endpoints is None:
821             self.endpoints = p.endpoints
822
823         if p.endpoints != self.endpoints:
824             raise FakePacketError("Conversation endpoints %s don't match"
825                                   "packet endpoints %s" %
826                                   (self.endpoints, p.endpoints))
827
828         p.timestamp -= self.start_time
829
830         if p.src == p.endpoints[0]:
831             self.client_balance -= p.client_score()
832         else:
833             self.client_balance += p.client_score()
834
835         if p.is_really_a_packet():
836             self.packets.append(p)
837
838     def add_short_packet(self, timestamp, protocol, opcode, extra,
839                          client=True):
840         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
841         (possibly empty) list of extra data. If client is True, assume
842         this packet is from the client to the server.
843         """
844         src, dest = self.guess_client_server()
845         if not client:
846             src, dest = dest, src
847         key = (protocol, opcode)
848         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
849         if protocol in IP_PROTOCOLS:
850             ip_protocol = IP_PROTOCOLS[protocol]
851         else:
852             ip_protocol = '06'
853         packet = Packet(timestamp - self.start_time, ip_protocol,
854                         '', src, dest,
855                         protocol, opcode, desc, extra)
856         # XXX we're assuming the timestamp is already adjusted for
857         # this conversation?
858         # XXX should we adjust client balance for guessed packets?
859         if packet.src == packet.endpoints[0]:
860             self.client_balance -= packet.client_score()
861         else:
862             self.client_balance += packet.client_score()
863         if packet.is_really_a_packet():
864             self.packets.append(packet)
865
866     def __str__(self):
867         return ("<Conversation %s %s starting %.3f %d packets>" %
868                 (self.conversation_id, self.endpoints, self.start_time,
869                  len(self.packets)))
870
871     __repr__ = __str__
872
873     def __iter__(self):
874         return iter(self.packets)
875
876     def __len__(self):
877         return len(self.packets)
878
879     def get_duration(self):
880         if len(self.packets) < 2:
881             return 0
882         return self.packets[-1].timestamp - self.packets[0].timestamp
883
884     def replay_as_summary_lines(self):
885         lines = []
886         for p in self.packets:
887             lines.append(p.as_summary(self.start_time))
888         return lines
889
890     def replay_in_fork_with_delay(self, start, context=None, account=None):
891         """Fork a new process and replay the conversation.
892         """
893         def signal_handler(signal, frame):
894             """Signal handler closes standard out and error.
895
896             Triggered by a sigterm, ensures that the log messages are flushed
897             to disk and not lost.
898             """
899             sys.stderr.close()
900             sys.stdout.close()
901             os._exit(0)
902
903         t = self.start_time
904         now = time.time() - start
905         gap = t - now
906         # we are replaying strictly in order, so it is safe to sleep
907         # in the main process if the gap is big enough. This reduces
908         # the number of concurrent threads, which allows us to make
909         # larger loads.
910         if gap > 0.15 and False:
911             print("sleeping for %f in main process" % (gap - 0.1),
912                   file=sys.stderr)
913             time.sleep(gap - 0.1)
914             now = time.time() - start
915             gap = t - now
916             print("gap is now %f" % gap, file=sys.stderr)
917
918         self.conversation_id = next(context.next_conversation_id)
919         pid = os.fork()
920         if pid != 0:
921             return pid
922         pid = os.getpid()
923         signal.signal(signal.SIGTERM, signal_handler)
924         # we must never return, or we'll end up running parts of the
925         # parent's clean-up code. So we work in a try...finally, and
926         # try to print any exceptions.
927
928         try:
929             context.generate_process_local_config(account, self)
930             sys.stdin.close()
931             os.close(0)
932             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
933                                     self.conversation_id)
934             sys.stdout.close()
935             sys.stdout = open(filename, 'w')
936
937             sleep_time = gap - SLEEP_OVERHEAD
938             if sleep_time > 0:
939                 time.sleep(sleep_time)
940
941             miss = t - (time.time() - start)
942             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
943             self.replay(context)
944         except Exception:
945             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
946                   file=sys.stderr)
947             traceback.print_exc(sys.stderr)
948         finally:
949             sys.stderr.close()
950             sys.stdout.close()
951             os._exit(0)
952
953     def replay(self, context=None):
954         start = time.time()
955
956         for p in self.packets:
957             now = time.time() - start
958             gap = p.timestamp - now
959             sleep_time = gap - SLEEP_OVERHEAD
960             if sleep_time > 0:
961                 time.sleep(sleep_time)
962
963             miss = p.timestamp - (time.time() - start)
964             if context is None:
965                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
966                                                            os.getpid()))
967                 continue
968             p.play(self, context)
969
970     def guess_client_server(self, server_clue=None):
971         """Have a go at deciding who is the server and who is the client.
972         returns (client, server)
973         """
974         a, b = self.endpoints
975
976         if self.client_balance < 0:
977             return (a, b)
978
979         # in the absense of a clue, we will fall through to assuming
980         # the lowest number is the server (which is usually true).
981
982         if self.client_balance == 0 and server_clue == b:
983             return (a, b)
984
985         return (b, a)
986
987     def forget_packets_outside_window(self, s, e):
988         """Prune any packets outside the timne window we're interested in
989
990         :param s: start of the window
991         :param e: end of the window
992         """
993         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
994         self.start_time = self.packets[0].timestamp if self.packets else None
995
996     def renormalise_times(self, start_time):
997         """Adjust the packet start times relative to the new start time."""
998         for p in self.packets:
999             p.timestamp -= start_time
1000
1001         if self.start_time is not None:
1002             self.start_time -= start_time
1003
1004
1005 class DnsHammer(Conversation):
1006     """A lightweight conversation that generates a lot of dns:0 packets on
1007     the fly"""
1008
1009     def __init__(self, dns_rate, duration):
1010         n = int(dns_rate * duration)
1011         self.times = [random.uniform(0, duration) for i in range(n)]
1012         self.times.sort()
1013         self.rate = dns_rate
1014         self.duration = duration
1015         self.start_time = 0
1016         self.msg = random_colour_print()
1017
1018     def __str__(self):
1019         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1020                 (len(self.times), self.duration, self.rate))
1021
1022     def replay_in_fork_with_delay(self, start, context=None, account=None):
1023         return Conversation.replay_in_fork_with_delay(self,
1024                                                       start,
1025                                                       context,
1026                                                       account)
1027
1028     def replay(self, context=None):
1029         start = time.time()
1030         fn = traffic_packets.packet_dns_0
1031         for t in self.times:
1032             now = time.time() - start
1033             gap = t - now
1034             sleep_time = gap - SLEEP_OVERHEAD
1035             if sleep_time > 0:
1036                 time.sleep(sleep_time)
1037
1038             if context is None:
1039                 miss = t - (time.time() - start)
1040                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1041                                                            os.getpid()))
1042                 continue
1043
1044             packet_start = time.time()
1045             try:
1046                 fn(self, self, context)
1047                 end = time.time()
1048                 duration = end - packet_start
1049                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1050             except Exception as e:
1051                 end = time.time()
1052                 duration = end - packet_start
1053                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1054
1055
1056 def ingest_summaries(files, dns_mode='count'):
1057     """Load a summary traffic summary file and generated Converations from it.
1058     """
1059
1060     dns_counts = defaultdict(int)
1061     packets = []
1062     for f in files:
1063         if isinstance(f, str):
1064             f = open(f)
1065         print("Ingesting %s" % (f.name,), file=sys.stderr)
1066         for line in f:
1067             p = Packet.from_line(line)
1068             if p.protocol == 'dns' and dns_mode != 'include':
1069                 dns_counts[p.opcode] += 1
1070             else:
1071                 packets.append(p)
1072
1073         f.close()
1074
1075     if not packets:
1076         return [], 0
1077
1078     start_time = min(p.timestamp for p in packets)
1079     last_packet = max(p.timestamp for p in packets)
1080
1081     print("gathering packets into conversations", file=sys.stderr)
1082     conversations = OrderedDict()
1083     for i, p in enumerate(packets):
1084         p.timestamp -= start_time
1085         c = conversations.get(p.endpoints)
1086         if c is None:
1087             c = Conversation(conversation_id=(i + 2))
1088             conversations[p.endpoints] = c
1089         c.add_packet(p)
1090
1091     # We only care about conversations with actual traffic, so we
1092     # filter out conversations with nothing to say. We do that here,
1093     # rather than earlier, because those empty packets contain useful
1094     # hints as to which end of the conversation was the client.
1095     conversation_list = []
1096     for c in conversations.values():
1097         if len(c) != 0:
1098             conversation_list.append(c)
1099
1100     # This is obviously not correct, as many conversations will appear
1101     # to start roughly simultaneously at the beginning of the snapshot.
1102     # To which we say: oh well, so be it.
1103     duration = float(last_packet - start_time)
1104     mean_interval = len(conversations) / duration
1105
1106     return conversation_list, mean_interval, duration, dns_counts
1107
1108
1109 def guess_server_address(conversations):
1110     # we guess the most common address.
1111     addresses = Counter()
1112     for c in conversations:
1113         addresses.update(c.endpoints)
1114     if addresses:
1115         return addresses.most_common(1)[0]
1116
1117
1118 def stringify_keys(x):
1119     y = {}
1120     for k, v in x.items():
1121         k2 = '\t'.join(k)
1122         y[k2] = v
1123     return y
1124
1125
1126 def unstringify_keys(x):
1127     y = {}
1128     for k, v in x.items():
1129         t = tuple(str(k).split('\t'))
1130         y[t] = v
1131     return y
1132
1133
1134 class TrafficModel(object):
1135     def __init__(self, n=3):
1136         self.ngrams = {}
1137         self.query_details = {}
1138         self.n = n
1139         self.dns_opcounts = defaultdict(int)
1140         self.cumulative_duration = 0.0
1141         self.packet_rate = [0, 1]
1142
1143     def learn(self, conversations, dns_opcounts={}):
1144         prev = 0.0
1145         cum_duration = 0.0
1146         key = (NON_PACKET,) * (self.n - 1)
1147
1148         server = guess_server_address(conversations)
1149
1150         for k, v in dns_opcounts.items():
1151             self.dns_opcounts[k] += v
1152
1153         if len(conversations) > 1:
1154             first = conversations[0].start_time
1155             total = 0
1156             last = first + 0.1
1157             for c in conversations:
1158                 total += len(c)
1159                 last = max(last, c.packets[-1].timestamp)
1160
1161             self.packet_rate[0] = total
1162             self.packet_rate[1] = last - first
1163
1164         for c in conversations:
1165             client, server = c.guess_client_server(server)
1166             cum_duration += c.get_duration()
1167             key = (NON_PACKET,) * (self.n - 1)
1168             for p in c:
1169                 if p.src != client:
1170                     continue
1171
1172                 elapsed = p.timestamp - prev
1173                 prev = p.timestamp
1174                 if elapsed > WAIT_THRESHOLD:
1175                     # add the wait as an extra state
1176                     wait = 'wait:%d' % (math.log(max(1.0,
1177                                                      elapsed * WAIT_SCALE)))
1178                     self.ngrams.setdefault(key, []).append(wait)
1179                     key = key[1:] + (wait,)
1180
1181                 short_p = p.as_packet_type()
1182                 self.query_details.setdefault(short_p,
1183                                               []).append(tuple(p.extra))
1184                 self.ngrams.setdefault(key, []).append(short_p)
1185                 key = key[1:] + (short_p,)
1186
1187         self.cumulative_duration += cum_duration
1188         # add in the end
1189         self.ngrams.setdefault(key, []).append(NON_PACKET)
1190
1191     def save(self, f):
1192         ngrams = {}
1193         for k, v in self.ngrams.items():
1194             k = '\t'.join(k)
1195             ngrams[k] = dict(Counter(v))
1196
1197         query_details = {}
1198         for k, v in self.query_details.items():
1199             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1200                                             for x in v))
1201
1202         d = {
1203             'ngrams': ngrams,
1204             'query_details': query_details,
1205             'cumulative_duration': self.cumulative_duration,
1206             'packet_rate': self.packet_rate,
1207             'version': CURRENT_MODEL_VERSION
1208         }
1209         d['dns'] = self.dns_opcounts
1210
1211         if isinstance(f, str):
1212             f = open(f, 'w')
1213
1214         json.dump(d, f, indent=2)
1215
1216     def load(self, f):
1217         if isinstance(f, str):
1218             f = open(f)
1219
1220         d = json.load(f)
1221
1222         try:
1223             version = d["version"]
1224             if version < REQUIRED_MODEL_VERSION:
1225                 raise ValueError("the model file is version %d; "
1226                                  "version %d is required" %
1227                                  (version, REQUIRED_MODEL_VERSION))
1228         except KeyError:
1229                 raise ValueError("the model file lacks a version number; "
1230                                  "version %d is required" %
1231                                  (REQUIRED_MODEL_VERSION))
1232
1233         for k, v in d['ngrams'].items():
1234             k = tuple(str(k).split('\t'))
1235             values = self.ngrams.setdefault(k, [])
1236             for p, count in v.items():
1237                 values.extend([str(p)] * count)
1238             values.sort()
1239
1240         for k, v in d['query_details'].items():
1241             values = self.query_details.setdefault(str(k), [])
1242             for p, count in v.items():
1243                 if p == '-':
1244                     values.extend([()] * count)
1245                 else:
1246                     values.extend([tuple(str(p).split('\t'))] * count)
1247             values.sort()
1248
1249         if 'dns' in d:
1250             for k, v in d['dns'].items():
1251                 self.dns_opcounts[k] += v
1252
1253         self.cumulative_duration = d['cumulative_duration']
1254         self.packet_rate = d['packet_rate']
1255
1256     def construct_conversation_sequence(self, timestamp=0.0,
1257                                         hard_stop=None,
1258                                         replay_speed=1,
1259                                         ignore_before=0):
1260         """Construct an individual conversation packet sequence from the
1261         model.
1262         """
1263         c = []
1264         key = (NON_PACKET,) * (self.n - 1)
1265         if ignore_before is None:
1266             ignore_before = timestamp - 1
1267
1268         while True:
1269             p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1270             if p == NON_PACKET:
1271                 break
1272
1273             if p in self.query_details:
1274                 extra = random.choice(self.query_details[p])
1275             else:
1276                 extra = []
1277
1278             protocol, opcode = p.split(':', 1)
1279             if protocol == 'wait':
1280                 log_wait_time = int(opcode) + random.random()
1281                 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1282                 timestamp += wait
1283             else:
1284                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1285                 wait = math.exp(log_wait) / replay_speed
1286                 timestamp += wait
1287                 if hard_stop is not None and timestamp > hard_stop:
1288                     break
1289                 if timestamp >= ignore_before:
1290                     c.append((timestamp, protocol, opcode, extra))
1291
1292             key = key[1:] + (p,)
1293
1294         return c
1295
1296     def generate_conversations(self, scale, duration, replay_speed=1,
1297                                server=1, client=2):
1298         """Generate a list of conversations from the model."""
1299
1300         # We run the simulation for ten times as long as our desired
1301         # duration, and take the section at the end.
1302         lead_in = 9 * duration
1303         rate_n, rate_t  = self.packet_rate
1304         target_packets = int(duration * scale * rate_n / rate_t)
1305
1306         conversations = []
1307         n_packets = 0
1308
1309         while n_packets < target_packets:
1310             start = random.uniform(-lead_in, duration)
1311             c = self.construct_conversation_sequence(start,
1312                                                      hard_stop=duration,
1313                                                      replay_speed=replay_speed,
1314                                                      ignore_before=0)
1315             conversations.append(c)
1316             n_packets += len(c)
1317
1318         print(("we have %d packets (target %d) in %d conversations at scale %f"
1319                % (n_packets, target_packets, len(conversations), scale)),
1320               file=sys.stderr)
1321         conversations.sort()  # sorts by first element == start time
1322         return seq_to_conversations(conversations)
1323
1324
1325 def seq_to_conversations(seq, server=1, client=2):
1326     conversations = []
1327     for s in seq:
1328         if s:
1329             c = Conversation(s[0][0], (server, client), s)
1330             client += 1
1331             conversations.append(c)
1332     return conversations
1333
1334
1335 IP_PROTOCOLS = {
1336     'dns': '11',
1337     'rpc_netlogon': '06',
1338     'kerberos': '06',      # ratio 16248:258
1339     'smb': '06',
1340     'smb2': '06',
1341     'ldap': '06',
1342     'cldap': '11',
1343     'lsarpc': '06',
1344     'samr': '06',
1345     'dcerpc': '06',
1346     'epm': '06',
1347     'drsuapi': '06',
1348     'browser': '11',
1349     'smb_netlogon': '11',
1350     'srvsvc': '06',
1351     'nbns': '11',
1352 }
1353
1354 OP_DESCRIPTIONS = {
1355     ('browser', '0x01'): 'Host Announcement (0x01)',
1356     ('browser', '0x02'): 'Request Announcement (0x02)',
1357     ('browser', '0x08'): 'Browser Election Request (0x08)',
1358     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1359     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1360     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1361     ('cldap', '3'): 'searchRequest',
1362     ('cldap', '5'): 'searchResDone',
1363     ('dcerpc', '0'): 'Request',
1364     ('dcerpc', '11'): 'Bind',
1365     ('dcerpc', '12'): 'Bind_ack',
1366     ('dcerpc', '13'): 'Bind_nak',
1367     ('dcerpc', '14'): 'Alter_context',
1368     ('dcerpc', '15'): 'Alter_context_resp',
1369     ('dcerpc', '16'): 'AUTH3',
1370     ('dcerpc', '2'): 'Response',
1371     ('dns', '0'): 'query',
1372     ('dns', '1'): 'response',
1373     ('drsuapi', '0'): 'DsBind',
1374     ('drsuapi', '12'): 'DsCrackNames',
1375     ('drsuapi', '13'): 'DsWriteAccountSpn',
1376     ('drsuapi', '1'): 'DsUnbind',
1377     ('drsuapi', '2'): 'DsReplicaSync',
1378     ('drsuapi', '3'): 'DsGetNCChanges',
1379     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1380     ('epm', '3'): 'Map',
1381     ('kerberos', ''): '',
1382     ('ldap', '0'): 'bindRequest',
1383     ('ldap', '1'): 'bindResponse',
1384     ('ldap', '2'): 'unbindRequest',
1385     ('ldap', '3'): 'searchRequest',
1386     ('ldap', '4'): 'searchResEntry',
1387     ('ldap', '5'): 'searchResDone',
1388     ('ldap', ''): '*** Unknown ***',
1389     ('lsarpc', '14'): 'lsa_LookupNames',
1390     ('lsarpc', '15'): 'lsa_LookupSids',
1391     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1392     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1393     ('lsarpc', '6'): 'lsa_OpenPolicy',
1394     ('lsarpc', '76'): 'lsa_LookupSids3',
1395     ('lsarpc', '77'): 'lsa_LookupNames4',
1396     ('nbns', '0'): 'query',
1397     ('nbns', '1'): 'response',
1398     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1399     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1400     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1401     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1402     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1403     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1404     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1405     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1406     ('samr', '0',): 'Connect',
1407     ('samr', '16'): 'GetAliasMembership',
1408     ('samr', '17'): 'LookupNames',
1409     ('samr', '18'): 'LookupRids',
1410     ('samr', '19'): 'OpenGroup',
1411     ('samr', '1'): 'Close',
1412     ('samr', '25'): 'QueryGroupMember',
1413     ('samr', '34'): 'OpenUser',
1414     ('samr', '36'): 'QueryUserInfo',
1415     ('samr', '39'): 'GetGroupsForUser',
1416     ('samr', '3'): 'QuerySecurity',
1417     ('samr', '5'): 'LookupDomain',
1418     ('samr', '64'): 'Connect5',
1419     ('samr', '6'): 'EnumDomains',
1420     ('samr', '7'): 'OpenDomain',
1421     ('samr', '8'): 'QueryDomainInfo',
1422     ('smb', '0x04'): 'Close (0x04)',
1423     ('smb', '0x24'): 'Locking AndX (0x24)',
1424     ('smb', '0x2e'): 'Read AndX (0x2e)',
1425     ('smb', '0x32'): 'Trans2 (0x32)',
1426     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1427     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1428     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1429     ('smb', '0x74'): 'Logoff AndX (0x74)',
1430     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1431     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1432     ('smb2', '0'): 'NegotiateProtocol',
1433     ('smb2', '11'): 'Ioctl',
1434     ('smb2', '14'): 'Find',
1435     ('smb2', '16'): 'GetInfo',
1436     ('smb2', '18'): 'Break',
1437     ('smb2', '1'): 'SessionSetup',
1438     ('smb2', '2'): 'SessionLogoff',
1439     ('smb2', '3'): 'TreeConnect',
1440     ('smb2', '4'): 'TreeDisconnect',
1441     ('smb2', '5'): 'Create',
1442     ('smb2', '6'): 'Close',
1443     ('smb2', '8'): 'Read',
1444     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1445     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1446                                'user unknown (0x17)'),
1447     ('srvsvc', '16'): 'NetShareGetInfo',
1448     ('srvsvc', '21'): 'NetSrvGetInfo',
1449 }
1450
1451
1452 def expand_short_packet(p, timestamp, src, dest, extra):
1453     protocol, opcode = p.split(':', 1)
1454     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1455     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1456
1457     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1458     line.extend(extra)
1459     return '\t'.join(line)
1460
1461
1462 def replay(conversations,
1463            host=None,
1464            creds=None,
1465            lp=None,
1466            accounts=None,
1467            dns_rate=0,
1468            duration=None,
1469            **kwargs):
1470
1471     context = ReplayContext(server=host,
1472                             creds=creds,
1473                             lp=lp,
1474                             **kwargs)
1475
1476     if len(accounts) < len(conversations):
1477         print(("we have %d accounts but %d conversations" %
1478                (accounts, conversations)), file=sys.stderr)
1479
1480     cstack = list(zip(
1481         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1482         accounts))
1483
1484     # Set the process group so that the calling scripts are not killed
1485     # when the forked child processes are killed.
1486     os.setpgrp()
1487
1488     start = time.time()
1489
1490     if duration is None:
1491         # end 1 second after the last packet of the last conversation
1492         # to start. Conversations other than the last could still be
1493         # going, but we don't care.
1494         duration = cstack[0][0].packets[-1].timestamp + 1.0
1495         print("We will stop after %.1f seconds" % duration,
1496               file=sys.stderr)
1497
1498     end = start + duration
1499
1500     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1501           % (len(conversations), duration))
1502
1503     children = {}
1504     if dns_rate:
1505         dns_hammer = DnsHammer(dns_rate, duration)
1506         cstack.append((dns_hammer, None))
1507
1508     try:
1509         while True:
1510             # we spawn a batch, wait for finishers, then spawn another
1511             now = time.time()
1512             batch_end = min(now + 2.0, end)
1513             fork_time = 0.0
1514             fork_n = 0
1515             while cstack:
1516                 c, account = cstack.pop()
1517                 if c.start_time + start > batch_end:
1518                     cstack.append((c, account))
1519                     break
1520
1521                 st = time.time()
1522                 pid = c.replay_in_fork_with_delay(start, context, account)
1523                 children[pid] = c
1524                 t = time.time()
1525                 elapsed = t - st
1526                 fork_time += elapsed
1527                 fork_n += 1
1528                 print("forked %s in pid %s (in %fs)" % (c, pid,
1529                                                         elapsed),
1530                       file=sys.stderr)
1531
1532             if fork_n:
1533                 print(("forked %d times in %f seconds (avg %f)" %
1534                        (fork_n, fork_time, fork_time / fork_n)),
1535                       file=sys.stderr)
1536             elif cstack:
1537                 debug(2, "no forks in batch ending %f" % batch_end)
1538
1539             while time.time() < batch_end - 1.0:
1540                 time.sleep(0.01)
1541                 try:
1542                     pid, status = os.waitpid(-1, os.WNOHANG)
1543                 except OSError as e:
1544                     if e.errno != 10:  # no child processes
1545                         raise
1546                     break
1547                 if pid:
1548                     c = children.pop(pid, None)
1549                     print(("process %d finished conversation %s;"
1550                            " %d to go" %
1551                            (pid, c, len(children))), file=sys.stderr)
1552
1553             if time.time() >= end:
1554                 print("time to stop", file=sys.stderr)
1555                 break
1556
1557     except Exception:
1558         print("EXCEPTION in parent", file=sys.stderr)
1559         traceback.print_exc()
1560     finally:
1561         for s in (15, 15, 9):
1562             print(("killing %d children with -%d" %
1563                    (len(children), s)), file=sys.stderr)
1564             for pid in children:
1565                 try:
1566                     os.kill(pid, s)
1567                 except OSError as e:
1568                     if e.errno != 3:  # don't fail if it has already died
1569                         raise
1570             time.sleep(0.5)
1571             end = time.time() + 1
1572             while children:
1573                 try:
1574                     pid, status = os.waitpid(-1, os.WNOHANG)
1575                 except OSError as e:
1576                     if e.errno != 10:
1577                         raise
1578                 if pid != 0:
1579                     c = children.pop(pid, None)
1580                     print(("kill -%d %d KILLED conversation %s; "
1581                            "%d to go" %
1582                            (s, pid, c, len(children))),
1583                           file=sys.stderr)
1584                 if time.time() >= end:
1585                     break
1586
1587             if not children:
1588                 break
1589             time.sleep(1)
1590
1591         if children:
1592             print("%d children are missing" % len(children),
1593                   file=sys.stderr)
1594
1595         # there may be stragglers that were forked just as ^C was hit
1596         # and don't appear in the list of children. We can get them
1597         # with killpg, but that will also kill us, so this is^H^H would be
1598         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1599         # so as not to have to fuss around writing signal handlers.
1600         try:
1601             os.killpg(0, 2)
1602         except KeyboardInterrupt:
1603             print("ignoring fake ^C", file=sys.stderr)
1604
1605
1606 def openLdb(host, creds, lp):
1607     session = system_session()
1608     ldb = SamDB(url="ldap://%s" % host,
1609                 session_info=session,
1610                 options=['modules:paged_searches'],
1611                 credentials=creds,
1612                 lp=lp)
1613     return ldb
1614
1615
1616 def ou_name(ldb, instance_id):
1617     """Generate an ou name from the instance id"""
1618     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1619                                                     ldb.domain_dn())
1620
1621
1622 def create_ou(ldb, instance_id):
1623     """Create an ou, all created user and machine accounts will belong to it.
1624
1625     This allows all the created resources to be cleaned up easily.
1626     """
1627     ou = ou_name(ldb, instance_id)
1628     try:
1629         ldb.add({"dn": ou.split(',', 1)[1],
1630                  "objectclass": "organizationalunit"})
1631     except LdbError as e:
1632         (status, _) = e.args
1633         # ignore already exists
1634         if status != 68:
1635             raise
1636     try:
1637         ldb.add({"dn": ou,
1638                  "objectclass": "organizationalunit"})
1639     except LdbError as e:
1640         (status, _) = e.args
1641         # ignore already exists
1642         if status != 68:
1643             raise
1644     return ou
1645
1646
1647 # ConversationAccounts holds details of the machine and user accounts
1648 # associated with a conversation.
1649 #
1650 # We use a named tuple to reduce shared memory usage.
1651 ConversationAccounts = namedtuple('ConversationAccounts',
1652                                   ('netbios_name',
1653                                    'machinepass',
1654                                    'username',
1655                                    'userpass'))
1656
1657
1658 def generate_replay_accounts(ldb, instance_id, number, password):
1659     """Generate a series of unique machine and user account names."""
1660
1661     accounts = []
1662     for i in range(1, number + 1):
1663         netbios_name = machine_name(instance_id, i)
1664         username = user_name(instance_id, i)
1665
1666         account = ConversationAccounts(netbios_name, password, username,
1667                                        password)
1668         accounts.append(account)
1669     return accounts
1670
1671
1672 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1673                            traffic_account=True):
1674     """Create a machine account via ldap."""
1675
1676     ou = ou_name(ldb, instance_id)
1677     dn = "cn=%s,%s" % (netbios_name, ou)
1678     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1679
1680     if traffic_account:
1681         # we set these bits for the machine account otherwise the replayed
1682         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1683         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1684                                UF_SERVER_TRUST_ACCOUNT)
1685
1686     else:
1687         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1688
1689     ldb.add({
1690         "dn": dn,
1691         "objectclass": "computer",
1692         "sAMAccountName": "%s$" % netbios_name,
1693         "userAccountControl": account_controls,
1694         "unicodePwd": utf16pw})
1695
1696
1697 def create_user_account(ldb, instance_id, username, userpass):
1698     """Create a user account via ldap."""
1699     ou = ou_name(ldb, instance_id)
1700     user_dn = "cn=%s,%s" % (username, ou)
1701     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1702     ldb.add({
1703         "dn": user_dn,
1704         "objectclass": "user",
1705         "sAMAccountName": username,
1706         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1707         "unicodePwd": utf16pw
1708     })
1709
1710     # grant user write permission to do things like write account SPN
1711     sdutils = sd_utils.SDUtils(ldb)
1712     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1713
1714
1715 def create_group(ldb, instance_id, name):
1716     """Create a group via ldap."""
1717
1718     ou = ou_name(ldb, instance_id)
1719     dn = "cn=%s,%s" % (name, ou)
1720     ldb.add({
1721         "dn": dn,
1722         "objectclass": "group",
1723         "sAMAccountName": name,
1724     })
1725
1726
1727 def user_name(instance_id, i):
1728     """Generate a user name based in the instance id"""
1729     return "STGU-%d-%d" % (instance_id, i)
1730
1731
1732 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1733     """Seach objectclass, return attr in a set"""
1734     objs = ldb.search(
1735         expression="(objectClass={})".format(objectclass),
1736         attrs=[attr]
1737     )
1738     return {str(obj[attr]) for obj in objs}
1739
1740
1741 def generate_users(ldb, instance_id, number, password):
1742     """Add users to the server"""
1743     existing_objects = search_objectclass(ldb, objectclass='user')
1744     users = 0
1745     for i in range(number, 0, -1):
1746         name = user_name(instance_id, i)
1747         if name not in existing_objects:
1748             create_user_account(ldb, instance_id, name, password)
1749             users += 1
1750             if users % 50 == 0:
1751                 LOGGER.info("Created %u/%u users" % (users, number))
1752
1753     return users
1754
1755
1756 def machine_name(instance_id, i, traffic_account=True):
1757     """Generate a machine account name from instance id."""
1758     if traffic_account:
1759         # traffic accounts correspond to a given user, and use different
1760         # userAccountControl flags to ensure packets get processed correctly
1761         # by the DC
1762         return "STGM-%d-%d" % (instance_id, i)
1763     else:
1764         # Otherwise we're just generating computer accounts to simulate a
1765         # semi-realistic network. These use the default computer
1766         # userAccountControl flags, so we use a different account name so that
1767         # we don't try to use them when generating packets
1768         return "PC-%d-%d" % (instance_id, i)
1769
1770
1771 def generate_machine_accounts(ldb, instance_id, number, password,
1772                               traffic_account=True):
1773     """Add machine accounts to the server"""
1774     existing_objects = search_objectclass(ldb, objectclass='computer')
1775     added = 0
1776     for i in range(number, 0, -1):
1777         name = machine_name(instance_id, i, traffic_account)
1778         if name + "$" not in existing_objects:
1779             create_machine_account(ldb, instance_id, name, password,
1780                                    traffic_account)
1781             added += 1
1782             if added % 50 == 0:
1783                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1784
1785     return added
1786
1787
1788 def group_name(instance_id, i):
1789     """Generate a group name from instance id."""
1790     return "STGG-%d-%d" % (instance_id, i)
1791
1792
1793 def generate_groups(ldb, instance_id, number):
1794     """Create the required number of groups on the server."""
1795     existing_objects = search_objectclass(ldb, objectclass='group')
1796     groups = 0
1797     for i in range(number, 0, -1):
1798         name = group_name(instance_id, i)
1799         if name not in existing_objects:
1800             create_group(ldb, instance_id, name)
1801             groups += 1
1802             if groups % 1000 == 0:
1803                 LOGGER.info("Created %u/%u groups" % (groups, number))
1804
1805     return groups
1806
1807
1808 def clean_up_accounts(ldb, instance_id):
1809     """Remove the created accounts and groups from the server."""
1810     ou = ou_name(ldb, instance_id)
1811     try:
1812         ldb.delete(ou, ["tree_delete:1"])
1813     except LdbError as e:
1814         (status, _) = e.args
1815         # ignore does not exist
1816         if status != 32:
1817             raise
1818
1819
1820 def generate_users_and_groups(ldb, instance_id, password,
1821                               number_of_users, number_of_groups,
1822                               group_memberships, max_members,
1823                               machine_accounts, traffic_accounts=True):
1824     """Generate the required users and groups, allocating the users to
1825        those groups."""
1826     memberships_added = 0
1827     groups_added = 0
1828     computers_added = 0
1829
1830     create_ou(ldb, instance_id)
1831
1832     LOGGER.info("Generating dummy user accounts")
1833     users_added = generate_users(ldb, instance_id, number_of_users, password)
1834
1835     LOGGER.info("Generating dummy machine accounts")
1836     computers_added = generate_machine_accounts(ldb, instance_id,
1837                                                 machine_accounts, password,
1838                                                 traffic_accounts)
1839
1840     if number_of_groups > 0:
1841         LOGGER.info("Generating dummy groups")
1842         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1843
1844     if group_memberships > 0:
1845         LOGGER.info("Assigning users to groups")
1846         assignments = GroupAssignments(number_of_groups,
1847                                        groups_added,
1848                                        number_of_users,
1849                                        users_added,
1850                                        group_memberships,
1851                                        max_members)
1852         LOGGER.info("Adding users to groups")
1853         add_users_to_groups(ldb, instance_id, assignments)
1854         memberships_added = assignments.total()
1855
1856     if (groups_added > 0 and users_added == 0 and
1857        number_of_groups != groups_added):
1858         LOGGER.warning("The added groups will contain no members")
1859
1860     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1861                 (users_added, computers_added, groups_added,
1862                  memberships_added))
1863
1864
1865 class GroupAssignments(object):
1866     def __init__(self, number_of_groups, groups_added, number_of_users,
1867                  users_added, group_memberships, max_members):
1868
1869         self.count = 0
1870         self.generate_group_distribution(number_of_groups)
1871         self.generate_user_distribution(number_of_users, group_memberships)
1872         self.max_members = max_members
1873         self.assignments = defaultdict(list)
1874         self.assign_groups(number_of_groups, groups_added, number_of_users,
1875                            users_added, group_memberships)
1876
1877     def cumulative_distribution(self, weights):
1878         # make sure the probabilities conform to a cumulative distribution
1879         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1880         # probability a proportional share of 1.0. Higher probabilities get a
1881         # bigger share, so are more likely to be picked. We use the cumulative
1882         # value, so we can use random.random() as a simple index into the list
1883         dist = []
1884         total = sum(weights)
1885         if total == 0:
1886             return None
1887
1888         cumulative = 0.0
1889         for probability in weights:
1890             cumulative += probability
1891             dist.append(cumulative / total)
1892         return dist
1893
1894     def generate_user_distribution(self, num_users, num_memberships):
1895         """Probability distribution of a user belonging to a group.
1896         """
1897         # Assign a weighted probability to each user. Use the Pareto
1898         # Distribution so that some users are in a lot of groups, and the
1899         # bulk of users are in only a few groups. If we're assigning a large
1900         # number of group memberships, use a higher shape. This means slightly
1901         # fewer outlying users that are in large numbers of groups. The aim is
1902         # to have no users belonging to more than ~500 groups.
1903         if num_memberships > 5000000:
1904             shape = 3.0
1905         elif num_memberships > 2000000:
1906             shape = 2.5
1907         elif num_memberships > 300000:
1908             shape = 2.25
1909         else:
1910             shape = 1.75
1911
1912         weights = []
1913         for x in range(1, num_users + 1):
1914             p = random.paretovariate(shape)
1915             weights.append(p)
1916
1917         # convert the weights to a cumulative distribution between 0.0 and 1.0
1918         self.user_dist = self.cumulative_distribution(weights)
1919
1920     def generate_group_distribution(self, n):
1921         """Probability distribution of a group containing a user."""
1922
1923         # Assign a weighted probability to each user. Probability decreases
1924         # as the group-ID increases
1925         weights = []
1926         for x in range(1, n + 1):
1927             p = 1 / (x**1.3)
1928             weights.append(p)
1929
1930         # convert the weights to a cumulative distribution between 0.0 and 1.0
1931         self.group_weights = weights
1932         self.group_dist = self.cumulative_distribution(weights)
1933
1934     def generate_random_membership(self):
1935         """Returns a randomly generated user-group membership"""
1936
1937         # the list items are cumulative distribution values between 0.0 and
1938         # 1.0, which makes random() a handy way to index the list to get a
1939         # weighted random user/group. (Here the user/group returned are
1940         # zero-based array indexes)
1941         user = bisect.bisect(self.user_dist, random.random())
1942         group = bisect.bisect(self.group_dist, random.random())
1943
1944         return user, group
1945
1946     def users_in_group(self, group):
1947         return self.assignments[group]
1948
1949     def get_groups(self):
1950         return self.assignments.keys()
1951
1952     def cap_group_membership(self, group, max_members):
1953         """Prevent the group's membership from exceeding the max specified"""
1954         num_members = len(self.assignments[group])
1955         if num_members >= max_members:
1956             LOGGER.info("Group {0} has {1} members".format(group, num_members))
1957
1958             # remove this group and then recalculate the cumulative
1959             # distribution, so this group is no longer selected
1960             self.group_weights[group - 1] = 0
1961             new_dist = self.cumulative_distribution(self.group_weights)
1962             self.group_dist = new_dist
1963
1964     def add_assignment(self, user, group):
1965         # the assignments are stored in a dictionary where key=group,
1966         # value=list-of-users-in-group (indexing by group-ID allows us to
1967         # optimize for DB membership writes)
1968         if user not in self.assignments[group]:
1969             self.assignments[group].append(user)
1970             self.count += 1
1971
1972         # check if there'a cap on how big the groups can grow
1973         if self.max_members:
1974             self.cap_group_membership(group, self.max_members)
1975
1976     def assign_groups(self, number_of_groups, groups_added,
1977                       number_of_users, users_added, group_memberships):
1978         """Allocate users to groups.
1979
1980         The intention is to have a few users that belong to most groups, while
1981         the majority of users belong to a few groups.
1982
1983         A few groups will contain most users, with the remaining only having a
1984         few users.
1985         """
1986
1987         if group_memberships <= 0:
1988             return
1989
1990         # Calculate the number of group menberships required
1991         group_memberships = math.ceil(
1992             float(group_memberships) *
1993             (float(users_added) / float(number_of_users)))
1994
1995         if self.max_members:
1996             group_memberships = min(group_memberships,
1997                                     self.max_members * number_of_groups)
1998
1999         existing_users  = number_of_users  - users_added  - 1
2000         existing_groups = number_of_groups - groups_added - 1
2001         while self.total() < group_memberships:
2002             user, group = self.generate_random_membership()
2003
2004             if group > existing_groups or user > existing_users:
2005                 # the + 1 converts the array index to the corresponding
2006                 # group or user number
2007                 self.add_assignment(user + 1, group + 1)
2008
2009     def total(self):
2010         return self.count
2011
2012
2013 def add_users_to_groups(db, instance_id, assignments):
2014     """Takes the assignments of users to groups and applies them to the DB."""
2015
2016     total = assignments.total()
2017     count = 0
2018     added = 0
2019
2020     for group in assignments.get_groups():
2021         users_in_group = assignments.users_in_group(group)
2022         if len(users_in_group) == 0:
2023             continue
2024
2025         # Split up the users into chunks, so we write no more than 1K at a
2026         # time. (Minimizing the DB modifies is more efficient, but writing
2027         # 10K+ users to a single group becomes inefficient memory-wise)
2028         for chunk in range(0, len(users_in_group), 1000):
2029             chunk_of_users = users_in_group[chunk:chunk + 1000]
2030             add_group_members(db, instance_id, group, chunk_of_users)
2031
2032             added += len(chunk_of_users)
2033             count += 1
2034             if count % 50 == 0:
2035                 LOGGER.info("Added %u/%u memberships" % (added, total))
2036
2037 def add_group_members(db, instance_id, group, users_in_group):
2038     """Adds the given users to group specified."""
2039
2040     ou = ou_name(db, instance_id)
2041
2042     def build_dn(name):
2043         return("cn=%s,%s" % (name, ou))
2044
2045     group_dn = build_dn(group_name(instance_id, group))
2046     m = ldb.Message()
2047     m.dn = ldb.Dn(db, group_dn)
2048
2049     for user in users_in_group:
2050         user_dn = build_dn(user_name(instance_id, user))
2051         idx = "member-" + str(user)
2052         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2053
2054     db.modify(m)
2055
2056
2057 def generate_stats(statsdir, timing_file):
2058     """Generate and print the summary stats for a run."""
2059     first      = sys.float_info.max
2060     last       = 0
2061     successful = 0
2062     failed     = 0
2063     latencies  = {}
2064     failures   = {}
2065     unique_converations = set()
2066     conversations = 0
2067
2068     if timing_file is not None:
2069         tw = timing_file.write
2070     else:
2071         def tw(x):
2072             pass
2073
2074     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2075
2076     for filename in os.listdir(statsdir):
2077         path = os.path.join(statsdir, filename)
2078         with open(path, 'r') as f:
2079             for line in f:
2080                 try:
2081                     fields       = line.rstrip('\n').split('\t')
2082                     conversation = fields[1]
2083                     protocol     = fields[2]
2084                     packet_type  = fields[3]
2085                     latency      = float(fields[4])
2086                     first        = min(float(fields[0]) - latency, first)
2087                     last         = max(float(fields[0]), last)
2088
2089                     if protocol not in latencies:
2090                         latencies[protocol] = {}
2091                     if packet_type not in latencies[protocol]:
2092                         latencies[protocol][packet_type] = []
2093
2094                     latencies[protocol][packet_type].append(latency)
2095
2096                     if protocol not in failures:
2097                         failures[protocol] = {}
2098                     if packet_type not in failures[protocol]:
2099                         failures[protocol][packet_type] = 0
2100
2101                     if fields[5] == 'True':
2102                         successful += 1
2103                     else:
2104                         failed += 1
2105                         failures[protocol][packet_type] += 1
2106
2107                     if conversation not in unique_converations:
2108                         unique_converations.add(conversation)
2109                         conversations += 1
2110
2111                     tw(line)
2112                 except (ValueError, IndexError):
2113                     # not a valid line print and ignore
2114                     print(line, file=sys.stderr)
2115                     pass
2116     duration = last - first
2117     if successful == 0:
2118         success_rate = 0
2119     else:
2120         success_rate = successful / duration
2121     if failed == 0:
2122         failure_rate = 0
2123     else:
2124         failure_rate = failed / duration
2125
2126     print("Total conversations:   %10d" % conversations)
2127     print("Successful operations: %10d (%.3f per second)"
2128           % (successful, success_rate))
2129     print("Failed operations:     %10d (%.3f per second)"
2130           % (failed, failure_rate))
2131
2132     print("Protocol    Op Code  Description                               "
2133           " Count       Failed         Mean       Median          "
2134           "95%        Range          Max")
2135
2136     protocols = sorted(latencies.keys())
2137     for protocol in protocols:
2138         packet_types = sorted(latencies[protocol], key=opcode_key)
2139         for packet_type in packet_types:
2140             values     = latencies[protocol][packet_type]
2141             values     = sorted(values)
2142             count      = len(values)
2143             failed     = failures[protocol][packet_type]
2144             mean       = sum(values) / count
2145             median     = calc_percentile(values, 0.50)
2146             percentile = calc_percentile(values, 0.95)
2147             rng        = values[-1] - values[0]
2148             maxv       = values[-1]
2149             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2150             if sys.stdout.isatty:
2151                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2152                       "%12.6f %12.6f %12.6f %12.6f"
2153                       % (protocol,
2154                          packet_type,
2155                          desc,
2156                          count,
2157                          failed,
2158                          mean,
2159                          median,
2160                          percentile,
2161                          rng,
2162                          maxv))
2163             else:
2164                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2165                       % (protocol,
2166                          packet_type,
2167                          desc,
2168                          count,
2169                          failed,
2170                          mean,
2171                          median,
2172                          percentile,
2173                          rng,
2174                          maxv))
2175
2176
2177 def opcode_key(v):
2178     """Sort key for the operation code to ensure that it sorts numerically"""
2179     try:
2180         return "%03d" % int(v)
2181     except:
2182         return v
2183
2184
2185 def calc_percentile(values, percentile):
2186     """Calculate the specified percentile from the list of values.
2187
2188     Assumes the list is sorted in ascending order.
2189     """
2190
2191     if not values:
2192         return 0
2193     k = (len(values) - 1) * percentile
2194     f = math.floor(k)
2195     c = math.ceil(k)
2196     if f == c:
2197         return values[int(k)]
2198     d0 = values[int(f)] * (c - k)
2199     d1 = values[int(c)] * (k - f)
2200     return d0 + d1
2201
2202
2203 def mk_masked_dir(*path):
2204     """In a testenv we end up with 0777 directories that look an alarming
2205     green colour with ls. Use umask to avoid that."""
2206     # py3 os.mkdir can do this
2207     d = os.path.join(*path)
2208     mask = os.umask(0o077)
2209     os.mkdir(d)
2210     os.umask(mask)
2211     return d