traffic_replay: Generate machine accounts as well as users
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
55 import bisect
56
57 SLEEP_OVERHEAD = 3e-4
58
59 # we don't use None, because it complicates [de]serialisation
60 NON_PACKET = '-'
61
62 CLIENT_CLUES = {
63     ('dns', '0'): 1.0,      # query
64     ('smb', '0x72'): 1.0,   # Negotiate protocol
65     ('ldap', '0'): 1.0,     # bind
66     ('ldap', '3'): 1.0,     # searchRequest
67     ('ldap', '2'): 1.0,     # unbindRequest
68     ('cldap', '3'): 1.0,
69     ('dcerpc', '11'): 1.0,  # bind
70     ('dcerpc', '14'): 1.0,  # Alter_context
71     ('nbns', '0'): 1.0,     # query
72 }
73
74 SERVER_CLUES = {
75     ('dns', '1'): 1.0,      # response
76     ('ldap', '1'): 1.0,     # bind response
77     ('ldap', '4'): 1.0,     # search result
78     ('ldap', '5'): 1.0,     # search done
79     ('cldap', '5'): 1.0,
80     ('dcerpc', '12'): 1.0,  # bind_ack
81     ('dcerpc', '13'): 1.0,  # bind_nak
82     ('dcerpc', '15'): 1.0,  # Alter_context response
83 }
84
85 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
86
87 WAIT_SCALE = 10.0
88 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
89 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
90
91 # DEBUG_LEVEL can be changed by scripts with -d
92 DEBUG_LEVEL = 0
93
94 LOGGER = get_samba_logger(name=__name__)
95
96
97 def debug(level, msg, *args):
98     """Print a formatted debug message to standard error.
99
100
101     :param level: The debug level, message will be printed if it is <= the
102                   currently set debug level. The debug level can be set with
103                   the -d option.
104     :param msg:   The message to be logged, can contain C-Style format
105                   specifiers
106     :param args:  The parameters required by the format specifiers
107     """
108     if level <= DEBUG_LEVEL:
109         if not args:
110             print(msg, file=sys.stderr)
111         else:
112             print(msg % tuple(args), file=sys.stderr)
113
114
115 def debug_lineno(*args):
116     """ Print an unformatted log message to stderr, contaning the line number
117     """
118     tb = traceback.extract_stack(limit=2)
119     print((" %s:" "\033[01;33m"
120            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
121           file=sys.stderr)
122     for a in args:
123         print(a, file=sys.stderr)
124     print(file=sys.stderr)
125     sys.stderr.flush()
126
127
128 def random_colour_print():
129     """Return a function that prints a randomly coloured line to stderr"""
130     n = 18 + random.randrange(214)
131     prefix = "\033[38;5;%dm" % n
132
133     def p(*args):
134         for a in args:
135             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
136
137     return p
138
139
140 class FakePacketError(Exception):
141     pass
142
143
144 class Packet(object):
145     """Details of a network packet"""
146     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
147                  protocol, opcode, desc, extra):
148
149         self.timestamp = timestamp
150         self.ip_protocol = ip_protocol
151         self.stream_number = stream_number
152         self.src = src
153         self.dest = dest
154         self.protocol = protocol
155         self.opcode = opcode
156         self.desc = desc
157         self.extra = extra
158         if self.src < self.dest:
159             self.endpoints = (self.src, self.dest)
160         else:
161             self.endpoints = (self.dest, self.src)
162
163     @classmethod
164     def from_line(self, line):
165         fields = line.rstrip('\n').split('\t')
166         (timestamp,
167          ip_protocol,
168          stream_number,
169          src,
170          dest,
171          protocol,
172          opcode,
173          desc) = fields[:8]
174         extra = fields[8:]
175
176         timestamp = float(timestamp)
177         src = int(src)
178         dest = int(dest)
179
180         return Packet(timestamp, ip_protocol, stream_number, src, dest,
181                       protocol, opcode, desc, extra)
182
183     def as_summary(self, time_offset=0.0):
184         """Format the packet as a traffic_summary line.
185         """
186         extra = '\t'.join(self.extra)
187         t = self.timestamp + time_offset
188         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
189                 (t,
190                  self.ip_protocol,
191                  self.stream_number or '',
192                  self.src,
193                  self.dest,
194                  self.protocol,
195                  self.opcode,
196                  self.desc,
197                  extra))
198
199     def __str__(self):
200         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
201                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
202                  self.stream_number, self.protocol, self.opcode, self.desc,
203                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
204
205     def __repr__(self):
206         return "<Packet @%s>" % self
207
208     def copy(self):
209         return self.__class__(self.timestamp,
210                               self.ip_protocol,
211                               self.stream_number,
212                               self.src,
213                               self.dest,
214                               self.protocol,
215                               self.opcode,
216                               self.desc,
217                               self.extra)
218
219     def as_packet_type(self):
220         t = '%s:%s' % (self.protocol, self.opcode)
221         return t
222
223     def client_score(self):
224         """A positive number means we think it is a client; a negative number
225         means we think it is a server. Zero means no idea. range: -1 to 1.
226         """
227         key = (self.protocol, self.opcode)
228         if key in CLIENT_CLUES:
229             return CLIENT_CLUES[key]
230         if key in SERVER_CLUES:
231             return -SERVER_CLUES[key]
232         return 0.0
233
234     def play(self, conversation, context):
235         """Send the packet over the network, if required.
236
237         Some packets are ignored, i.e. for  protocols not handled,
238         server response messages, or messages that are generated by the
239         protocol layer associated with other packets.
240         """
241         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
242         try:
243             fn = getattr(traffic_packets, fn_name)
244
245         except AttributeError as e:
246             print("Conversation(%s) Missing handler %s" %
247                   (conversation.conversation_id, fn_name),
248                   file=sys.stderr)
249             return
250
251         # Don't display a message for kerberos packets, they're not directly
252         # generated they're used to indicate kerberos should be used
253         if self.protocol != "kerberos":
254             debug(2, "Conversation(%s) Calling handler %s" %
255                      (conversation.conversation_id, fn_name))
256
257         start = time.time()
258         try:
259             if fn(self, conversation, context):
260                 # Only collect timing data for functions that generate
261                 # network traffic, or fail
262                 end = time.time()
263                 duration = end - start
264                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
265                       (end, conversation.conversation_id, self.protocol,
266                        self.opcode, duration))
267         except Exception as e:
268             end = time.time()
269             duration = end - start
270             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
271                   (end, conversation.conversation_id, self.protocol,
272                    self.opcode, duration, e))
273
274     def __cmp__(self, other):
275         return self.timestamp - other.timestamp
276
277     def is_really_a_packet(self, missing_packet_stats=None):
278         """Is the packet one that can be ignored?
279
280         If so removing it will have no effect on the replay
281         """
282         if self.protocol in SKIPPED_PROTOCOLS:
283             # Ignore any packets for the protocols we're not interested in.
284             return False
285         if self.protocol == "ldap" and self.opcode == '':
286             # skip ldap continuation packets
287             return False
288
289         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
290         fn = getattr(traffic_packets, fn_name, None)
291         if not fn:
292             print("missing packet %s" % fn_name, file=sys.stderr)
293             return False
294         if fn is traffic_packets.null_packet:
295             return False
296         return True
297
298
299 class ReplayContext(object):
300     """State/Context for an individual conversation between an simulated client
301        and a server.
302     """
303
304     def __init__(self,
305                  server=None,
306                  lp=None,
307                  creds=None,
308                  badpassword_frequency=None,
309                  prefer_kerberos=None,
310                  tempdir=None,
311                  statsdir=None,
312                  ou=None,
313                  base_dn=None,
314                  domain=None,
315                  domain_sid=None):
316
317         self.server                   = server
318         self.ldap_connections         = []
319         self.dcerpc_connections       = []
320         self.lsarpc_connections       = []
321         self.lsarpc_connections_named = []
322         self.drsuapi_connections      = []
323         self.srvsvc_connections       = []
324         self.samr_contexts            = []
325         self.netlogon_connection      = None
326         self.creds                    = creds
327         self.lp                       = lp
328         self.prefer_kerberos          = prefer_kerberos
329         self.ou                       = ou
330         self.base_dn                  = base_dn
331         self.domain                   = domain
332         self.statsdir                 = statsdir
333         self.global_tempdir           = tempdir
334         self.domain_sid               = domain_sid
335         self.realm                    = lp.get('realm')
336
337         # Bad password attempt controls
338         self.badpassword_frequency    = badpassword_frequency
339         self.last_lsarpc_bad          = False
340         self.last_lsarpc_named_bad    = False
341         self.last_simple_bind_bad     = False
342         self.last_bind_bad            = False
343         self.last_srvsvc_bad          = False
344         self.last_drsuapi_bad         = False
345         self.last_netlogon_bad        = False
346         self.last_samlogon_bad        = False
347         self.generate_ldap_search_tables()
348         self.next_conversation_id = itertools.count()
349
350     def generate_ldap_search_tables(self):
351         session = system_session()
352
353         db = SamDB(url="ldap://%s" % self.server,
354                    session_info=session,
355                    credentials=self.creds,
356                    lp=self.lp)
357
358         res = db.search(db.domain_dn(),
359                         scope=ldb.SCOPE_SUBTREE,
360                         controls=["paged_results:1:1000"],
361                         attrs=['dn'])
362
363         # find a list of dns for each pattern
364         # e.g. CN,CN,CN,DC,DC
365         dn_map = {}
366         attribute_clue_map = {
367             'invocationId': []
368         }
369
370         for r in res:
371             dn = str(r.dn)
372             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
373             dns = dn_map.setdefault(pattern, [])
374             dns.append(dn)
375             if dn.startswith('CN=NTDS Settings,'):
376                 attribute_clue_map['invocationId'].append(dn)
377
378         # extend the map in case we are working with a different
379         # number of DC components.
380         # for k, v in self.dn_map.items():
381         #     print >>sys.stderr, k, len(v)
382
383         for k in list(dn_map.keys()):
384             if k[-3:] != ',DC':
385                 continue
386             p = k[:-3]
387             while p[-3:] == ',DC':
388                 p = p[:-3]
389             for i in range(5):
390                 p += ',DC'
391                 if p != k and p in dn_map:
392                     print('dn_map collison %s %s' % (k, p),
393                           file=sys.stderr)
394                     continue
395                 dn_map[p] = dn_map[k]
396
397         self.dn_map = dn_map
398         self.attribute_clue_map = attribute_clue_map
399
400     def generate_process_local_config(self, account, conversation):
401         if account is None:
402             return
403         self.netbios_name             = account.netbios_name
404         self.machinepass              = account.machinepass
405         self.username                 = account.username
406         self.userpass                 = account.userpass
407
408         self.tempdir = mk_masked_dir(self.global_tempdir,
409                                      'conversation-%d' %
410                                      conversation.conversation_id)
411
412         self.lp.set("private dir", self.tempdir)
413         self.lp.set("lock dir", self.tempdir)
414         self.lp.set("state directory", self.tempdir)
415         self.lp.set("tls verify peer", "no_check")
416
417         # If the domain was not specified, check for the environment
418         # variable.
419         if self.domain is None:
420             self.domain = os.environ["DOMAIN"]
421
422         self.remoteAddress = "/root/ncalrpc_as_system"
423         self.samlogon_dn   = ("cn=%s,%s" %
424                               (self.netbios_name, self.ou))
425         self.user_dn       = ("cn=%s,%s" %
426                               (self.username, self.ou))
427
428         self.generate_machine_creds()
429         self.generate_user_creds()
430
431     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
432         """Execute the supplied logon function, randomly choosing the
433            bad credentials.
434
435            Based on the frequency in badpassword_frequency randomly perform the
436            function with the supplied bad credentials.
437            If run with bad credentials, the function is re-run with the good
438            credentials.
439            failed_last_time is used to prevent consecutive bad credential
440            attempts. So the over all bad credential frequency will be lower
441            than that requested, but not significantly.
442         """
443         if not failed_last_time:
444             if (self.badpassword_frequency and self.badpassword_frequency > 0
445                 and random.random() < self.badpassword_frequency):
446                 try:
447                     f(bad)
448                 except:
449                     # Ignore any exceptions as the operation may fail
450                     # as it's being performed with bad credentials
451                     pass
452                 failed_last_time = True
453             else:
454                 failed_last_time = False
455
456         result = f(good)
457         return (result, failed_last_time)
458
459     def generate_user_creds(self):
460         """Generate the conversation specific user Credentials.
461
462         Each Conversation has an associated user account used to simulate
463         any non Administrative user traffic.
464
465         Generates user credentials with good and bad passwords and ldap
466         simple bind credentials with good and bad passwords.
467         """
468         self.user_creds = Credentials()
469         self.user_creds.guess(self.lp)
470         self.user_creds.set_workstation(self.netbios_name)
471         self.user_creds.set_password(self.userpass)
472         self.user_creds.set_username(self.username)
473         self.user_creds.set_domain(self.domain)
474         if self.prefer_kerberos:
475             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
476         else:
477             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
478
479         self.user_creds_bad = Credentials()
480         self.user_creds_bad.guess(self.lp)
481         self.user_creds_bad.set_workstation(self.netbios_name)
482         self.user_creds_bad.set_password(self.userpass[:-4])
483         self.user_creds_bad.set_username(self.username)
484         if self.prefer_kerberos:
485             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
486         else:
487             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
488
489         # Credentials for ldap simple bind.
490         self.simple_bind_creds = Credentials()
491         self.simple_bind_creds.guess(self.lp)
492         self.simple_bind_creds.set_workstation(self.netbios_name)
493         self.simple_bind_creds.set_password(self.userpass)
494         self.simple_bind_creds.set_username(self.username)
495         self.simple_bind_creds.set_gensec_features(
496             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
497         if self.prefer_kerberos:
498             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
499         else:
500             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
501         self.simple_bind_creds.set_bind_dn(self.user_dn)
502
503         self.simple_bind_creds_bad = Credentials()
504         self.simple_bind_creds_bad.guess(self.lp)
505         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
506         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
507         self.simple_bind_creds_bad.set_username(self.username)
508         self.simple_bind_creds_bad.set_gensec_features(
509             self.simple_bind_creds_bad.get_gensec_features() |
510             gensec.FEATURE_SEAL)
511         if self.prefer_kerberos:
512             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
513         else:
514             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
515         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
516
517     def generate_machine_creds(self):
518         """Generate the conversation specific machine Credentials.
519
520         Each Conversation has an associated machine account.
521
522         Generates machine credentials with good and bad passwords.
523         """
524
525         self.machine_creds = Credentials()
526         self.machine_creds.guess(self.lp)
527         self.machine_creds.set_workstation(self.netbios_name)
528         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
529         self.machine_creds.set_password(self.machinepass)
530         self.machine_creds.set_username(self.netbios_name + "$")
531         self.machine_creds.set_domain(self.domain)
532         if self.prefer_kerberos:
533             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
534         else:
535             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
536
537         self.machine_creds_bad = Credentials()
538         self.machine_creds_bad.guess(self.lp)
539         self.machine_creds_bad.set_workstation(self.netbios_name)
540         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
541         self.machine_creds_bad.set_password(self.machinepass[:-4])
542         self.machine_creds_bad.set_username(self.netbios_name + "$")
543         if self.prefer_kerberos:
544             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
545         else:
546             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
547
548     def get_matching_dn(self, pattern, attributes=None):
549         # If the pattern is an empty string, we assume ROOTDSE,
550         # Otherwise we try adding or removing DC suffixes, then
551         # shorter leading patterns until we hit one.
552         # e.g if there is no CN,CN,CN,CN,DC,DC
553         # we first try       CN,CN,CN,CN,DC
554         # and                CN,CN,CN,CN,DC,DC,DC
555         # then change to        CN,CN,CN,DC,DC
556         # and as last resort we use the base_dn
557         attr_clue = self.attribute_clue_map.get(attributes)
558         if attr_clue:
559             return random.choice(attr_clue)
560
561         pattern = pattern.upper()
562         while pattern:
563             if pattern in self.dn_map:
564                 return random.choice(self.dn_map[pattern])
565             # chop one off the front and try it all again.
566             pattern = pattern[3:]
567
568         return self.base_dn
569
570     def get_dcerpc_connection(self, new=False):
571         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
572         if self.dcerpc_connections and not new:
573             return self.dcerpc_connections[-1]
574         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
575                              (guid, 1), self.lp)
576         self.dcerpc_connections.append(c)
577         return c
578
579     def get_srvsvc_connection(self, new=False):
580         if self.srvsvc_connections and not new:
581             return self.srvsvc_connections[-1]
582
583         def connect(creds):
584             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
585                                  self.lp,
586                                  creds)
587
588         (c, self.last_srvsvc_bad) = \
589             self.with_random_bad_credentials(connect,
590                                              self.user_creds,
591                                              self.user_creds_bad,
592                                              self.last_srvsvc_bad)
593
594         self.srvsvc_connections.append(c)
595         return c
596
597     def get_lsarpc_connection(self, new=False):
598         if self.lsarpc_connections and not new:
599             return self.lsarpc_connections[-1]
600
601         def connect(creds):
602             binding_options = 'schannel,seal,sign'
603             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
604                               (self.server, binding_options),
605                               self.lp,
606                               creds)
607
608         (c, self.last_lsarpc_bad) = \
609             self.with_random_bad_credentials(connect,
610                                              self.machine_creds,
611                                              self.machine_creds_bad,
612                                              self.last_lsarpc_bad)
613
614         self.lsarpc_connections.append(c)
615         return c
616
617     def get_lsarpc_named_pipe_connection(self, new=False):
618         if self.lsarpc_connections_named and not new:
619             return self.lsarpc_connections_named[-1]
620
621         def connect(creds):
622             return lsa.lsarpc("ncacn_np:%s" % (self.server),
623                               self.lp,
624                               creds)
625
626         (c, self.last_lsarpc_named_bad) = \
627             self.with_random_bad_credentials(connect,
628                                              self.machine_creds,
629                                              self.machine_creds_bad,
630                                              self.last_lsarpc_named_bad)
631
632         self.lsarpc_connections_named.append(c)
633         return c
634
635     def get_drsuapi_connection_pair(self, new=False, unbind=False):
636         """get a (drs, drs_handle) tuple"""
637         if self.drsuapi_connections and not new:
638             c = self.drsuapi_connections[-1]
639             return c
640
641         def connect(creds):
642             binding_options = 'seal'
643             binding_string = "ncacn_ip_tcp:%s[%s]" %\
644                              (self.server, binding_options)
645             return drsuapi.drsuapi(binding_string, self.lp, creds)
646
647         (drs, self.last_drsuapi_bad) = \
648             self.with_random_bad_credentials(connect,
649                                              self.user_creds,
650                                              self.user_creds_bad,
651                                              self.last_drsuapi_bad)
652
653         (drs_handle, supported_extensions) = drs_DsBind(drs)
654         c = (drs, drs_handle)
655         self.drsuapi_connections.append(c)
656         return c
657
658     def get_ldap_connection(self, new=False, simple=False):
659         if self.ldap_connections and not new:
660             return self.ldap_connections[-1]
661
662         def simple_bind(creds):
663             """
664             To run simple bind against Windows, we need to run
665             following commands in PowerShell:
666
667                 Install-windowsfeature ADCS-Cert-Authority
668                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
669                 Restart-Computer
670
671             """
672             return SamDB('ldaps://%s' % self.server,
673                          credentials=creds,
674                          lp=self.lp)
675
676         def sasl_bind(creds):
677             return SamDB('ldap://%s' % self.server,
678                          credentials=creds,
679                          lp=self.lp)
680         if simple:
681             (samdb, self.last_simple_bind_bad) = \
682                 self.with_random_bad_credentials(simple_bind,
683                                                  self.simple_bind_creds,
684                                                  self.simple_bind_creds_bad,
685                                                  self.last_simple_bind_bad)
686         else:
687             (samdb, self.last_bind_bad) = \
688                 self.with_random_bad_credentials(sasl_bind,
689                                                  self.user_creds,
690                                                  self.user_creds_bad,
691                                                  self.last_bind_bad)
692
693         self.ldap_connections.append(samdb)
694         return samdb
695
696     def get_samr_context(self, new=False):
697         if not self.samr_contexts or new:
698             self.samr_contexts.append(
699                 SamrContext(self.server, lp=self.lp, creds=self.creds))
700         return self.samr_contexts[-1]
701
702     def get_netlogon_connection(self):
703
704         if self.netlogon_connection:
705             return self.netlogon_connection
706
707         def connect(creds):
708             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
709                                      (self.server),
710                                      self.lp,
711                                      creds)
712         (c, self.last_netlogon_bad) = \
713             self.with_random_bad_credentials(connect,
714                                              self.machine_creds,
715                                              self.machine_creds_bad,
716                                              self.last_netlogon_bad)
717         self.netlogon_connection = c
718         return c
719
720     def guess_a_dns_lookup(self):
721         return (self.realm, 'A')
722
723     def get_authenticator(self):
724         auth = self.machine_creds.new_client_authenticator()
725         current  = netr_Authenticator()
726         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
727         current.timestamp = auth["timestamp"]
728
729         subsequent = netr_Authenticator()
730         return (current, subsequent)
731
732
733 class SamrContext(object):
734     """State/Context associated with a samr connection.
735     """
736     def __init__(self, server, lp=None, creds=None):
737         self.connection    = None
738         self.handle        = None
739         self.domain_handle = None
740         self.domain_sid    = None
741         self.group_handle  = None
742         self.user_handle   = None
743         self.rids          = None
744         self.server        = server
745         self.lp            = lp
746         self.creds         = creds
747
748     def get_connection(self):
749         if not self.connection:
750             self.connection = samr.samr(
751                 "ncacn_ip_tcp:%s[seal]" % (self.server),
752                 lp_ctx=self.lp,
753                 credentials=self.creds)
754
755         return self.connection
756
757     def get_handle(self):
758         if not self.handle:
759             c = self.get_connection()
760             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
761         return self.handle
762
763
764 class Conversation(object):
765     """Details of a converation between a simulated client and a server."""
766     conversation_id = None
767
768     def __init__(self, start_time=None, endpoints=None):
769         self.start_time = start_time
770         self.endpoints = endpoints
771         self.packets = []
772         self.msg = random_colour_print()
773         self.client_balance = 0.0
774
775     def __cmp__(self, other):
776         if self.start_time is None:
777             if other.start_time is None:
778                 return 0
779             return -1
780         if other.start_time is None:
781             return 1
782         return self.start_time - other.start_time
783
784     def add_packet(self, packet):
785         """Add a packet object to this conversation, making a local copy with
786         a conversation-relative timestamp."""
787         p = packet.copy()
788
789         if self.start_time is None:
790             self.start_time = p.timestamp
791
792         if self.endpoints is None:
793             self.endpoints = p.endpoints
794
795         if p.endpoints != self.endpoints:
796             raise FakePacketError("Conversation endpoints %s don't match"
797                                   "packet endpoints %s" %
798                                   (self.endpoints, p.endpoints))
799
800         p.timestamp -= self.start_time
801
802         if p.src == p.endpoints[0]:
803             self.client_balance -= p.client_score()
804         else:
805             self.client_balance += p.client_score()
806
807         if p.is_really_a_packet():
808             self.packets.append(p)
809
810     def add_short_packet(self, timestamp, protocol, opcode, extra,
811                          client=True):
812         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
813         (possibly empty) list of extra data. If client is True, assume
814         this packet is from the client to the server.
815         """
816         src, dest = self.guess_client_server()
817         if not client:
818             src, dest = dest, src
819         key = (protocol, opcode)
820         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
821         if protocol in IP_PROTOCOLS:
822             ip_protocol = IP_PROTOCOLS[protocol]
823         else:
824             ip_protocol = '06'
825         packet = Packet(timestamp - self.start_time, ip_protocol,
826                         '', src, dest,
827                         protocol, opcode, desc, extra)
828         # XXX we're assuming the timestamp is already adjusted for
829         # this conversation?
830         # XXX should we adjust client balance for guessed packets?
831         if packet.src == packet.endpoints[0]:
832             self.client_balance -= packet.client_score()
833         else:
834             self.client_balance += packet.client_score()
835         if packet.is_really_a_packet():
836             self.packets.append(packet)
837
838     def __str__(self):
839         return ("<Conversation %s %s starting %.3f %d packets>" %
840                 (self.conversation_id, self.endpoints, self.start_time,
841                  len(self.packets)))
842
843     __repr__ = __str__
844
845     def __iter__(self):
846         return iter(self.packets)
847
848     def __len__(self):
849         return len(self.packets)
850
851     def get_duration(self):
852         if len(self.packets) < 2:
853             return 0
854         return self.packets[-1].timestamp - self.packets[0].timestamp
855
856     def replay_as_summary_lines(self):
857         lines = []
858         for p in self.packets:
859             lines.append(p.as_summary(self.start_time))
860         return lines
861
862     def replay_in_fork_with_delay(self, start, context=None, account=None):
863         """Fork a new process and replay the conversation.
864         """
865         def signal_handler(signal, frame):
866             """Signal handler closes standard out and error.
867
868             Triggered by a sigterm, ensures that the log messages are flushed
869             to disk and not lost.
870             """
871             sys.stderr.close()
872             sys.stdout.close()
873             os._exit(0)
874
875         t = self.start_time
876         now = time.time() - start
877         gap = t - now
878         # we are replaying strictly in order, so it is safe to sleep
879         # in the main process if the gap is big enough. This reduces
880         # the number of concurrent threads, which allows us to make
881         # larger loads.
882         if gap > 0.15 and False:
883             print("sleeping for %f in main process" % (gap - 0.1),
884                   file=sys.stderr)
885             time.sleep(gap - 0.1)
886             now = time.time() - start
887             gap = t - now
888             print("gap is now %f" % gap, file=sys.stderr)
889
890         self.conversation_id = next(context.next_conversation_id)
891         pid = os.fork()
892         if pid != 0:
893             return pid
894         pid = os.getpid()
895         signal.signal(signal.SIGTERM, signal_handler)
896         # we must never return, or we'll end up running parts of the
897         # parent's clean-up code. So we work in a try...finally, and
898         # try to print any exceptions.
899
900         try:
901             context.generate_process_local_config(account, self)
902             sys.stdin.close()
903             os.close(0)
904             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
905                                     self.conversation_id)
906             sys.stdout.close()
907             sys.stdout = open(filename, 'w')
908
909             sleep_time = gap - SLEEP_OVERHEAD
910             if sleep_time > 0:
911                 time.sleep(sleep_time)
912
913             miss = t - (time.time() - start)
914             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
915             self.replay(context)
916         except Exception:
917             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
918                   file=sys.stderr)
919             traceback.print_exc(sys.stderr)
920         finally:
921             sys.stderr.close()
922             sys.stdout.close()
923             os._exit(0)
924
925     def replay(self, context=None):
926         start = time.time()
927
928         for p in self.packets:
929             now = time.time() - start
930             gap = p.timestamp - now
931             sleep_time = gap - SLEEP_OVERHEAD
932             if sleep_time > 0:
933                 time.sleep(sleep_time)
934
935             miss = p.timestamp - (time.time() - start)
936             if context is None:
937                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
938                                                            os.getpid()))
939                 continue
940             p.play(self, context)
941
942     def guess_client_server(self, server_clue=None):
943         """Have a go at deciding who is the server and who is the client.
944         returns (client, server)
945         """
946         a, b = self.endpoints
947
948         if self.client_balance < 0:
949             return (a, b)
950
951         # in the absense of a clue, we will fall through to assuming
952         # the lowest number is the server (which is usually true).
953
954         if self.client_balance == 0 and server_clue == b:
955             return (a, b)
956
957         return (b, a)
958
959     def forget_packets_outside_window(self, s, e):
960         """Prune any packets outside the timne window we're interested in
961
962         :param s: start of the window
963         :param e: end of the window
964         """
965         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
966         self.start_time = self.packets[0].timestamp if self.packets else None
967
968     def renormalise_times(self, start_time):
969         """Adjust the packet start times relative to the new start time."""
970         for p in self.packets:
971             p.timestamp -= start_time
972
973         if self.start_time is not None:
974             self.start_time -= start_time
975
976
977 class DnsHammer(Conversation):
978     """A lightweight conversation that generates a lot of dns:0 packets on
979     the fly"""
980
981     def __init__(self, dns_rate, duration):
982         n = int(dns_rate * duration)
983         self.times = [random.uniform(0, duration) for i in range(n)]
984         self.times.sort()
985         self.rate = dns_rate
986         self.duration = duration
987         self.start_time = 0
988         self.msg = random_colour_print()
989
990     def __str__(self):
991         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
992                 (len(self.times), self.duration, self.rate))
993
994     def replay_in_fork_with_delay(self, start, context=None, account=None):
995         return Conversation.replay_in_fork_with_delay(self,
996                                                       start,
997                                                       context,
998                                                       account)
999
1000     def replay(self, context=None):
1001         start = time.time()
1002         fn = traffic_packets.packet_dns_0
1003         for t in self.times:
1004             now = time.time() - start
1005             gap = t - now
1006             sleep_time = gap - SLEEP_OVERHEAD
1007             if sleep_time > 0:
1008                 time.sleep(sleep_time)
1009
1010             if context is None:
1011                 miss = t - (time.time() - start)
1012                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1013                                                            os.getpid()))
1014                 continue
1015
1016             packet_start = time.time()
1017             try:
1018                 fn(self, self, context)
1019                 end = time.time()
1020                 duration = end - packet_start
1021                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1022             except Exception as e:
1023                 end = time.time()
1024                 duration = end - packet_start
1025                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1026
1027
1028 def ingest_summaries(files, dns_mode='count'):
1029     """Load a summary traffic summary file and generated Converations from it.
1030     """
1031
1032     dns_counts = defaultdict(int)
1033     packets = []
1034     for f in files:
1035         if isinstance(f, str):
1036             f = open(f)
1037         print("Ingesting %s" % (f.name,), file=sys.stderr)
1038         for line in f:
1039             p = Packet.from_line(line)
1040             if p.protocol == 'dns' and dns_mode != 'include':
1041                 dns_counts[p.opcode] += 1
1042             else:
1043                 packets.append(p)
1044
1045         f.close()
1046
1047     if not packets:
1048         return [], 0
1049
1050     start_time = min(p.timestamp for p in packets)
1051     last_packet = max(p.timestamp for p in packets)
1052
1053     print("gathering packets into conversations", file=sys.stderr)
1054     conversations = OrderedDict()
1055     for p in packets:
1056         p.timestamp -= start_time
1057         c = conversations.get(p.endpoints)
1058         if c is None:
1059             c = Conversation()
1060             conversations[p.endpoints] = c
1061         c.add_packet(p)
1062
1063     # We only care about conversations with actual traffic, so we
1064     # filter out conversations with nothing to say. We do that here,
1065     # rather than earlier, because those empty packets contain useful
1066     # hints as to which end of the conversation was the client.
1067     conversation_list = []
1068     for c in conversations.values():
1069         if len(c) != 0:
1070             conversation_list.append(c)
1071
1072     # This is obviously not correct, as many conversations will appear
1073     # to start roughly simultaneously at the beginning of the snapshot.
1074     # To which we say: oh well, so be it.
1075     duration = float(last_packet - start_time)
1076     mean_interval = len(conversations) / duration
1077
1078     return conversation_list, mean_interval, duration, dns_counts
1079
1080
1081 def guess_server_address(conversations):
1082     # we guess the most common address.
1083     addresses = Counter()
1084     for c in conversations:
1085         addresses.update(c.endpoints)
1086     if addresses:
1087         return addresses.most_common(1)[0]
1088
1089
1090 def stringify_keys(x):
1091     y = {}
1092     for k, v in x.items():
1093         k2 = '\t'.join(k)
1094         y[k2] = v
1095     return y
1096
1097
1098 def unstringify_keys(x):
1099     y = {}
1100     for k, v in x.items():
1101         t = tuple(str(k).split('\t'))
1102         y[t] = v
1103     return y
1104
1105
1106 class TrafficModel(object):
1107     def __init__(self, n=3):
1108         self.ngrams = {}
1109         self.query_details = {}
1110         self.n = n
1111         self.dns_opcounts = defaultdict(int)
1112         self.cumulative_duration = 0.0
1113         self.conversation_rate = [0, 1]
1114
1115     def learn(self, conversations, dns_opcounts={}):
1116         prev = 0.0
1117         cum_duration = 0.0
1118         key = (NON_PACKET,) * (self.n - 1)
1119
1120         server = guess_server_address(conversations)
1121
1122         for k, v in dns_opcounts.items():
1123             self.dns_opcounts[k] += v
1124
1125         if len(conversations) > 1:
1126             elapsed =\
1127                 conversations[-1].start_time - conversations[0].start_time
1128             self.conversation_rate[0] = len(conversations)
1129             self.conversation_rate[1] = elapsed
1130
1131         for c in conversations:
1132             client, server = c.guess_client_server(server)
1133             cum_duration += c.get_duration()
1134             key = (NON_PACKET,) * (self.n - 1)
1135             for p in c:
1136                 if p.src != client:
1137                     continue
1138
1139                 elapsed = p.timestamp - prev
1140                 prev = p.timestamp
1141                 if elapsed > WAIT_THRESHOLD:
1142                     # add the wait as an extra state
1143                     wait = 'wait:%d' % (math.log(max(1.0,
1144                                                      elapsed * WAIT_SCALE)))
1145                     self.ngrams.setdefault(key, []).append(wait)
1146                     key = key[1:] + (wait,)
1147
1148                 short_p = p.as_packet_type()
1149                 self.query_details.setdefault(short_p,
1150                                               []).append(tuple(p.extra))
1151                 self.ngrams.setdefault(key, []).append(short_p)
1152                 key = key[1:] + (short_p,)
1153
1154         self.cumulative_duration += cum_duration
1155         # add in the end
1156         self.ngrams.setdefault(key, []).append(NON_PACKET)
1157
1158     def save(self, f):
1159         ngrams = {}
1160         for k, v in self.ngrams.items():
1161             k = '\t'.join(k)
1162             ngrams[k] = dict(Counter(v))
1163
1164         query_details = {}
1165         for k, v in self.query_details.items():
1166             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1167                                             for x in v))
1168
1169         d = {
1170             'ngrams': ngrams,
1171             'query_details': query_details,
1172             'cumulative_duration': self.cumulative_duration,
1173             'conversation_rate': self.conversation_rate,
1174         }
1175         d['dns'] = self.dns_opcounts
1176
1177         if isinstance(f, str):
1178             f = open(f, 'w')
1179
1180         json.dump(d, f, indent=2)
1181
1182     def load(self, f):
1183         if isinstance(f, str):
1184             f = open(f)
1185
1186         d = json.load(f)
1187
1188         for k, v in d['ngrams'].items():
1189             k = tuple(str(k).split('\t'))
1190             values = self.ngrams.setdefault(k, [])
1191             for p, count in v.items():
1192                 values.extend([str(p)] * count)
1193
1194         for k, v in d['query_details'].items():
1195             values = self.query_details.setdefault(str(k), [])
1196             for p, count in v.items():
1197                 if p == '-':
1198                     values.extend([()] * count)
1199                 else:
1200                     values.extend([tuple(str(p).split('\t'))] * count)
1201
1202         if 'dns' in d:
1203             for k, v in d['dns'].items():
1204                 self.dns_opcounts[k] += v
1205
1206         self.cumulative_duration = d['cumulative_duration']
1207         self.conversation_rate = d['conversation_rate']
1208
1209     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1210                                hard_stop=None, packet_rate=1):
1211         """Construct a individual converation from the model."""
1212
1213         c = Conversation(timestamp, (server, client))
1214
1215         key = (NON_PACKET,) * (self.n - 1)
1216
1217         while key in self.ngrams:
1218             p = random.choice(self.ngrams.get(key, NON_PACKET))
1219             if p == NON_PACKET:
1220                 break
1221             if p in self.query_details:
1222                 extra = random.choice(self.query_details[p])
1223             else:
1224                 extra = []
1225
1226             protocol, opcode = p.split(':', 1)
1227             if protocol == 'wait':
1228                 log_wait_time = int(opcode) + random.random()
1229                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1230                 timestamp += wait
1231             else:
1232                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1233                 wait = math.exp(log_wait) / packet_rate
1234                 timestamp += wait
1235                 if hard_stop is not None and timestamp > hard_stop:
1236                     break
1237                 c.add_short_packet(timestamp, protocol, opcode, extra)
1238
1239             key = key[1:] + (p,)
1240
1241         return c
1242
1243     def generate_conversations(self, rate, duration, packet_rate=1):
1244         """Generate a list of conversations from the model."""
1245
1246         # We run the simulation for at least ten times as long as our
1247         # desired duration, and take a section near the start.
1248         rate_n, rate_t  = self.conversation_rate
1249
1250         duration2 = max(rate_t, duration * 2)
1251         n = rate * duration2 * rate_n / rate_t
1252
1253         server = 1
1254         client = 2
1255
1256         conversations = []
1257         end = duration2
1258         start = end - duration
1259
1260         while client < n + 2:
1261             start = random.uniform(0, duration2)
1262             c = self.construct_conversation(start,
1263                                             client,
1264                                             server,
1265                                             hard_stop=(duration2 * 5),
1266                                             packet_rate=packet_rate)
1267
1268             c.forget_packets_outside_window(start, end)
1269             c.renormalise_times(start)
1270             if len(c) != 0:
1271                 conversations.append(c)
1272             client += 1
1273
1274         print(("we have %d conversations at rate %f" %
1275                (len(conversations), rate)), file=sys.stderr)
1276         conversations.sort()
1277         return conversations
1278
1279
1280 IP_PROTOCOLS = {
1281     'dns': '11',
1282     'rpc_netlogon': '06',
1283     'kerberos': '06',      # ratio 16248:258
1284     'smb': '06',
1285     'smb2': '06',
1286     'ldap': '06',
1287     'cldap': '11',
1288     'lsarpc': '06',
1289     'samr': '06',
1290     'dcerpc': '06',
1291     'epm': '06',
1292     'drsuapi': '06',
1293     'browser': '11',
1294     'smb_netlogon': '11',
1295     'srvsvc': '06',
1296     'nbns': '11',
1297 }
1298
1299 OP_DESCRIPTIONS = {
1300     ('browser', '0x01'): 'Host Announcement (0x01)',
1301     ('browser', '0x02'): 'Request Announcement (0x02)',
1302     ('browser', '0x08'): 'Browser Election Request (0x08)',
1303     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1304     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1305     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1306     ('cldap', '3'): 'searchRequest',
1307     ('cldap', '5'): 'searchResDone',
1308     ('dcerpc', '0'): 'Request',
1309     ('dcerpc', '11'): 'Bind',
1310     ('dcerpc', '12'): 'Bind_ack',
1311     ('dcerpc', '13'): 'Bind_nak',
1312     ('dcerpc', '14'): 'Alter_context',
1313     ('dcerpc', '15'): 'Alter_context_resp',
1314     ('dcerpc', '16'): 'AUTH3',
1315     ('dcerpc', '2'): 'Response',
1316     ('dns', '0'): 'query',
1317     ('dns', '1'): 'response',
1318     ('drsuapi', '0'): 'DsBind',
1319     ('drsuapi', '12'): 'DsCrackNames',
1320     ('drsuapi', '13'): 'DsWriteAccountSpn',
1321     ('drsuapi', '1'): 'DsUnbind',
1322     ('drsuapi', '2'): 'DsReplicaSync',
1323     ('drsuapi', '3'): 'DsGetNCChanges',
1324     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1325     ('epm', '3'): 'Map',
1326     ('kerberos', ''): '',
1327     ('ldap', '0'): 'bindRequest',
1328     ('ldap', '1'): 'bindResponse',
1329     ('ldap', '2'): 'unbindRequest',
1330     ('ldap', '3'): 'searchRequest',
1331     ('ldap', '4'): 'searchResEntry',
1332     ('ldap', '5'): 'searchResDone',
1333     ('ldap', ''): '*** Unknown ***',
1334     ('lsarpc', '14'): 'lsa_LookupNames',
1335     ('lsarpc', '15'): 'lsa_LookupSids',
1336     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1337     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1338     ('lsarpc', '6'): 'lsa_OpenPolicy',
1339     ('lsarpc', '76'): 'lsa_LookupSids3',
1340     ('lsarpc', '77'): 'lsa_LookupNames4',
1341     ('nbns', '0'): 'query',
1342     ('nbns', '1'): 'response',
1343     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1344     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1345     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1346     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1347     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1348     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1349     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1350     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1351     ('samr', '0',): 'Connect',
1352     ('samr', '16'): 'GetAliasMembership',
1353     ('samr', '17'): 'LookupNames',
1354     ('samr', '18'): 'LookupRids',
1355     ('samr', '19'): 'OpenGroup',
1356     ('samr', '1'): 'Close',
1357     ('samr', '25'): 'QueryGroupMember',
1358     ('samr', '34'): 'OpenUser',
1359     ('samr', '36'): 'QueryUserInfo',
1360     ('samr', '39'): 'GetGroupsForUser',
1361     ('samr', '3'): 'QuerySecurity',
1362     ('samr', '5'): 'LookupDomain',
1363     ('samr', '64'): 'Connect5',
1364     ('samr', '6'): 'EnumDomains',
1365     ('samr', '7'): 'OpenDomain',
1366     ('samr', '8'): 'QueryDomainInfo',
1367     ('smb', '0x04'): 'Close (0x04)',
1368     ('smb', '0x24'): 'Locking AndX (0x24)',
1369     ('smb', '0x2e'): 'Read AndX (0x2e)',
1370     ('smb', '0x32'): 'Trans2 (0x32)',
1371     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1372     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1373     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1374     ('smb', '0x74'): 'Logoff AndX (0x74)',
1375     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1376     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1377     ('smb2', '0'): 'NegotiateProtocol',
1378     ('smb2', '11'): 'Ioctl',
1379     ('smb2', '14'): 'Find',
1380     ('smb2', '16'): 'GetInfo',
1381     ('smb2', '18'): 'Break',
1382     ('smb2', '1'): 'SessionSetup',
1383     ('smb2', '2'): 'SessionLogoff',
1384     ('smb2', '3'): 'TreeConnect',
1385     ('smb2', '4'): 'TreeDisconnect',
1386     ('smb2', '5'): 'Create',
1387     ('smb2', '6'): 'Close',
1388     ('smb2', '8'): 'Read',
1389     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1390     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1391                                'user unknown (0x17)'),
1392     ('srvsvc', '16'): 'NetShareGetInfo',
1393     ('srvsvc', '21'): 'NetSrvGetInfo',
1394 }
1395
1396
1397 def expand_short_packet(p, timestamp, src, dest, extra):
1398     protocol, opcode = p.split(':', 1)
1399     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1400     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1401
1402     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1403     line.extend(extra)
1404     return '\t'.join(line)
1405
1406
1407 def replay(conversations,
1408            host=None,
1409            creds=None,
1410            lp=None,
1411            accounts=None,
1412            dns_rate=0,
1413            duration=None,
1414            **kwargs):
1415
1416     context = ReplayContext(server=host,
1417                             creds=creds,
1418                             lp=lp,
1419                             **kwargs)
1420
1421     if len(accounts) < len(conversations):
1422         print(("we have %d accounts but %d conversations" %
1423                (accounts, conversations)), file=sys.stderr)
1424
1425     cstack = list(zip(
1426         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1427         accounts))
1428
1429     # Set the process group so that the calling scripts are not killed
1430     # when the forked child processes are killed.
1431     os.setpgrp()
1432
1433     start = time.time()
1434
1435     if duration is None:
1436         # end 1 second after the last packet of the last conversation
1437         # to start. Conversations other than the last could still be
1438         # going, but we don't care.
1439         duration = cstack[0][0].packets[-1].timestamp + 1.0
1440         print("We will stop after %.1f seconds" % duration,
1441               file=sys.stderr)
1442
1443     end = start + duration
1444
1445     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1446           % (len(conversations), duration))
1447
1448     children = {}
1449     if dns_rate:
1450         dns_hammer = DnsHammer(dns_rate, duration)
1451         cstack.append((dns_hammer, None))
1452
1453     try:
1454         while True:
1455             # we spawn a batch, wait for finishers, then spawn another
1456             now = time.time()
1457             batch_end = min(now + 2.0, end)
1458             fork_time = 0.0
1459             fork_n = 0
1460             while cstack:
1461                 c, account = cstack.pop()
1462                 if c.start_time + start > batch_end:
1463                     cstack.append((c, account))
1464                     break
1465
1466                 st = time.time()
1467                 pid = c.replay_in_fork_with_delay(start, context, account)
1468                 children[pid] = c
1469                 t = time.time()
1470                 elapsed = t - st
1471                 fork_time += elapsed
1472                 fork_n += 1
1473                 print("forked %s in pid %s (in %fs)" % (c, pid,
1474                                                         elapsed),
1475                       file=sys.stderr)
1476
1477             if fork_n:
1478                 print(("forked %d times in %f seconds (avg %f)" %
1479                        (fork_n, fork_time, fork_time / fork_n)),
1480                       file=sys.stderr)
1481             elif cstack:
1482                 debug(2, "no forks in batch ending %f" % batch_end)
1483
1484             while time.time() < batch_end - 1.0:
1485                 time.sleep(0.01)
1486                 try:
1487                     pid, status = os.waitpid(-1, os.WNOHANG)
1488                 except OSError as e:
1489                     if e.errno != 10:  # no child processes
1490                         raise
1491                     break
1492                 if pid:
1493                     c = children.pop(pid, None)
1494                     print(("process %d finished conversation %s;"
1495                            " %d to go" %
1496                            (pid, c, len(children))), file=sys.stderr)
1497
1498             if time.time() >= end:
1499                 print("time to stop", file=sys.stderr)
1500                 break
1501
1502     except Exception:
1503         print("EXCEPTION in parent", file=sys.stderr)
1504         traceback.print_exc()
1505     finally:
1506         for s in (15, 15, 9):
1507             print(("killing %d children with -%d" %
1508                    (len(children), s)), file=sys.stderr)
1509             for pid in children:
1510                 try:
1511                     os.kill(pid, s)
1512                 except OSError as e:
1513                     if e.errno != 3:  # don't fail if it has already died
1514                         raise
1515             time.sleep(0.5)
1516             end = time.time() + 1
1517             while children:
1518                 try:
1519                     pid, status = os.waitpid(-1, os.WNOHANG)
1520                 except OSError as e:
1521                     if e.errno != 10:
1522                         raise
1523                 if pid != 0:
1524                     c = children.pop(pid, None)
1525                     print(("kill -%d %d KILLED conversation %s; "
1526                            "%d to go" %
1527                            (s, pid, c, len(children))),
1528                           file=sys.stderr)
1529                 if time.time() >= end:
1530                     break
1531
1532             if not children:
1533                 break
1534             time.sleep(1)
1535
1536         if children:
1537             print("%d children are missing" % len(children),
1538                   file=sys.stderr)
1539
1540         # there may be stragglers that were forked just as ^C was hit
1541         # and don't appear in the list of children. We can get them
1542         # with killpg, but that will also kill us, so this is^H^H would be
1543         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1544         # so as not to have to fuss around writing signal handlers.
1545         try:
1546             os.killpg(0, 2)
1547         except KeyboardInterrupt:
1548             print("ignoring fake ^C", file=sys.stderr)
1549
1550
1551 def openLdb(host, creds, lp):
1552     session = system_session()
1553     ldb = SamDB(url="ldap://%s" % host,
1554                 session_info=session,
1555                 options=['modules:paged_searches'],
1556                 credentials=creds,
1557                 lp=lp)
1558     return ldb
1559
1560
1561 def ou_name(ldb, instance_id):
1562     """Generate an ou name from the instance id"""
1563     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1564                                                     ldb.domain_dn())
1565
1566
1567 def create_ou(ldb, instance_id):
1568     """Create an ou, all created user and machine accounts will belong to it.
1569
1570     This allows all the created resources to be cleaned up easily.
1571     """
1572     ou = ou_name(ldb, instance_id)
1573     try:
1574         ldb.add({"dn": ou.split(',', 1)[1],
1575                  "objectclass": "organizationalunit"})
1576     except LdbError as e:
1577         (status, _) = e.args
1578         # ignore already exists
1579         if status != 68:
1580             raise
1581     try:
1582         ldb.add({"dn": ou,
1583                  "objectclass": "organizationalunit"})
1584     except LdbError as e:
1585         (status, _) = e.args
1586         # ignore already exists
1587         if status != 68:
1588             raise
1589     return ou
1590
1591
1592 class ConversationAccounts(object):
1593     """Details of the machine and user accounts associated with a conversation.
1594     """
1595     def __init__(self, netbios_name, machinepass, username, userpass):
1596         self.netbios_name = netbios_name
1597         self.machinepass  = machinepass
1598         self.username     = username
1599         self.userpass     = userpass
1600
1601
1602 def generate_replay_accounts(ldb, instance_id, number, password):
1603     """Generate a series of unique machine and user account names."""
1604
1605     generate_traffic_accounts(ldb, instance_id, number, password)
1606     accounts = []
1607     for i in range(1, number + 1):
1608         netbios_name = "STGM-%d-%d" % (instance_id, i)
1609         username     = "STGU-%d-%d" % (instance_id, i)
1610
1611         account = ConversationAccounts(netbios_name, password, username,
1612                                        password)
1613         accounts.append(account)
1614     return accounts
1615
1616
1617 def generate_traffic_accounts(ldb, instance_id, number, password):
1618     """Create the specified number of user and machine accounts.
1619
1620     As accounts are not explicitly deleted between runs. This function starts
1621     with the last account and iterates backwards stopping either when it
1622     finds an already existing account or it has generated all the required
1623     accounts.
1624     """
1625     print(("Generating machine and conversation accounts, "
1626            "as required for %d conversations" % number),
1627           file=sys.stderr)
1628     added = 0
1629     for i in range(number, 0, -1):
1630         try:
1631             netbios_name = "STGM-%d-%d" % (instance_id, i)
1632             create_machine_account(ldb, instance_id, netbios_name, password)
1633             added += 1
1634             if added % 50 == 0:
1635                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1636         except LdbError as e:
1637             (status, _) = e.args
1638             if status == 68:
1639                 break
1640             else:
1641                 raise
1642     if added > 0:
1643         LOGGER.info("Added %d new machine accounts" % added)
1644
1645     added = 0
1646     for i in range(number, 0, -1):
1647         try:
1648             username = "STGU-%d-%d" % (instance_id, i)
1649             create_user_account(ldb, instance_id, username, password)
1650             added += 1
1651             if added % 50 == 0:
1652                 LOGGER.info("Created %u/%u users" % (added, number))
1653
1654         except LdbError as e:
1655             (status, _) = e.args
1656             if status == 68:
1657                 break
1658             else:
1659                 raise
1660
1661     if added > 0:
1662         LOGGER.info("Added %d new user accounts" % added)
1663
1664
1665 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1666     """Create a machine account via ldap."""
1667
1668     ou = ou_name(ldb, instance_id)
1669     dn = "cn=%s,%s" % (netbios_name, ou)
1670     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1671
1672     ldb.add({
1673         "dn": dn,
1674         "objectclass": "computer",
1675         "sAMAccountName": "%s$" % netbios_name,
1676         "userAccountControl":
1677             str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
1678         "unicodePwd": utf16pw})
1679
1680
1681 def create_user_account(ldb, instance_id, username, userpass):
1682     """Create a user account via ldap."""
1683     ou = ou_name(ldb, instance_id)
1684     user_dn = "cn=%s,%s" % (username, ou)
1685     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1686     ldb.add({
1687         "dn": user_dn,
1688         "objectclass": "user",
1689         "sAMAccountName": username,
1690         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1691         "unicodePwd": utf16pw
1692     })
1693
1694     # grant user write permission to do things like write account SPN
1695     sdutils = sd_utils.SDUtils(ldb)
1696     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1697
1698
1699 def create_group(ldb, instance_id, name):
1700     """Create a group via ldap."""
1701
1702     ou = ou_name(ldb, instance_id)
1703     dn = "cn=%s,%s" % (name, ou)
1704     ldb.add({
1705         "dn": dn,
1706         "objectclass": "group",
1707         "sAMAccountName": name,
1708     })
1709
1710
1711 def user_name(instance_id, i):
1712     """Generate a user name based in the instance id"""
1713     return "STGU-%d-%d" % (instance_id, i)
1714
1715
1716 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1717     """Seach objectclass, return attr in a set"""
1718     objs = ldb.search(
1719         expression="(objectClass={})".format(objectclass),
1720         attrs=[attr]
1721     )
1722     return {str(obj[attr]) for obj in objs}
1723
1724
1725 def generate_users(ldb, instance_id, number, password):
1726     """Add users to the server"""
1727     existing_objects = search_objectclass(ldb, objectclass='user')
1728     users = 0
1729     for i in range(number, 0, -1):
1730         name = user_name(instance_id, i)
1731         if name not in existing_objects:
1732             create_user_account(ldb, instance_id, name, password)
1733             users += 1
1734             if users % 50 == 0:
1735                 LOGGER.info("Created %u/%u users" % (users, number))
1736
1737     return users
1738
1739
1740 def generate_machine_accounts(ldb, instance_id, number, password):
1741     """Add machine accounts to the server"""
1742     existing_objects = search_objectclass(ldb, objectclass='computer')
1743     added = 0
1744     for i in range(number, 0, -1):
1745         name = "STGM-%d-%d$" % (instance_id, i)
1746         if name not in existing_objects:
1747             name = "STGM-%d-%d" % (instance_id, i)
1748             create_machine_account(ldb, instance_id, name, password)
1749             added += 1
1750             if added % 50 == 0:
1751                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1752
1753     return added
1754
1755
1756 def group_name(instance_id, i):
1757     """Generate a group name from instance id."""
1758     return "STGG-%d-%d" % (instance_id, i)
1759
1760
1761 def generate_groups(ldb, instance_id, number):
1762     """Create the required number of groups on the server."""
1763     existing_objects = search_objectclass(ldb, objectclass='group')
1764     groups = 0
1765     for i in range(number, 0, -1):
1766         name = group_name(instance_id, i)
1767         if name not in existing_objects:
1768             create_group(ldb, instance_id, name)
1769             groups += 1
1770             if groups % 1000 == 0:
1771                 LOGGER.info("Created %u/%u groups" % (groups, number))
1772
1773     return groups
1774
1775
1776 def clean_up_accounts(ldb, instance_id):
1777     """Remove the created accounts and groups from the server."""
1778     ou = ou_name(ldb, instance_id)
1779     try:
1780         ldb.delete(ou, ["tree_delete:1"])
1781     except LdbError as e:
1782         (status, _) = e.args
1783         # ignore does not exist
1784         if status != 32:
1785             raise
1786
1787
1788 def generate_users_and_groups(ldb, instance_id, password,
1789                               number_of_users, number_of_groups,
1790                               group_memberships):
1791     """Generate the required users and groups, allocating the users to
1792        those groups."""
1793     memberships_added = 0
1794     groups_added  = 0
1795
1796     create_ou(ldb, instance_id)
1797
1798     LOGGER.info("Generating dummy user accounts")
1799     users_added = generate_users(ldb, instance_id, number_of_users, password)
1800
1801     # assume there will be some overhang with more computer accounts than users
1802     computer_accounts = int(1.25 * number_of_users)
1803     LOGGER.info("Generating dummy machine accounts")
1804     computers_added = generate_machine_accounts(ldb, instance_id,
1805                                                 computer_accounts, password)
1806
1807     if number_of_groups > 0:
1808         LOGGER.info("Generating dummy groups")
1809         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1810
1811     if group_memberships > 0:
1812         LOGGER.info("Assigning users to groups")
1813         assignments = GroupAssignments(number_of_groups,
1814                                        groups_added,
1815                                        number_of_users,
1816                                        users_added,
1817                                        group_memberships)
1818         LOGGER.info("Adding users to groups")
1819         add_users_to_groups(ldb, instance_id, assignments)
1820         memberships_added = assignments.total()
1821
1822     if (groups_added > 0 and users_added == 0 and
1823        number_of_groups != groups_added):
1824         LOGGER.warning("The added groups will contain no members")
1825
1826     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1827                 (users_added, computers_added, groups_added,
1828                  memberships_added))
1829
1830
1831 class GroupAssignments(object):
1832     def __init__(self, number_of_groups, groups_added, number_of_users,
1833                  users_added, group_memberships):
1834
1835         self.count = 0
1836         self.generate_group_distribution(number_of_groups)
1837         self.generate_user_distribution(number_of_users, group_memberships)
1838         self.assignments = self.assign_groups(number_of_groups,
1839                                               groups_added,
1840                                               number_of_users,
1841                                               users_added,
1842                                               group_memberships)
1843
1844     def cumulative_distribution(self, weights):
1845         # make sure the probabilities conform to a cumulative distribution
1846         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1847         # probability a proportional share of 1.0. Higher probabilities get a
1848         # bigger share, so are more likely to be picked. We use the cumulative
1849         # value, so we can use random.random() as a simple index into the list
1850         dist = []
1851         total = sum(weights)
1852         cumulative = 0.0
1853         for probability in weights:
1854             cumulative += probability
1855             dist.append(cumulative / total)
1856         return dist
1857
1858     def generate_user_distribution(self, num_users, num_memberships):
1859         """Probability distribution of a user belonging to a group.
1860         """
1861         # Assign a weighted probability to each user. Use the Pareto
1862         # Distribution so that some users are in a lot of groups, and the
1863         # bulk of users are in only a few groups. If we're assigning a large
1864         # number of group memberships, use a higher shape. This means slightly
1865         # fewer outlying users that are in large numbers of groups. The aim is
1866         # to have no users belonging to more than ~500 groups.
1867         if num_memberships > 5000000:
1868             shape = 3.0
1869         elif num_memberships > 2000000:
1870             shape = 2.5
1871         elif num_memberships > 300000:
1872             shape = 2.25
1873         else:
1874             shape = 1.75
1875
1876         weights = []
1877         for x in range(1, num_users + 1):
1878             p = random.paretovariate(shape)
1879             weights.append(p)
1880
1881         # convert the weights to a cumulative distribution between 0.0 and 1.0
1882         self.user_dist = self.cumulative_distribution(weights)
1883
1884     def generate_group_distribution(self, n):
1885         """Probability distribution of a group containing a user."""
1886
1887         # Assign a weighted probability to each user. Probability decreases
1888         # as the group-ID increases
1889         weights = []
1890         for x in range(1, n + 1):
1891             p = 1 / (x**1.3)
1892             weights.append(p)
1893
1894         # convert the weights to a cumulative distribution between 0.0 and 1.0
1895         self.group_dist = self.cumulative_distribution(weights)
1896
1897     def generate_random_membership(self):
1898         """Returns a randomly generated user-group membership"""
1899
1900         # the list items are cumulative distribution values between 0.0 and
1901         # 1.0, which makes random() a handy way to index the list to get a
1902         # weighted random user/group. (Here the user/group returned are
1903         # zero-based array indexes)
1904         user = bisect.bisect(self.user_dist, random.random())
1905         group = bisect.bisect(self.group_dist, random.random())
1906
1907         return user, group
1908
1909     def users_in_group(self, group):
1910         return self.assignments[group]
1911
1912     def get_groups(self):
1913         return self.assignments.keys()
1914
1915     def assign_groups(self, number_of_groups, groups_added,
1916                       number_of_users, users_added, group_memberships):
1917         """Allocate users to groups.
1918
1919         The intention is to have a few users that belong to most groups, while
1920         the majority of users belong to a few groups.
1921
1922         A few groups will contain most users, with the remaining only having a
1923         few users.
1924         """
1925
1926         assignments = set()
1927         if group_memberships <= 0:
1928             return {}
1929
1930         # Calculate the number of group menberships required
1931         group_memberships = math.ceil(
1932             float(group_memberships) *
1933             (float(users_added) / float(number_of_users)))
1934
1935         existing_users  = number_of_users  - users_added  - 1
1936         existing_groups = number_of_groups - groups_added - 1
1937         while len(assignments) < group_memberships:
1938             user, group = self.generate_random_membership()
1939
1940             if group > existing_groups or user > existing_users:
1941                 # the + 1 converts the array index to the corresponding
1942                 # group or user number
1943                 assignments.add(((user + 1), (group + 1)))
1944
1945         # convert the set into a dictionary, where key=group, value=list-of-
1946         # users-in-group (indexing by group-ID allows us to optimize for
1947         # DB membership writes)
1948         assignment_dict = defaultdict(list)
1949         for (user, group) in assignments:
1950             assignment_dict[group].append(user)
1951             self.count += 1
1952
1953         return assignment_dict
1954
1955     def total(self):
1956         return self.count
1957
1958
1959 def add_users_to_groups(db, instance_id, assignments):
1960     """Takes the assignments of users to groups and applies them to the DB."""
1961
1962     total = assignments.total()
1963     count = 0
1964     added = 0
1965
1966     for group in assignments.get_groups():
1967         users_in_group = assignments.users_in_group(group)
1968         if len(users_in_group) == 0:
1969             continue
1970
1971         # Split up the users into chunks, so we write no more than 1K at a
1972         # time. (Minimizing the DB modifies is more efficient, but writing
1973         # 10K+ users to a single group becomes inefficient memory-wise)
1974         for chunk in range(0, len(users_in_group), 1000):
1975             chunk_of_users = users_in_group[chunk:chunk + 1000]
1976             add_group_members(db, instance_id, group, chunk_of_users)
1977
1978             added += len(chunk_of_users)
1979             count += 1
1980             if count % 50 == 0:
1981                 LOGGER.info("Added %u/%u memberships" % (added, total))
1982
1983 def add_group_members(db, instance_id, group, users_in_group):
1984     """Adds the given users to group specified."""
1985
1986     ou = ou_name(db, instance_id)
1987
1988     def build_dn(name):
1989         return("cn=%s,%s" % (name, ou))
1990
1991     group_dn = build_dn(group_name(instance_id, group))
1992     m = ldb.Message()
1993     m.dn = ldb.Dn(db, group_dn)
1994
1995     for user in users_in_group:
1996         user_dn = build_dn(user_name(instance_id, user))
1997         idx = "member-" + str(user)
1998         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1999
2000     db.modify(m)
2001
2002
2003 def generate_stats(statsdir, timing_file):
2004     """Generate and print the summary stats for a run."""
2005     first      = sys.float_info.max
2006     last       = 0
2007     successful = 0
2008     failed     = 0
2009     latencies  = {}
2010     failures   = {}
2011     unique_converations = set()
2012     conversations = 0
2013
2014     if timing_file is not None:
2015         tw = timing_file.write
2016     else:
2017         def tw(x):
2018             pass
2019
2020     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2021
2022     for filename in os.listdir(statsdir):
2023         path = os.path.join(statsdir, filename)
2024         with open(path, 'r') as f:
2025             for line in f:
2026                 try:
2027                     fields       = line.rstrip('\n').split('\t')
2028                     conversation = fields[1]
2029                     protocol     = fields[2]
2030                     packet_type  = fields[3]
2031                     latency      = float(fields[4])
2032                     first        = min(float(fields[0]) - latency, first)
2033                     last         = max(float(fields[0]), last)
2034
2035                     if protocol not in latencies:
2036                         latencies[protocol] = {}
2037                     if packet_type not in latencies[protocol]:
2038                         latencies[protocol][packet_type] = []
2039
2040                     latencies[protocol][packet_type].append(latency)
2041
2042                     if protocol not in failures:
2043                         failures[protocol] = {}
2044                     if packet_type not in failures[protocol]:
2045                         failures[protocol][packet_type] = 0
2046
2047                     if fields[5] == 'True':
2048                         successful += 1
2049                     else:
2050                         failed += 1
2051                         failures[protocol][packet_type] += 1
2052
2053                     if conversation not in unique_converations:
2054                         unique_converations.add(conversation)
2055                         conversations += 1
2056
2057                     tw(line)
2058                 except (ValueError, IndexError):
2059                     # not a valid line print and ignore
2060                     print(line, file=sys.stderr)
2061                     pass
2062     duration = last - first
2063     if successful == 0:
2064         success_rate = 0
2065     else:
2066         success_rate = successful / duration
2067     if failed == 0:
2068         failure_rate = 0
2069     else:
2070         failure_rate = failed / duration
2071
2072     print("Total conversations:   %10d" % conversations)
2073     print("Successful operations: %10d (%.3f per second)"
2074           % (successful, success_rate))
2075     print("Failed operations:     %10d (%.3f per second)"
2076           % (failed, failure_rate))
2077
2078     print("Protocol    Op Code  Description                               "
2079           " Count       Failed         Mean       Median          "
2080           "95%        Range          Max")
2081
2082     protocols = sorted(latencies.keys())
2083     for protocol in protocols:
2084         packet_types = sorted(latencies[protocol], key=opcode_key)
2085         for packet_type in packet_types:
2086             values     = latencies[protocol][packet_type]
2087             values     = sorted(values)
2088             count      = len(values)
2089             failed     = failures[protocol][packet_type]
2090             mean       = sum(values) / count
2091             median     = calc_percentile(values, 0.50)
2092             percentile = calc_percentile(values, 0.95)
2093             rng        = values[-1] - values[0]
2094             maxv       = values[-1]
2095             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2096             if sys.stdout.isatty:
2097                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2098                       "%12.6f %12.6f %12.6f %12.6f"
2099                       % (protocol,
2100                          packet_type,
2101                          desc,
2102                          count,
2103                          failed,
2104                          mean,
2105                          median,
2106                          percentile,
2107                          rng,
2108                          maxv))
2109             else:
2110                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2111                       % (protocol,
2112                          packet_type,
2113                          desc,
2114                          count,
2115                          failed,
2116                          mean,
2117                          median,
2118                          percentile,
2119                          rng,
2120                          maxv))
2121
2122
2123 def opcode_key(v):
2124     """Sort key for the operation code to ensure that it sorts numerically"""
2125     try:
2126         return "%03d" % int(v)
2127     except:
2128         return v
2129
2130
2131 def calc_percentile(values, percentile):
2132     """Calculate the specified percentile from the list of values.
2133
2134     Assumes the list is sorted in ascending order.
2135     """
2136
2137     if not values:
2138         return 0
2139     k = (len(values) - 1) * percentile
2140     f = math.floor(k)
2141     c = math.ceil(k)
2142     if f == c:
2143         return values[int(k)]
2144     d0 = values[int(f)] * (c - k)
2145     d1 = values[int(c)] * (k - f)
2146     return d0 + d1
2147
2148
2149 def mk_masked_dir(*path):
2150     """In a testenv we end up with 0777 diectories that look an alarming
2151     green colour with ls. Use umask to avoid that."""
2152     d = os.path.join(*path)
2153     mask = os.umask(0o077)
2154     os.mkdir(d)
2155     os.umask(mask)
2156     return d