b6097cdc1201058360402dafad71d07390cacada
[garming/samba-autobuild/.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28
29 from collections import OrderedDict, Counter, defaultdict, namedtuple
30 from samba.emulate import traffic_packets
31 from samba.samdb import SamDB
32 import ldb
33 from ldb import LdbError
34 from samba.dcerpc import ClientConnection
35 from samba.dcerpc import security, drsuapi, lsa
36 from samba.dcerpc import netlogon
37 from samba.dcerpc.netlogon import netr_Authenticator
38 from samba.dcerpc import srvsvc
39 from samba.dcerpc import samr
40 from samba.drs_utils import drs_DsBind
41 import traceback
42 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
43 from samba.auth import system_session
44 from samba.dsdb import (
45     UF_NORMAL_ACCOUNT,
46     UF_SERVER_TRUST_ACCOUNT,
47     UF_TRUSTED_FOR_DELEGATION,
48     UF_WORKSTATION_TRUST_ACCOUNT
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
55 import bisect
56
57 SLEEP_OVERHEAD = 3e-4
58
59 # we don't use None, because it complicates [de]serialisation
60 NON_PACKET = '-'
61
62 CLIENT_CLUES = {
63     ('dns', '0'): 1.0,      # query
64     ('smb', '0x72'): 1.0,   # Negotiate protocol
65     ('ldap', '0'): 1.0,     # bind
66     ('ldap', '3'): 1.0,     # searchRequest
67     ('ldap', '2'): 1.0,     # unbindRequest
68     ('cldap', '3'): 1.0,
69     ('dcerpc', '11'): 1.0,  # bind
70     ('dcerpc', '14'): 1.0,  # Alter_context
71     ('nbns', '0'): 1.0,     # query
72 }
73
74 SERVER_CLUES = {
75     ('dns', '1'): 1.0,      # response
76     ('ldap', '1'): 1.0,     # bind response
77     ('ldap', '4'): 1.0,     # search result
78     ('ldap', '5'): 1.0,     # search done
79     ('cldap', '5'): 1.0,
80     ('dcerpc', '12'): 1.0,  # bind_ack
81     ('dcerpc', '13'): 1.0,  # bind_nak
82     ('dcerpc', '15'): 1.0,  # Alter_context response
83 }
84
85 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
86
87 WAIT_SCALE = 10.0
88 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
89 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
90
91 # DEBUG_LEVEL can be changed by scripts with -d
92 DEBUG_LEVEL = 0
93
94 LOGGER = get_samba_logger(name=__name__)
95
96
97 def debug(level, msg, *args):
98     """Print a formatted debug message to standard error.
99
100
101     :param level: The debug level, message will be printed if it is <= the
102                   currently set debug level. The debug level can be set with
103                   the -d option.
104     :param msg:   The message to be logged, can contain C-Style format
105                   specifiers
106     :param args:  The parameters required by the format specifiers
107     """
108     if level <= DEBUG_LEVEL:
109         if not args:
110             print(msg, file=sys.stderr)
111         else:
112             print(msg % tuple(args), file=sys.stderr)
113
114
115 def debug_lineno(*args):
116     """ Print an unformatted log message to stderr, contaning the line number
117     """
118     tb = traceback.extract_stack(limit=2)
119     print((" %s:" "\033[01;33m"
120            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
121           file=sys.stderr)
122     for a in args:
123         print(a, file=sys.stderr)
124     print(file=sys.stderr)
125     sys.stderr.flush()
126
127
128 def random_colour_print(seeds):
129     """Return a function that prints a coloured line to stderr. The colour
130     of the line depends on a sort of hash of the integer arguments."""
131     if seeds:
132         s = 214
133         for x in seeds:
134             s += 17
135             s *= x
136             s %= 214
137         prefix = "\033[38;5;%dm" % (18 + s)
138
139         def p(*args):
140             if DEBUG_LEVEL > 0:
141                 for a in args:
142                     print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
143     else:
144         def p(*args):
145             if DEBUG_LEVEL > 0:
146                 for a in args:
147                     print(a, file=sys.stderr)
148
149     return p
150
151
152 class FakePacketError(Exception):
153     pass
154
155
156 class Packet(object):
157     """Details of a network packet"""
158     __slots__ = ('timestamp',
159                  'ip_protocol',
160                  'stream_number',
161                  'src',
162                  'dest',
163                  'protocol',
164                  'opcode',
165                  'desc',
166                  'extra',
167                  'endpoints')
168     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
169                  protocol, opcode, desc, extra):
170         self.timestamp = timestamp
171         self.ip_protocol = ip_protocol
172         self.stream_number = stream_number
173         self.src = src
174         self.dest = dest
175         self.protocol = protocol
176         self.opcode = opcode
177         self.desc = desc
178         self.extra = extra
179         if self.src < self.dest:
180             self.endpoints = (self.src, self.dest)
181         else:
182             self.endpoints = (self.dest, self.src)
183
184     @classmethod
185     def from_line(cls, line):
186         fields = line.rstrip('\n').split('\t')
187         (timestamp,
188          ip_protocol,
189          stream_number,
190          src,
191          dest,
192          protocol,
193          opcode,
194          desc) = fields[:8]
195         extra = fields[8:]
196
197         timestamp = float(timestamp)
198         src = int(src)
199         dest = int(dest)
200
201         return cls(timestamp, ip_protocol, stream_number, src, dest,
202                    protocol, opcode, desc, extra)
203
204     def as_summary(self, time_offset=0.0):
205         """Format the packet as a traffic_summary line.
206         """
207         extra = '\t'.join(self.extra)
208         t = self.timestamp + time_offset
209         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
210                 (t,
211                  self.ip_protocol,
212                  self.stream_number or '',
213                  self.src,
214                  self.dest,
215                  self.protocol,
216                  self.opcode,
217                  self.desc,
218                  extra))
219
220     def __str__(self):
221         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
222                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
223                  self.stream_number, self.protocol, self.opcode, self.desc,
224                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
225
226     def __repr__(self):
227         return "<Packet @%s>" % self
228
229     def copy(self):
230         return self.__class__(self.timestamp,
231                               self.ip_protocol,
232                               self.stream_number,
233                               self.src,
234                               self.dest,
235                               self.protocol,
236                               self.opcode,
237                               self.desc,
238                               self.extra)
239
240     def as_packet_type(self):
241         t = '%s:%s' % (self.protocol, self.opcode)
242         return t
243
244     def client_score(self):
245         """A positive number means we think it is a client; a negative number
246         means we think it is a server. Zero means no idea. range: -1 to 1.
247         """
248         key = (self.protocol, self.opcode)
249         if key in CLIENT_CLUES:
250             return CLIENT_CLUES[key]
251         if key in SERVER_CLUES:
252             return -SERVER_CLUES[key]
253         return 0.0
254
255     def play(self, conversation, context):
256         """Send the packet over the network, if required.
257
258         Some packets are ignored, i.e. for  protocols not handled,
259         server response messages, or messages that are generated by the
260         protocol layer associated with other packets.
261         """
262         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
263         try:
264             fn = getattr(traffic_packets, fn_name)
265
266         except AttributeError as e:
267             print("Conversation(%s) Missing handler %s" %
268                   (conversation.conversation_id, fn_name),
269                   file=sys.stderr)
270             return
271
272         # Don't display a message for kerberos packets, they're not directly
273         # generated they're used to indicate kerberos should be used
274         if self.protocol != "kerberos":
275             debug(2, "Conversation(%s) Calling handler %s" %
276                      (conversation.conversation_id, fn_name))
277
278         start = time.time()
279         try:
280             if fn(self, conversation, context):
281                 # Only collect timing data for functions that generate
282                 # network traffic, or fail
283                 end = time.time()
284                 duration = end - start
285                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
286                       (end, conversation.conversation_id, self.protocol,
287                        self.opcode, duration))
288         except Exception as e:
289             end = time.time()
290             duration = end - start
291             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
292                   (end, conversation.conversation_id, self.protocol,
293                    self.opcode, duration, e))
294
295     def __cmp__(self, other):
296         return self.timestamp - other.timestamp
297
298     def is_really_a_packet(self, missing_packet_stats=None):
299         return is_a_real_packet(self.protocol, self.opcode)
300
301
302 def is_a_real_packet(protocol, opcode):
303     """Is the packet one that can be ignored?
304
305     If so removing it will have no effect on the replay
306     """
307     if protocol in SKIPPED_PROTOCOLS:
308         # Ignore any packets for the protocols we're not interested in.
309         return False
310     if protocol == "ldap" and opcode == '':
311         # skip ldap continuation packets
312         return False
313
314     fn_name = 'packet_%s_%s' % (protocol, opcode)
315     fn = getattr(traffic_packets, fn_name, None)
316     if fn is None:
317         LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
318         return False
319     if fn is traffic_packets.null_packet:
320         return False
321     return True
322
323
324 class ReplayContext(object):
325     """State/Context for a conversation between an simulated client and a
326        server. Some of the context is shared amongst all conversations
327        and should be generated before the fork, while other context is
328        specific to a particular conversation and should be generated
329        *after* the fork, in generate_process_local_config().
330     """
331     def __init__(self,
332                  server=None,
333                  lp=None,
334                  creds=None,
335                  badpassword_frequency=None,
336                  prefer_kerberos=None,
337                  tempdir=None,
338                  statsdir=None,
339                  ou=None,
340                  base_dn=None,
341                  domain=None,
342                  domain_sid=None):
343
344         self.server                   = server
345         self.netlogon_connection      = None
346         self.creds                    = creds
347         self.lp                       = lp
348         self.prefer_kerberos          = prefer_kerberos
349         self.ou                       = ou
350         self.base_dn                  = base_dn
351         self.domain                   = domain
352         self.statsdir                 = statsdir
353         self.global_tempdir           = tempdir
354         self.domain_sid               = domain_sid
355         self.realm                    = lp.get('realm')
356
357         # Bad password attempt controls
358         self.badpassword_frequency    = badpassword_frequency
359         self.last_lsarpc_bad          = False
360         self.last_lsarpc_named_bad    = False
361         self.last_simple_bind_bad     = False
362         self.last_bind_bad            = False
363         self.last_srvsvc_bad          = False
364         self.last_drsuapi_bad         = False
365         self.last_netlogon_bad        = False
366         self.last_samlogon_bad        = False
367         self.generate_ldap_search_tables()
368
369     def generate_ldap_search_tables(self):
370         session = system_session()
371
372         db = SamDB(url="ldap://%s" % self.server,
373                    session_info=session,
374                    credentials=self.creds,
375                    lp=self.lp)
376
377         res = db.search(db.domain_dn(),
378                         scope=ldb.SCOPE_SUBTREE,
379                         controls=["paged_results:1:1000"],
380                         attrs=['dn'])
381
382         # find a list of dns for each pattern
383         # e.g. CN,CN,CN,DC,DC
384         dn_map = {}
385         attribute_clue_map = {
386             'invocationId': []
387         }
388
389         for r in res:
390             dn = str(r.dn)
391             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
392             dns = dn_map.setdefault(pattern, [])
393             dns.append(dn)
394             if dn.startswith('CN=NTDS Settings,'):
395                 attribute_clue_map['invocationId'].append(dn)
396
397         # extend the map in case we are working with a different
398         # number of DC components.
399         # for k, v in self.dn_map.items():
400         #     print >>sys.stderr, k, len(v)
401
402         for k in list(dn_map.keys()):
403             if k[-3:] != ',DC':
404                 continue
405             p = k[:-3]
406             while p[-3:] == ',DC':
407                 p = p[:-3]
408             for i in range(5):
409                 p += ',DC'
410                 if p != k and p in dn_map:
411                     print('dn_map collison %s %s' % (k, p),
412                           file=sys.stderr)
413                     continue
414                 dn_map[p] = dn_map[k]
415
416         self.dn_map = dn_map
417         self.attribute_clue_map = attribute_clue_map
418
419     def generate_process_local_config(self, account, conversation):
420         self.ldap_connections         = []
421         self.dcerpc_connections       = []
422         self.lsarpc_connections       = []
423         self.lsarpc_connections_named = []
424         self.drsuapi_connections      = []
425         self.srvsvc_connections       = []
426         self.samr_contexts            = []
427         self.netbios_name             = account.netbios_name
428         self.machinepass              = account.machinepass
429         self.username                 = account.username
430         self.userpass                 = account.userpass
431
432         self.tempdir = mk_masked_dir(self.global_tempdir,
433                                      'conversation-%d' %
434                                      conversation.conversation_id)
435
436         self.lp.set("private dir", self.tempdir)
437         self.lp.set("lock dir", self.tempdir)
438         self.lp.set("state directory", self.tempdir)
439         self.lp.set("tls verify peer", "no_check")
440
441         # If the domain was not specified, check for the environment
442         # variable.
443         if self.domain is None:
444             self.domain = os.environ["DOMAIN"]
445
446         self.remoteAddress = "/root/ncalrpc_as_system"
447         self.samlogon_dn   = ("cn=%s,%s" %
448                               (self.netbios_name, self.ou))
449         self.user_dn       = ("cn=%s,%s" %
450                               (self.username, self.ou))
451
452         self.generate_machine_creds()
453         self.generate_user_creds()
454
455     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
456         """Execute the supplied logon function, randomly choosing the
457            bad credentials.
458
459            Based on the frequency in badpassword_frequency randomly perform the
460            function with the supplied bad credentials.
461            If run with bad credentials, the function is re-run with the good
462            credentials.
463            failed_last_time is used to prevent consecutive bad credential
464            attempts. So the over all bad credential frequency will be lower
465            than that requested, but not significantly.
466         """
467         if not failed_last_time:
468             if (self.badpassword_frequency and self.badpassword_frequency > 0
469                 and random.random() < self.badpassword_frequency):
470                 try:
471                     f(bad)
472                 except:
473                     # Ignore any exceptions as the operation may fail
474                     # as it's being performed with bad credentials
475                     pass
476                 failed_last_time = True
477             else:
478                 failed_last_time = False
479
480         result = f(good)
481         return (result, failed_last_time)
482
483     def generate_user_creds(self):
484         """Generate the conversation specific user Credentials.
485
486         Each Conversation has an associated user account used to simulate
487         any non Administrative user traffic.
488
489         Generates user credentials with good and bad passwords and ldap
490         simple bind credentials with good and bad passwords.
491         """
492         self.user_creds = Credentials()
493         self.user_creds.guess(self.lp)
494         self.user_creds.set_workstation(self.netbios_name)
495         self.user_creds.set_password(self.userpass)
496         self.user_creds.set_username(self.username)
497         self.user_creds.set_domain(self.domain)
498         if self.prefer_kerberos:
499             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
500         else:
501             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
502
503         self.user_creds_bad = Credentials()
504         self.user_creds_bad.guess(self.lp)
505         self.user_creds_bad.set_workstation(self.netbios_name)
506         self.user_creds_bad.set_password(self.userpass[:-4])
507         self.user_creds_bad.set_username(self.username)
508         if self.prefer_kerberos:
509             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
510         else:
511             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
512
513         # Credentials for ldap simple bind.
514         self.simple_bind_creds = Credentials()
515         self.simple_bind_creds.guess(self.lp)
516         self.simple_bind_creds.set_workstation(self.netbios_name)
517         self.simple_bind_creds.set_password(self.userpass)
518         self.simple_bind_creds.set_username(self.username)
519         self.simple_bind_creds.set_gensec_features(
520             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
521         if self.prefer_kerberos:
522             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
523         else:
524             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
525         self.simple_bind_creds.set_bind_dn(self.user_dn)
526
527         self.simple_bind_creds_bad = Credentials()
528         self.simple_bind_creds_bad.guess(self.lp)
529         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
530         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
531         self.simple_bind_creds_bad.set_username(self.username)
532         self.simple_bind_creds_bad.set_gensec_features(
533             self.simple_bind_creds_bad.get_gensec_features() |
534             gensec.FEATURE_SEAL)
535         if self.prefer_kerberos:
536             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
537         else:
538             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
539         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
540
541     def generate_machine_creds(self):
542         """Generate the conversation specific machine Credentials.
543
544         Each Conversation has an associated machine account.
545
546         Generates machine credentials with good and bad passwords.
547         """
548
549         self.machine_creds = Credentials()
550         self.machine_creds.guess(self.lp)
551         self.machine_creds.set_workstation(self.netbios_name)
552         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
553         self.machine_creds.set_password(self.machinepass)
554         self.machine_creds.set_username(self.netbios_name + "$")
555         self.machine_creds.set_domain(self.domain)
556         if self.prefer_kerberos:
557             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
558         else:
559             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
560
561         self.machine_creds_bad = Credentials()
562         self.machine_creds_bad.guess(self.lp)
563         self.machine_creds_bad.set_workstation(self.netbios_name)
564         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
565         self.machine_creds_bad.set_password(self.machinepass[:-4])
566         self.machine_creds_bad.set_username(self.netbios_name + "$")
567         if self.prefer_kerberos:
568             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
569         else:
570             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
571
572     def get_matching_dn(self, pattern, attributes=None):
573         # If the pattern is an empty string, we assume ROOTDSE,
574         # Otherwise we try adding or removing DC suffixes, then
575         # shorter leading patterns until we hit one.
576         # e.g if there is no CN,CN,CN,CN,DC,DC
577         # we first try       CN,CN,CN,CN,DC
578         # and                CN,CN,CN,CN,DC,DC,DC
579         # then change to        CN,CN,CN,DC,DC
580         # and as last resort we use the base_dn
581         attr_clue = self.attribute_clue_map.get(attributes)
582         if attr_clue:
583             return random.choice(attr_clue)
584
585         pattern = pattern.upper()
586         while pattern:
587             if pattern in self.dn_map:
588                 return random.choice(self.dn_map[pattern])
589             # chop one off the front and try it all again.
590             pattern = pattern[3:]
591
592         return self.base_dn
593
594     def get_dcerpc_connection(self, new=False):
595         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
596         if self.dcerpc_connections and not new:
597             return self.dcerpc_connections[-1]
598         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
599                              (guid, 1), self.lp)
600         self.dcerpc_connections.append(c)
601         return c
602
603     def get_srvsvc_connection(self, new=False):
604         if self.srvsvc_connections and not new:
605             return self.srvsvc_connections[-1]
606
607         def connect(creds):
608             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
609                                  self.lp,
610                                  creds)
611
612         (c, self.last_srvsvc_bad) = \
613             self.with_random_bad_credentials(connect,
614                                              self.user_creds,
615                                              self.user_creds_bad,
616                                              self.last_srvsvc_bad)
617
618         self.srvsvc_connections.append(c)
619         return c
620
621     def get_lsarpc_connection(self, new=False):
622         if self.lsarpc_connections and not new:
623             return self.lsarpc_connections[-1]
624
625         def connect(creds):
626             binding_options = 'schannel,seal,sign'
627             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
628                               (self.server, binding_options),
629                               self.lp,
630                               creds)
631
632         (c, self.last_lsarpc_bad) = \
633             self.with_random_bad_credentials(connect,
634                                              self.machine_creds,
635                                              self.machine_creds_bad,
636                                              self.last_lsarpc_bad)
637
638         self.lsarpc_connections.append(c)
639         return c
640
641     def get_lsarpc_named_pipe_connection(self, new=False):
642         if self.lsarpc_connections_named and not new:
643             return self.lsarpc_connections_named[-1]
644
645         def connect(creds):
646             return lsa.lsarpc("ncacn_np:%s" % (self.server),
647                               self.lp,
648                               creds)
649
650         (c, self.last_lsarpc_named_bad) = \
651             self.with_random_bad_credentials(connect,
652                                              self.machine_creds,
653                                              self.machine_creds_bad,
654                                              self.last_lsarpc_named_bad)
655
656         self.lsarpc_connections_named.append(c)
657         return c
658
659     def get_drsuapi_connection_pair(self, new=False, unbind=False):
660         """get a (drs, drs_handle) tuple"""
661         if self.drsuapi_connections and not new:
662             c = self.drsuapi_connections[-1]
663             return c
664
665         def connect(creds):
666             binding_options = 'seal'
667             binding_string = "ncacn_ip_tcp:%s[%s]" %\
668                              (self.server, binding_options)
669             return drsuapi.drsuapi(binding_string, self.lp, creds)
670
671         (drs, self.last_drsuapi_bad) = \
672             self.with_random_bad_credentials(connect,
673                                              self.user_creds,
674                                              self.user_creds_bad,
675                                              self.last_drsuapi_bad)
676
677         (drs_handle, supported_extensions) = drs_DsBind(drs)
678         c = (drs, drs_handle)
679         self.drsuapi_connections.append(c)
680         return c
681
682     def get_ldap_connection(self, new=False, simple=False):
683         if self.ldap_connections and not new:
684             return self.ldap_connections[-1]
685
686         def simple_bind(creds):
687             """
688             To run simple bind against Windows, we need to run
689             following commands in PowerShell:
690
691                 Install-windowsfeature ADCS-Cert-Authority
692                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
693                 Restart-Computer
694
695             """
696             return SamDB('ldaps://%s' % self.server,
697                          credentials=creds,
698                          lp=self.lp)
699
700         def sasl_bind(creds):
701             return SamDB('ldap://%s' % self.server,
702                          credentials=creds,
703                          lp=self.lp)
704         if simple:
705             (samdb, self.last_simple_bind_bad) = \
706                 self.with_random_bad_credentials(simple_bind,
707                                                  self.simple_bind_creds,
708                                                  self.simple_bind_creds_bad,
709                                                  self.last_simple_bind_bad)
710         else:
711             (samdb, self.last_bind_bad) = \
712                 self.with_random_bad_credentials(sasl_bind,
713                                                  self.user_creds,
714                                                  self.user_creds_bad,
715                                                  self.last_bind_bad)
716
717         self.ldap_connections.append(samdb)
718         return samdb
719
720     def get_samr_context(self, new=False):
721         if not self.samr_contexts or new:
722             self.samr_contexts.append(
723                 SamrContext(self.server, lp=self.lp, creds=self.creds))
724         return self.samr_contexts[-1]
725
726     def get_netlogon_connection(self):
727
728         if self.netlogon_connection:
729             return self.netlogon_connection
730
731         def connect(creds):
732             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
733                                      (self.server),
734                                      self.lp,
735                                      creds)
736         (c, self.last_netlogon_bad) = \
737             self.with_random_bad_credentials(connect,
738                                              self.machine_creds,
739                                              self.machine_creds_bad,
740                                              self.last_netlogon_bad)
741         self.netlogon_connection = c
742         return c
743
744     def guess_a_dns_lookup(self):
745         return (self.realm, 'A')
746
747     def get_authenticator(self):
748         auth = self.machine_creds.new_client_authenticator()
749         current  = netr_Authenticator()
750         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
751         current.timestamp = auth["timestamp"]
752
753         subsequent = netr_Authenticator()
754         return (current, subsequent)
755
756
757 class SamrContext(object):
758     """State/Context associated with a samr connection.
759     """
760     def __init__(self, server, lp=None, creds=None):
761         self.connection    = None
762         self.handle        = None
763         self.domain_handle = None
764         self.domain_sid    = None
765         self.group_handle  = None
766         self.user_handle   = None
767         self.rids          = None
768         self.server        = server
769         self.lp            = lp
770         self.creds         = creds
771
772     def get_connection(self):
773         if not self.connection:
774             self.connection = samr.samr(
775                 "ncacn_ip_tcp:%s[seal]" % (self.server),
776                 lp_ctx=self.lp,
777                 credentials=self.creds)
778
779         return self.connection
780
781     def get_handle(self):
782         if not self.handle:
783             c = self.get_connection()
784             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
785         return self.handle
786
787
788 class Conversation(object):
789     """Details of a converation between a simulated client and a server."""
790     def __init__(self, start_time=None, endpoints=None, seq=(),
791                  conversation_id=None):
792         self.start_time = start_time
793         self.endpoints = endpoints
794         self.packets = []
795         self.msg = random_colour_print(endpoints)
796         self.client_balance = 0.0
797         self.conversation_id = conversation_id
798         for p in seq:
799             self.add_short_packet(*p)
800
801     def __cmp__(self, other):
802         if self.start_time is None:
803             if other.start_time is None:
804                 return 0
805             return -1
806         if other.start_time is None:
807             return 1
808         return self.start_time - other.start_time
809
810     def add_packet(self, packet):
811         """Add a packet object to this conversation, making a local copy with
812         a conversation-relative timestamp."""
813         p = packet.copy()
814
815         if self.start_time is None:
816             self.start_time = p.timestamp
817
818         if self.endpoints is None:
819             self.endpoints = p.endpoints
820
821         if p.endpoints != self.endpoints:
822             raise FakePacketError("Conversation endpoints %s don't match"
823                                   "packet endpoints %s" %
824                                   (self.endpoints, p.endpoints))
825
826         p.timestamp -= self.start_time
827
828         if p.src == p.endpoints[0]:
829             self.client_balance -= p.client_score()
830         else:
831             self.client_balance += p.client_score()
832
833         if p.is_really_a_packet():
834             self.packets.append(p)
835
836     def add_short_packet(self, timestamp, protocol, opcode, extra,
837                          client=True):
838         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
839         (possibly empty) list of extra data. If client is True, assume
840         this packet is from the client to the server.
841         """
842         src, dest = self.guess_client_server()
843         if not client:
844             src, dest = dest, src
845         key = (protocol, opcode)
846         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
847         if protocol in IP_PROTOCOLS:
848             ip_protocol = IP_PROTOCOLS[protocol]
849         else:
850             ip_protocol = '06'
851         packet = Packet(timestamp - self.start_time, ip_protocol,
852                         '', src, dest,
853                         protocol, opcode, desc, extra)
854         # XXX we're assuming the timestamp is already adjusted for
855         # this conversation?
856         # XXX should we adjust client balance for guessed packets?
857         if packet.src == packet.endpoints[0]:
858             self.client_balance -= packet.client_score()
859         else:
860             self.client_balance += packet.client_score()
861         if packet.is_really_a_packet():
862             self.packets.append(packet)
863
864     def __str__(self):
865         return ("<Conversation %s %s starting %.3f %d packets>" %
866                 (self.conversation_id, self.endpoints, self.start_time,
867                  len(self.packets)))
868
869     __repr__ = __str__
870
871     def __iter__(self):
872         return iter(self.packets)
873
874     def __len__(self):
875         return len(self.packets)
876
877     def get_duration(self):
878         if len(self.packets) < 2:
879             return 0
880         return self.packets[-1].timestamp - self.packets[0].timestamp
881
882     def replay_as_summary_lines(self):
883         lines = []
884         for p in self.packets:
885             lines.append(p.as_summary(self.start_time))
886         return lines
887
888     def replay_in_fork_with_delay(self, start, context=None, account=None):
889         """Fork a new process and replay the conversation.
890         """
891         def signal_handler(signal, frame):
892             """Signal handler closes standard out and error.
893
894             Triggered by a sigterm, ensures that the log messages are flushed
895             to disk and not lost.
896             """
897             sys.stderr.close()
898             sys.stdout.close()
899             os._exit(0)
900
901         t = self.start_time
902         now = time.time() - start
903         gap = t - now
904         # we are replaying strictly in order, so it is safe to sleep
905         # in the main process if the gap is big enough. This reduces
906         # the number of concurrent threads, which allows us to make
907         # larger loads.
908         if gap > 0.15 and False:
909             print("sleeping for %f in main process" % (gap - 0.1),
910                   file=sys.stderr)
911             time.sleep(gap - 0.1)
912             now = time.time() - start
913             gap = t - now
914             print("gap is now %f" % gap, file=sys.stderr)
915
916         self.conversation_id = next(context.next_conversation_id)
917         pid = os.fork()
918         if pid != 0:
919             return pid
920         pid = os.getpid()
921         signal.signal(signal.SIGTERM, signal_handler)
922         # we must never return, or we'll end up running parts of the
923         # parent's clean-up code. So we work in a try...finally, and
924         # try to print any exceptions.
925
926         try:
927             context.generate_process_local_config(account, self)
928             sys.stdin.close()
929             os.close(0)
930             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
931                                     self.conversation_id)
932             sys.stdout.close()
933             sys.stdout = open(filename, 'w')
934
935             sleep_time = gap - SLEEP_OVERHEAD
936             if sleep_time > 0:
937                 time.sleep(sleep_time)
938
939             miss = t - (time.time() - start)
940             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
941             self.replay(context)
942         except Exception:
943             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
944                   file=sys.stderr)
945             traceback.print_exc(sys.stderr)
946         finally:
947             sys.stderr.close()
948             sys.stdout.close()
949             os._exit(0)
950
951     def replay(self, context=None):
952         start = time.time()
953
954         for p in self.packets:
955             now = time.time() - start
956             gap = p.timestamp - now
957             sleep_time = gap - SLEEP_OVERHEAD
958             if sleep_time > 0:
959                 time.sleep(sleep_time)
960
961             miss = p.timestamp - (time.time() - start)
962             if context is None:
963                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
964                                                            os.getpid()))
965                 continue
966             p.play(self, context)
967
968     def guess_client_server(self, server_clue=None):
969         """Have a go at deciding who is the server and who is the client.
970         returns (client, server)
971         """
972         a, b = self.endpoints
973
974         if self.client_balance < 0:
975             return (a, b)
976
977         # in the absense of a clue, we will fall through to assuming
978         # the lowest number is the server (which is usually true).
979
980         if self.client_balance == 0 and server_clue == b:
981             return (a, b)
982
983         return (b, a)
984
985     def forget_packets_outside_window(self, s, e):
986         """Prune any packets outside the timne window we're interested in
987
988         :param s: start of the window
989         :param e: end of the window
990         """
991         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
992         self.start_time = self.packets[0].timestamp if self.packets else None
993
994     def renormalise_times(self, start_time):
995         """Adjust the packet start times relative to the new start time."""
996         for p in self.packets:
997             p.timestamp -= start_time
998
999         if self.start_time is not None:
1000             self.start_time -= start_time
1001
1002
1003 class DnsHammer(Conversation):
1004     """A lightweight conversation that generates a lot of dns:0 packets on
1005     the fly"""
1006
1007     def __init__(self, dns_rate, duration):
1008         n = int(dns_rate * duration)
1009         self.times = [random.uniform(0, duration) for i in range(n)]
1010         self.times.sort()
1011         self.rate = dns_rate
1012         self.duration = duration
1013         self.start_time = 0
1014         self.msg = random_colour_print()
1015
1016     def __str__(self):
1017         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1018                 (len(self.times), self.duration, self.rate))
1019
1020     def replay_in_fork_with_delay(self, start, context=None, account=None):
1021         return Conversation.replay_in_fork_with_delay(self,
1022                                                       start,
1023                                                       context,
1024                                                       account)
1025
1026     def replay(self, context=None):
1027         start = time.time()
1028         fn = traffic_packets.packet_dns_0
1029         for t in self.times:
1030             now = time.time() - start
1031             gap = t - now
1032             sleep_time = gap - SLEEP_OVERHEAD
1033             if sleep_time > 0:
1034                 time.sleep(sleep_time)
1035
1036             if context is None:
1037                 miss = t - (time.time() - start)
1038                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1039                                                            os.getpid()))
1040                 continue
1041
1042             packet_start = time.time()
1043             try:
1044                 fn(self, self, context)
1045                 end = time.time()
1046                 duration = end - packet_start
1047                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1048             except Exception as e:
1049                 end = time.time()
1050                 duration = end - packet_start
1051                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1052
1053
1054 def ingest_summaries(files, dns_mode='count'):
1055     """Load a summary traffic summary file and generated Converations from it.
1056     """
1057
1058     dns_counts = defaultdict(int)
1059     packets = []
1060     for f in files:
1061         if isinstance(f, str):
1062             f = open(f)
1063         print("Ingesting %s" % (f.name,), file=sys.stderr)
1064         for line in f:
1065             p = Packet.from_line(line)
1066             if p.protocol == 'dns' and dns_mode != 'include':
1067                 dns_counts[p.opcode] += 1
1068             else:
1069                 packets.append(p)
1070
1071         f.close()
1072
1073     if not packets:
1074         return [], 0
1075
1076     start_time = min(p.timestamp for p in packets)
1077     last_packet = max(p.timestamp for p in packets)
1078
1079     print("gathering packets into conversations", file=sys.stderr)
1080     conversations = OrderedDict()
1081     for i, p in enumerate(packets):
1082         p.timestamp -= start_time
1083         c = conversations.get(p.endpoints)
1084         if c is None:
1085             c = Conversation(conversation_id=(i + 2))
1086             conversations[p.endpoints] = c
1087         c.add_packet(p)
1088
1089     # We only care about conversations with actual traffic, so we
1090     # filter out conversations with nothing to say. We do that here,
1091     # rather than earlier, because those empty packets contain useful
1092     # hints as to which end of the conversation was the client.
1093     conversation_list = []
1094     for c in conversations.values():
1095         if len(c) != 0:
1096             conversation_list.append(c)
1097
1098     # This is obviously not correct, as many conversations will appear
1099     # to start roughly simultaneously at the beginning of the snapshot.
1100     # To which we say: oh well, so be it.
1101     duration = float(last_packet - start_time)
1102     mean_interval = len(conversations) / duration
1103
1104     return conversation_list, mean_interval, duration, dns_counts
1105
1106
1107 def guess_server_address(conversations):
1108     # we guess the most common address.
1109     addresses = Counter()
1110     for c in conversations:
1111         addresses.update(c.endpoints)
1112     if addresses:
1113         return addresses.most_common(1)[0]
1114
1115
1116 def stringify_keys(x):
1117     y = {}
1118     for k, v in x.items():
1119         k2 = '\t'.join(k)
1120         y[k2] = v
1121     return y
1122
1123
1124 def unstringify_keys(x):
1125     y = {}
1126     for k, v in x.items():
1127         t = tuple(str(k).split('\t'))
1128         y[t] = v
1129     return y
1130
1131
1132 class TrafficModel(object):
1133     def __init__(self, n=3):
1134         self.ngrams = {}
1135         self.query_details = {}
1136         self.n = n
1137         self.dns_opcounts = defaultdict(int)
1138         self.cumulative_duration = 0.0
1139         self.conversation_rate = [0, 1]
1140
1141     def learn(self, conversations, dns_opcounts={}):
1142         prev = 0.0
1143         cum_duration = 0.0
1144         key = (NON_PACKET,) * (self.n - 1)
1145
1146         server = guess_server_address(conversations)
1147
1148         for k, v in dns_opcounts.items():
1149             self.dns_opcounts[k] += v
1150
1151         if len(conversations) > 1:
1152             elapsed =\
1153                 conversations[-1].start_time - conversations[0].start_time
1154             self.conversation_rate[0] = len(conversations)
1155             self.conversation_rate[1] = elapsed
1156
1157         for c in conversations:
1158             client, server = c.guess_client_server(server)
1159             cum_duration += c.get_duration()
1160             key = (NON_PACKET,) * (self.n - 1)
1161             for p in c:
1162                 if p.src != client:
1163                     continue
1164
1165                 elapsed = p.timestamp - prev
1166                 prev = p.timestamp
1167                 if elapsed > WAIT_THRESHOLD:
1168                     # add the wait as an extra state
1169                     wait = 'wait:%d' % (math.log(max(1.0,
1170                                                      elapsed * WAIT_SCALE)))
1171                     self.ngrams.setdefault(key, []).append(wait)
1172                     key = key[1:] + (wait,)
1173
1174                 short_p = p.as_packet_type()
1175                 self.query_details.setdefault(short_p,
1176                                               []).append(tuple(p.extra))
1177                 self.ngrams.setdefault(key, []).append(short_p)
1178                 key = key[1:] + (short_p,)
1179
1180         self.cumulative_duration += cum_duration
1181         # add in the end
1182         self.ngrams.setdefault(key, []).append(NON_PACKET)
1183
1184     def save(self, f):
1185         ngrams = {}
1186         for k, v in self.ngrams.items():
1187             k = '\t'.join(k)
1188             ngrams[k] = dict(Counter(v))
1189
1190         query_details = {}
1191         for k, v in self.query_details.items():
1192             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1193                                             for x in v))
1194
1195         d = {
1196             'ngrams': ngrams,
1197             'query_details': query_details,
1198             'cumulative_duration': self.cumulative_duration,
1199             'conversation_rate': self.conversation_rate,
1200         }
1201         d['dns'] = self.dns_opcounts
1202
1203         if isinstance(f, str):
1204             f = open(f, 'w')
1205
1206         json.dump(d, f, indent=2)
1207
1208     def load(self, f):
1209         if isinstance(f, str):
1210             f = open(f)
1211
1212         d = json.load(f)
1213
1214         for k, v in d['ngrams'].items():
1215             k = tuple(str(k).split('\t'))
1216             values = self.ngrams.setdefault(k, [])
1217             for p, count in v.items():
1218                 values.extend([str(p)] * count)
1219             values.sort()
1220
1221         for k, v in d['query_details'].items():
1222             values = self.query_details.setdefault(str(k), [])
1223             for p, count in v.items():
1224                 if p == '-':
1225                     values.extend([()] * count)
1226                 else:
1227                     values.extend([tuple(str(p).split('\t'))] * count)
1228             values.sort()
1229
1230         if 'dns' in d:
1231             for k, v in d['dns'].items():
1232                 self.dns_opcounts[k] += v
1233
1234         self.cumulative_duration = d['cumulative_duration']
1235         self.conversation_rate = d['conversation_rate']
1236
1237     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1238                                hard_stop=None, replay_speed=1):
1239         """Construct a individual converation from the model."""
1240
1241         c = Conversation(timestamp, (server, client), conversation_id=client)
1242
1243         key = (NON_PACKET,) * (self.n - 1)
1244
1245         while key in self.ngrams:
1246             p = random.choice(self.ngrams.get(key, NON_PACKET))
1247             if p == NON_PACKET:
1248                 break
1249
1250             if p in self.query_details:
1251                 extra = random.choice(self.query_details[p])
1252             else:
1253                 extra = []
1254
1255             protocol, opcode = p.split(':', 1)
1256             if protocol == 'wait':
1257                 log_wait_time = int(opcode) + random.random()
1258                 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1259                 timestamp += wait
1260             else:
1261                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1262                 wait = math.exp(log_wait) / replay_speed
1263                 timestamp += wait
1264                 if hard_stop is not None and timestamp > hard_stop:
1265                     break
1266                 c.add_short_packet(timestamp, protocol, opcode, extra)
1267
1268             key = key[1:] + (p,)
1269
1270         return c
1271
1272     def generate_conversations(self, rate, duration, replay_speed=1):
1273         """Generate a list of conversations from the model."""
1274
1275         # We run the simulation for at least ten times as long as our
1276         # desired duration, and take a section near the start.
1277         rate_n, rate_t  = self.conversation_rate
1278
1279         duration2 = max(rate_t, duration * 2)
1280         n = rate * duration2 * rate_n / rate_t
1281
1282         server = 1
1283         client = 2
1284
1285         conversations = []
1286         end = duration2
1287         start = end - duration
1288
1289         while client < n + 2:
1290             start = random.uniform(0, duration2)
1291             c = self.construct_conversation(start,
1292                                             client,
1293                                             server,
1294                                             hard_stop=(duration2 * 5),
1295                                             replay_speed=replay_speed)
1296
1297             c.forget_packets_outside_window(start, end)
1298             c.renormalise_times(start)
1299             if len(c) != 0:
1300                 conversations.append(c)
1301             client += 1
1302
1303         print(("we have %d conversations at rate %f" %
1304                (len(conversations), rate)), file=sys.stderr)
1305         conversations.sort()
1306         return conversations
1307
1308
1309 IP_PROTOCOLS = {
1310     'dns': '11',
1311     'rpc_netlogon': '06',
1312     'kerberos': '06',      # ratio 16248:258
1313     'smb': '06',
1314     'smb2': '06',
1315     'ldap': '06',
1316     'cldap': '11',
1317     'lsarpc': '06',
1318     'samr': '06',
1319     'dcerpc': '06',
1320     'epm': '06',
1321     'drsuapi': '06',
1322     'browser': '11',
1323     'smb_netlogon': '11',
1324     'srvsvc': '06',
1325     'nbns': '11',
1326 }
1327
1328 OP_DESCRIPTIONS = {
1329     ('browser', '0x01'): 'Host Announcement (0x01)',
1330     ('browser', '0x02'): 'Request Announcement (0x02)',
1331     ('browser', '0x08'): 'Browser Election Request (0x08)',
1332     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1333     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1334     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1335     ('cldap', '3'): 'searchRequest',
1336     ('cldap', '5'): 'searchResDone',
1337     ('dcerpc', '0'): 'Request',
1338     ('dcerpc', '11'): 'Bind',
1339     ('dcerpc', '12'): 'Bind_ack',
1340     ('dcerpc', '13'): 'Bind_nak',
1341     ('dcerpc', '14'): 'Alter_context',
1342     ('dcerpc', '15'): 'Alter_context_resp',
1343     ('dcerpc', '16'): 'AUTH3',
1344     ('dcerpc', '2'): 'Response',
1345     ('dns', '0'): 'query',
1346     ('dns', '1'): 'response',
1347     ('drsuapi', '0'): 'DsBind',
1348     ('drsuapi', '12'): 'DsCrackNames',
1349     ('drsuapi', '13'): 'DsWriteAccountSpn',
1350     ('drsuapi', '1'): 'DsUnbind',
1351     ('drsuapi', '2'): 'DsReplicaSync',
1352     ('drsuapi', '3'): 'DsGetNCChanges',
1353     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1354     ('epm', '3'): 'Map',
1355     ('kerberos', ''): '',
1356     ('ldap', '0'): 'bindRequest',
1357     ('ldap', '1'): 'bindResponse',
1358     ('ldap', '2'): 'unbindRequest',
1359     ('ldap', '3'): 'searchRequest',
1360     ('ldap', '4'): 'searchResEntry',
1361     ('ldap', '5'): 'searchResDone',
1362     ('ldap', ''): '*** Unknown ***',
1363     ('lsarpc', '14'): 'lsa_LookupNames',
1364     ('lsarpc', '15'): 'lsa_LookupSids',
1365     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1366     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1367     ('lsarpc', '6'): 'lsa_OpenPolicy',
1368     ('lsarpc', '76'): 'lsa_LookupSids3',
1369     ('lsarpc', '77'): 'lsa_LookupNames4',
1370     ('nbns', '0'): 'query',
1371     ('nbns', '1'): 'response',
1372     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1373     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1374     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1375     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1376     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1377     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1378     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1379     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1380     ('samr', '0',): 'Connect',
1381     ('samr', '16'): 'GetAliasMembership',
1382     ('samr', '17'): 'LookupNames',
1383     ('samr', '18'): 'LookupRids',
1384     ('samr', '19'): 'OpenGroup',
1385     ('samr', '1'): 'Close',
1386     ('samr', '25'): 'QueryGroupMember',
1387     ('samr', '34'): 'OpenUser',
1388     ('samr', '36'): 'QueryUserInfo',
1389     ('samr', '39'): 'GetGroupsForUser',
1390     ('samr', '3'): 'QuerySecurity',
1391     ('samr', '5'): 'LookupDomain',
1392     ('samr', '64'): 'Connect5',
1393     ('samr', '6'): 'EnumDomains',
1394     ('samr', '7'): 'OpenDomain',
1395     ('samr', '8'): 'QueryDomainInfo',
1396     ('smb', '0x04'): 'Close (0x04)',
1397     ('smb', '0x24'): 'Locking AndX (0x24)',
1398     ('smb', '0x2e'): 'Read AndX (0x2e)',
1399     ('smb', '0x32'): 'Trans2 (0x32)',
1400     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1401     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1402     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1403     ('smb', '0x74'): 'Logoff AndX (0x74)',
1404     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1405     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1406     ('smb2', '0'): 'NegotiateProtocol',
1407     ('smb2', '11'): 'Ioctl',
1408     ('smb2', '14'): 'Find',
1409     ('smb2', '16'): 'GetInfo',
1410     ('smb2', '18'): 'Break',
1411     ('smb2', '1'): 'SessionSetup',
1412     ('smb2', '2'): 'SessionLogoff',
1413     ('smb2', '3'): 'TreeConnect',
1414     ('smb2', '4'): 'TreeDisconnect',
1415     ('smb2', '5'): 'Create',
1416     ('smb2', '6'): 'Close',
1417     ('smb2', '8'): 'Read',
1418     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1419     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1420                                'user unknown (0x17)'),
1421     ('srvsvc', '16'): 'NetShareGetInfo',
1422     ('srvsvc', '21'): 'NetSrvGetInfo',
1423 }
1424
1425
1426 def expand_short_packet(p, timestamp, src, dest, extra):
1427     protocol, opcode = p.split(':', 1)
1428     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1429     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1430
1431     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1432     line.extend(extra)
1433     return '\t'.join(line)
1434
1435
1436 def replay(conversations,
1437            host=None,
1438            creds=None,
1439            lp=None,
1440            accounts=None,
1441            dns_rate=0,
1442            duration=None,
1443            **kwargs):
1444
1445     context = ReplayContext(server=host,
1446                             creds=creds,
1447                             lp=lp,
1448                             **kwargs)
1449
1450     if len(accounts) < len(conversations):
1451         print(("we have %d accounts but %d conversations" %
1452                (accounts, conversations)), file=sys.stderr)
1453
1454     cstack = list(zip(
1455         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1456         accounts))
1457
1458     # Set the process group so that the calling scripts are not killed
1459     # when the forked child processes are killed.
1460     os.setpgrp()
1461
1462     start = time.time()
1463
1464     if duration is None:
1465         # end 1 second after the last packet of the last conversation
1466         # to start. Conversations other than the last could still be
1467         # going, but we don't care.
1468         duration = cstack[0][0].packets[-1].timestamp + 1.0
1469         print("We will stop after %.1f seconds" % duration,
1470               file=sys.stderr)
1471
1472     end = start + duration
1473
1474     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1475           % (len(conversations), duration))
1476
1477     children = {}
1478     if dns_rate:
1479         dns_hammer = DnsHammer(dns_rate, duration)
1480         cstack.append((dns_hammer, None))
1481
1482     try:
1483         while True:
1484             # we spawn a batch, wait for finishers, then spawn another
1485             now = time.time()
1486             batch_end = min(now + 2.0, end)
1487             fork_time = 0.0
1488             fork_n = 0
1489             while cstack:
1490                 c, account = cstack.pop()
1491                 if c.start_time + start > batch_end:
1492                     cstack.append((c, account))
1493                     break
1494
1495                 st = time.time()
1496                 pid = c.replay_in_fork_with_delay(start, context, account)
1497                 children[pid] = c
1498                 t = time.time()
1499                 elapsed = t - st
1500                 fork_time += elapsed
1501                 fork_n += 1
1502                 print("forked %s in pid %s (in %fs)" % (c, pid,
1503                                                         elapsed),
1504                       file=sys.stderr)
1505
1506             if fork_n:
1507                 print(("forked %d times in %f seconds (avg %f)" %
1508                        (fork_n, fork_time, fork_time / fork_n)),
1509                       file=sys.stderr)
1510             elif cstack:
1511                 debug(2, "no forks in batch ending %f" % batch_end)
1512
1513             while time.time() < batch_end - 1.0:
1514                 time.sleep(0.01)
1515                 try:
1516                     pid, status = os.waitpid(-1, os.WNOHANG)
1517                 except OSError as e:
1518                     if e.errno != 10:  # no child processes
1519                         raise
1520                     break
1521                 if pid:
1522                     c = children.pop(pid, None)
1523                     print(("process %d finished conversation %s;"
1524                            " %d to go" %
1525                            (pid, c, len(children))), file=sys.stderr)
1526
1527             if time.time() >= end:
1528                 print("time to stop", file=sys.stderr)
1529                 break
1530
1531     except Exception:
1532         print("EXCEPTION in parent", file=sys.stderr)
1533         traceback.print_exc()
1534     finally:
1535         for s in (15, 15, 9):
1536             print(("killing %d children with -%d" %
1537                    (len(children), s)), file=sys.stderr)
1538             for pid in children:
1539                 try:
1540                     os.kill(pid, s)
1541                 except OSError as e:
1542                     if e.errno != 3:  # don't fail if it has already died
1543                         raise
1544             time.sleep(0.5)
1545             end = time.time() + 1
1546             while children:
1547                 try:
1548                     pid, status = os.waitpid(-1, os.WNOHANG)
1549                 except OSError as e:
1550                     if e.errno != 10:
1551                         raise
1552                 if pid != 0:
1553                     c = children.pop(pid, None)
1554                     print(("kill -%d %d KILLED conversation %s; "
1555                            "%d to go" %
1556                            (s, pid, c, len(children))),
1557                           file=sys.stderr)
1558                 if time.time() >= end:
1559                     break
1560
1561             if not children:
1562                 break
1563             time.sleep(1)
1564
1565         if children:
1566             print("%d children are missing" % len(children),
1567                   file=sys.stderr)
1568
1569         # there may be stragglers that were forked just as ^C was hit
1570         # and don't appear in the list of children. We can get them
1571         # with killpg, but that will also kill us, so this is^H^H would be
1572         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1573         # so as not to have to fuss around writing signal handlers.
1574         try:
1575             os.killpg(0, 2)
1576         except KeyboardInterrupt:
1577             print("ignoring fake ^C", file=sys.stderr)
1578
1579
1580 def openLdb(host, creds, lp):
1581     session = system_session()
1582     ldb = SamDB(url="ldap://%s" % host,
1583                 session_info=session,
1584                 options=['modules:paged_searches'],
1585                 credentials=creds,
1586                 lp=lp)
1587     return ldb
1588
1589
1590 def ou_name(ldb, instance_id):
1591     """Generate an ou name from the instance id"""
1592     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1593                                                     ldb.domain_dn())
1594
1595
1596 def create_ou(ldb, instance_id):
1597     """Create an ou, all created user and machine accounts will belong to it.
1598
1599     This allows all the created resources to be cleaned up easily.
1600     """
1601     ou = ou_name(ldb, instance_id)
1602     try:
1603         ldb.add({"dn": ou.split(',', 1)[1],
1604                  "objectclass": "organizationalunit"})
1605     except LdbError as e:
1606         (status, _) = e.args
1607         # ignore already exists
1608         if status != 68:
1609             raise
1610     try:
1611         ldb.add({"dn": ou,
1612                  "objectclass": "organizationalunit"})
1613     except LdbError as e:
1614         (status, _) = e.args
1615         # ignore already exists
1616         if status != 68:
1617             raise
1618     return ou
1619
1620
1621 # ConversationAccounts holds details of the machine and user accounts
1622 # associated with a conversation.
1623 #
1624 # We use a named tuple to reduce shared memory usage.
1625 ConversationAccounts = namedtuple('ConversationAccounts',
1626                                   ('netbios_name',
1627                                    'machinepass',
1628                                    'username',
1629                                    'userpass'))
1630
1631
1632 def generate_replay_accounts(ldb, instance_id, number, password):
1633     """Generate a series of unique machine and user account names."""
1634
1635     accounts = []
1636     for i in range(1, number + 1):
1637         netbios_name = machine_name(instance_id, i)
1638         username = user_name(instance_id, i)
1639
1640         account = ConversationAccounts(netbios_name, password, username,
1641                                        password)
1642         accounts.append(account)
1643     return accounts
1644
1645
1646 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1647                            traffic_account=True):
1648     """Create a machine account via ldap."""
1649
1650     ou = ou_name(ldb, instance_id)
1651     dn = "cn=%s,%s" % (netbios_name, ou)
1652     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1653
1654     if traffic_account:
1655         # we set these bits for the machine account otherwise the replayed
1656         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1657         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1658                                UF_SERVER_TRUST_ACCOUNT)
1659
1660     else:
1661         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1662
1663     ldb.add({
1664         "dn": dn,
1665         "objectclass": "computer",
1666         "sAMAccountName": "%s$" % netbios_name,
1667         "userAccountControl": account_controls,
1668         "unicodePwd": utf16pw})
1669
1670
1671 def create_user_account(ldb, instance_id, username, userpass):
1672     """Create a user account via ldap."""
1673     ou = ou_name(ldb, instance_id)
1674     user_dn = "cn=%s,%s" % (username, ou)
1675     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1676     ldb.add({
1677         "dn": user_dn,
1678         "objectclass": "user",
1679         "sAMAccountName": username,
1680         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1681         "unicodePwd": utf16pw
1682     })
1683
1684     # grant user write permission to do things like write account SPN
1685     sdutils = sd_utils.SDUtils(ldb)
1686     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1687
1688
1689 def create_group(ldb, instance_id, name):
1690     """Create a group via ldap."""
1691
1692     ou = ou_name(ldb, instance_id)
1693     dn = "cn=%s,%s" % (name, ou)
1694     ldb.add({
1695         "dn": dn,
1696         "objectclass": "group",
1697         "sAMAccountName": name,
1698     })
1699
1700
1701 def user_name(instance_id, i):
1702     """Generate a user name based in the instance id"""
1703     return "STGU-%d-%d" % (instance_id, i)
1704
1705
1706 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1707     """Seach objectclass, return attr in a set"""
1708     objs = ldb.search(
1709         expression="(objectClass={})".format(objectclass),
1710         attrs=[attr]
1711     )
1712     return {str(obj[attr]) for obj in objs}
1713
1714
1715 def generate_users(ldb, instance_id, number, password):
1716     """Add users to the server"""
1717     existing_objects = search_objectclass(ldb, objectclass='user')
1718     users = 0
1719     for i in range(number, 0, -1):
1720         name = user_name(instance_id, i)
1721         if name not in existing_objects:
1722             create_user_account(ldb, instance_id, name, password)
1723             users += 1
1724             if users % 50 == 0:
1725                 LOGGER.info("Created %u/%u users" % (users, number))
1726
1727     return users
1728
1729
1730 def machine_name(instance_id, i, traffic_account=True):
1731     """Generate a machine account name from instance id."""
1732     if traffic_account:
1733         # traffic accounts correspond to a given user, and use different
1734         # userAccountControl flags to ensure packets get processed correctly
1735         # by the DC
1736         return "STGM-%d-%d" % (instance_id, i)
1737     else:
1738         # Otherwise we're just generating computer accounts to simulate a
1739         # semi-realistic network. These use the default computer
1740         # userAccountControl flags, so we use a different account name so that
1741         # we don't try to use them when generating packets
1742         return "PC-%d-%d" % (instance_id, i)
1743
1744
1745 def generate_machine_accounts(ldb, instance_id, number, password,
1746                               traffic_account=True):
1747     """Add machine accounts to the server"""
1748     existing_objects = search_objectclass(ldb, objectclass='computer')
1749     added = 0
1750     for i in range(number, 0, -1):
1751         name = machine_name(instance_id, i, traffic_account)
1752         if name + "$" not in existing_objects:
1753             create_machine_account(ldb, instance_id, name, password,
1754                                    traffic_account)
1755             added += 1
1756             if added % 50 == 0:
1757                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1758
1759     return added
1760
1761
1762 def group_name(instance_id, i):
1763     """Generate a group name from instance id."""
1764     return "STGG-%d-%d" % (instance_id, i)
1765
1766
1767 def generate_groups(ldb, instance_id, number):
1768     """Create the required number of groups on the server."""
1769     existing_objects = search_objectclass(ldb, objectclass='group')
1770     groups = 0
1771     for i in range(number, 0, -1):
1772         name = group_name(instance_id, i)
1773         if name not in existing_objects:
1774             create_group(ldb, instance_id, name)
1775             groups += 1
1776             if groups % 1000 == 0:
1777                 LOGGER.info("Created %u/%u groups" % (groups, number))
1778
1779     return groups
1780
1781
1782 def clean_up_accounts(ldb, instance_id):
1783     """Remove the created accounts and groups from the server."""
1784     ou = ou_name(ldb, instance_id)
1785     try:
1786         ldb.delete(ou, ["tree_delete:1"])
1787     except LdbError as e:
1788         (status, _) = e.args
1789         # ignore does not exist
1790         if status != 32:
1791             raise
1792
1793
1794 def generate_users_and_groups(ldb, instance_id, password,
1795                               number_of_users, number_of_groups,
1796                               group_memberships, max_members,
1797                               machine_accounts, traffic_accounts=True):
1798     """Generate the required users and groups, allocating the users to
1799        those groups."""
1800     memberships_added = 0
1801     groups_added = 0
1802     computers_added = 0
1803
1804     create_ou(ldb, instance_id)
1805
1806     LOGGER.info("Generating dummy user accounts")
1807     users_added = generate_users(ldb, instance_id, number_of_users, password)
1808
1809     LOGGER.info("Generating dummy machine accounts")
1810     computers_added = generate_machine_accounts(ldb, instance_id,
1811                                                 machine_accounts, password,
1812                                                 traffic_accounts)
1813
1814     if number_of_groups > 0:
1815         LOGGER.info("Generating dummy groups")
1816         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1817
1818     if group_memberships > 0:
1819         LOGGER.info("Assigning users to groups")
1820         assignments = GroupAssignments(number_of_groups,
1821                                        groups_added,
1822                                        number_of_users,
1823                                        users_added,
1824                                        group_memberships,
1825                                        max_members)
1826         LOGGER.info("Adding users to groups")
1827         add_users_to_groups(ldb, instance_id, assignments)
1828         memberships_added = assignments.total()
1829
1830     if (groups_added > 0 and users_added == 0 and
1831        number_of_groups != groups_added):
1832         LOGGER.warning("The added groups will contain no members")
1833
1834     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1835                 (users_added, computers_added, groups_added,
1836                  memberships_added))
1837
1838
1839 class GroupAssignments(object):
1840     def __init__(self, number_of_groups, groups_added, number_of_users,
1841                  users_added, group_memberships, max_members):
1842
1843         self.count = 0
1844         self.generate_group_distribution(number_of_groups)
1845         self.generate_user_distribution(number_of_users, group_memberships)
1846         self.max_members = max_members
1847         self.assignments = defaultdict(list)
1848         self.assign_groups(number_of_groups, groups_added, number_of_users,
1849                            users_added, group_memberships)
1850
1851     def cumulative_distribution(self, weights):
1852         # make sure the probabilities conform to a cumulative distribution
1853         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1854         # probability a proportional share of 1.0. Higher probabilities get a
1855         # bigger share, so are more likely to be picked. We use the cumulative
1856         # value, so we can use random.random() as a simple index into the list
1857         dist = []
1858         total = sum(weights)
1859         if total == 0:
1860             return None
1861
1862         cumulative = 0.0
1863         for probability in weights:
1864             cumulative += probability
1865             dist.append(cumulative / total)
1866         return dist
1867
1868     def generate_user_distribution(self, num_users, num_memberships):
1869         """Probability distribution of a user belonging to a group.
1870         """
1871         # Assign a weighted probability to each user. Use the Pareto
1872         # Distribution so that some users are in a lot of groups, and the
1873         # bulk of users are in only a few groups. If we're assigning a large
1874         # number of group memberships, use a higher shape. This means slightly
1875         # fewer outlying users that are in large numbers of groups. The aim is
1876         # to have no users belonging to more than ~500 groups.
1877         if num_memberships > 5000000:
1878             shape = 3.0
1879         elif num_memberships > 2000000:
1880             shape = 2.5
1881         elif num_memberships > 300000:
1882             shape = 2.25
1883         else:
1884             shape = 1.75
1885
1886         weights = []
1887         for x in range(1, num_users + 1):
1888             p = random.paretovariate(shape)
1889             weights.append(p)
1890
1891         # convert the weights to a cumulative distribution between 0.0 and 1.0
1892         self.user_dist = self.cumulative_distribution(weights)
1893
1894     def generate_group_distribution(self, n):
1895         """Probability distribution of a group containing a user."""
1896
1897         # Assign a weighted probability to each user. Probability decreases
1898         # as the group-ID increases
1899         weights = []
1900         for x in range(1, n + 1):
1901             p = 1 / (x**1.3)
1902             weights.append(p)
1903
1904         # convert the weights to a cumulative distribution between 0.0 and 1.0
1905         self.group_weights = weights
1906         self.group_dist = self.cumulative_distribution(weights)
1907
1908     def generate_random_membership(self):
1909         """Returns a randomly generated user-group membership"""
1910
1911         # the list items are cumulative distribution values between 0.0 and
1912         # 1.0, which makes random() a handy way to index the list to get a
1913         # weighted random user/group. (Here the user/group returned are
1914         # zero-based array indexes)
1915         user = bisect.bisect(self.user_dist, random.random())
1916         group = bisect.bisect(self.group_dist, random.random())
1917
1918         return user, group
1919
1920     def users_in_group(self, group):
1921         return self.assignments[group]
1922
1923     def get_groups(self):
1924         return self.assignments.keys()
1925
1926     def cap_group_membership(self, group, max_members):
1927         """Prevent the group's membership from exceeding the max specified"""
1928         num_members = len(self.assignments[group])
1929         if num_members >= max_members:
1930             LOGGER.info("Group {0} has {1} members".format(group, num_members))
1931
1932             # remove this group and then recalculate the cumulative
1933             # distribution, so this group is no longer selected
1934             self.group_weights[group - 1] = 0
1935             new_dist = self.cumulative_distribution(self.group_weights)
1936             self.group_dist = new_dist
1937
1938     def add_assignment(self, user, group):
1939         # the assignments are stored in a dictionary where key=group,
1940         # value=list-of-users-in-group (indexing by group-ID allows us to
1941         # optimize for DB membership writes)
1942         if user not in self.assignments[group]:
1943             self.assignments[group].append(user)
1944             self.count += 1
1945
1946         # check if there'a cap on how big the groups can grow
1947         if self.max_members:
1948             self.cap_group_membership(group, self.max_members)
1949
1950     def assign_groups(self, number_of_groups, groups_added,
1951                       number_of_users, users_added, group_memberships):
1952         """Allocate users to groups.
1953
1954         The intention is to have a few users that belong to most groups, while
1955         the majority of users belong to a few groups.
1956
1957         A few groups will contain most users, with the remaining only having a
1958         few users.
1959         """
1960
1961         if group_memberships <= 0:
1962             return
1963
1964         # Calculate the number of group menberships required
1965         group_memberships = math.ceil(
1966             float(group_memberships) *
1967             (float(users_added) / float(number_of_users)))
1968
1969         if self.max_members:
1970             group_memberships = min(group_memberships,
1971                                     self.max_members * number_of_groups)
1972
1973         existing_users  = number_of_users  - users_added  - 1
1974         existing_groups = number_of_groups - groups_added - 1
1975         while self.total() < group_memberships:
1976             user, group = self.generate_random_membership()
1977
1978             if group > existing_groups or user > existing_users:
1979                 # the + 1 converts the array index to the corresponding
1980                 # group or user number
1981                 self.add_assignment(user + 1, group + 1)
1982
1983     def total(self):
1984         return self.count
1985
1986
1987 def add_users_to_groups(db, instance_id, assignments):
1988     """Takes the assignments of users to groups and applies them to the DB."""
1989
1990     total = assignments.total()
1991     count = 0
1992     added = 0
1993
1994     for group in assignments.get_groups():
1995         users_in_group = assignments.users_in_group(group)
1996         if len(users_in_group) == 0:
1997             continue
1998
1999         # Split up the users into chunks, so we write no more than 1K at a
2000         # time. (Minimizing the DB modifies is more efficient, but writing
2001         # 10K+ users to a single group becomes inefficient memory-wise)
2002         for chunk in range(0, len(users_in_group), 1000):
2003             chunk_of_users = users_in_group[chunk:chunk + 1000]
2004             add_group_members(db, instance_id, group, chunk_of_users)
2005
2006             added += len(chunk_of_users)
2007             count += 1
2008             if count % 50 == 0:
2009                 LOGGER.info("Added %u/%u memberships" % (added, total))
2010
2011 def add_group_members(db, instance_id, group, users_in_group):
2012     """Adds the given users to group specified."""
2013
2014     ou = ou_name(db, instance_id)
2015
2016     def build_dn(name):
2017         return("cn=%s,%s" % (name, ou))
2018
2019     group_dn = build_dn(group_name(instance_id, group))
2020     m = ldb.Message()
2021     m.dn = ldb.Dn(db, group_dn)
2022
2023     for user in users_in_group:
2024         user_dn = build_dn(user_name(instance_id, user))
2025         idx = "member-" + str(user)
2026         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2027
2028     db.modify(m)
2029
2030
2031 def generate_stats(statsdir, timing_file):
2032     """Generate and print the summary stats for a run."""
2033     first      = sys.float_info.max
2034     last       = 0
2035     successful = 0
2036     failed     = 0
2037     latencies  = {}
2038     failures   = {}
2039     unique_converations = set()
2040     conversations = 0
2041
2042     if timing_file is not None:
2043         tw = timing_file.write
2044     else:
2045         def tw(x):
2046             pass
2047
2048     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2049
2050     for filename in os.listdir(statsdir):
2051         path = os.path.join(statsdir, filename)
2052         with open(path, 'r') as f:
2053             for line in f:
2054                 try:
2055                     fields       = line.rstrip('\n').split('\t')
2056                     conversation = fields[1]
2057                     protocol     = fields[2]
2058                     packet_type  = fields[3]
2059                     latency      = float(fields[4])
2060                     first        = min(float(fields[0]) - latency, first)
2061                     last         = max(float(fields[0]), last)
2062
2063                     if protocol not in latencies:
2064                         latencies[protocol] = {}
2065                     if packet_type not in latencies[protocol]:
2066                         latencies[protocol][packet_type] = []
2067
2068                     latencies[protocol][packet_type].append(latency)
2069
2070                     if protocol not in failures:
2071                         failures[protocol] = {}
2072                     if packet_type not in failures[protocol]:
2073                         failures[protocol][packet_type] = 0
2074
2075                     if fields[5] == 'True':
2076                         successful += 1
2077                     else:
2078                         failed += 1
2079                         failures[protocol][packet_type] += 1
2080
2081                     if conversation not in unique_converations:
2082                         unique_converations.add(conversation)
2083                         conversations += 1
2084
2085                     tw(line)
2086                 except (ValueError, IndexError):
2087                     # not a valid line print and ignore
2088                     print(line, file=sys.stderr)
2089                     pass
2090     duration = last - first
2091     if successful == 0:
2092         success_rate = 0
2093     else:
2094         success_rate = successful / duration
2095     if failed == 0:
2096         failure_rate = 0
2097     else:
2098         failure_rate = failed / duration
2099
2100     print("Total conversations:   %10d" % conversations)
2101     print("Successful operations: %10d (%.3f per second)"
2102           % (successful, success_rate))
2103     print("Failed operations:     %10d (%.3f per second)"
2104           % (failed, failure_rate))
2105
2106     print("Protocol    Op Code  Description                               "
2107           " Count       Failed         Mean       Median          "
2108           "95%        Range          Max")
2109
2110     protocols = sorted(latencies.keys())
2111     for protocol in protocols:
2112         packet_types = sorted(latencies[protocol], key=opcode_key)
2113         for packet_type in packet_types:
2114             values     = latencies[protocol][packet_type]
2115             values     = sorted(values)
2116             count      = len(values)
2117             failed     = failures[protocol][packet_type]
2118             mean       = sum(values) / count
2119             median     = calc_percentile(values, 0.50)
2120             percentile = calc_percentile(values, 0.95)
2121             rng        = values[-1] - values[0]
2122             maxv       = values[-1]
2123             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2124             if sys.stdout.isatty:
2125                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2126                       "%12.6f %12.6f %12.6f %12.6f"
2127                       % (protocol,
2128                          packet_type,
2129                          desc,
2130                          count,
2131                          failed,
2132                          mean,
2133                          median,
2134                          percentile,
2135                          rng,
2136                          maxv))
2137             else:
2138                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2139                       % (protocol,
2140                          packet_type,
2141                          desc,
2142                          count,
2143                          failed,
2144                          mean,
2145                          median,
2146                          percentile,
2147                          rng,
2148                          maxv))
2149
2150
2151 def opcode_key(v):
2152     """Sort key for the operation code to ensure that it sorts numerically"""
2153     try:
2154         return "%03d" % int(v)
2155     except:
2156         return v
2157
2158
2159 def calc_percentile(values, percentile):
2160     """Calculate the specified percentile from the list of values.
2161
2162     Assumes the list is sorted in ascending order.
2163     """
2164
2165     if not values:
2166         return 0
2167     k = (len(values) - 1) * percentile
2168     f = math.floor(k)
2169     c = math.ceil(k)
2170     if f == c:
2171         return values[int(k)]
2172     d0 = values[int(f)] * (c - k)
2173     d1 = values[int(c)] * (k - f)
2174     return d0 + d1
2175
2176
2177 def mk_masked_dir(*path):
2178     """In a testenv we end up with 0777 directories that look an alarming
2179     green colour with ls. Use umask to avoid that."""
2180     # py3 os.mkdir can do this
2181     d = os.path.join(*path)
2182     mask = os.umask(0o077)
2183     os.mkdir(d)
2184     os.umask(mask)
2185     return d