python2 reduction: Merge remaining compat code into common
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 from errno import ECHILD, ESRCH
29
30 from collections import OrderedDict, Counter, defaultdict, namedtuple
31 from dns.resolver import query as dns_query
32
33 from samba.emulate import traffic_packets
34 from samba.samdb import SamDB
35 import ldb
36 from ldb import LdbError
37 from samba.dcerpc import ClientConnection
38 from samba.dcerpc import security, drsuapi, lsa
39 from samba.dcerpc import netlogon
40 from samba.dcerpc.netlogon import netr_Authenticator
41 from samba.dcerpc import srvsvc
42 from samba.dcerpc import samr
43 from samba.drs_utils import drs_DsBind
44 import traceback
45 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
46 from samba.auth import system_session
47 from samba.dsdb import (
48     UF_NORMAL_ACCOUNT,
49     UF_SERVER_TRUST_ACCOUNT,
50     UF_TRUSTED_FOR_DELEGATION,
51     UF_WORKSTATION_TRUST_ACCOUNT
52 )
53 from samba.dcerpc.misc import SEC_CHAN_BDC
54 from samba import gensec
55 from samba import sd_utils
56 from samba.common import get_string
57 from samba.logger import get_samba_logger
58 import bisect
59
60 CURRENT_MODEL_VERSION = 2   # save as this
61 REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
62 SLEEP_OVERHEAD = 3e-4
63
64 # we don't use None, because it complicates [de]serialisation
65 NON_PACKET = '-'
66
67 CLIENT_CLUES = {
68     ('dns', '0'): 1.0,      # query
69     ('smb', '0x72'): 1.0,   # Negotiate protocol
70     ('ldap', '0'): 1.0,     # bind
71     ('ldap', '3'): 1.0,     # searchRequest
72     ('ldap', '2'): 1.0,     # unbindRequest
73     ('cldap', '3'): 1.0,
74     ('dcerpc', '11'): 1.0,  # bind
75     ('dcerpc', '14'): 1.0,  # Alter_context
76     ('nbns', '0'): 1.0,     # query
77 }
78
79 SERVER_CLUES = {
80     ('dns', '1'): 1.0,      # response
81     ('ldap', '1'): 1.0,     # bind response
82     ('ldap', '4'): 1.0,     # search result
83     ('ldap', '5'): 1.0,     # search done
84     ('cldap', '5'): 1.0,
85     ('dcerpc', '12'): 1.0,  # bind_ack
86     ('dcerpc', '13'): 1.0,  # bind_nak
87     ('dcerpc', '15'): 1.0,  # Alter_context response
88 }
89
90 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
91
92 WAIT_SCALE = 10.0
93 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
94 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
95
96 # DEBUG_LEVEL can be changed by scripts with -d
97 DEBUG_LEVEL = 0
98
99 LOGGER = get_samba_logger(name=__name__)
100
101
102 def debug(level, msg, *args):
103     """Print a formatted debug message to standard error.
104
105
106     :param level: The debug level, message will be printed if it is <= the
107                   currently set debug level. The debug level can be set with
108                   the -d option.
109     :param msg:   The message to be logged, can contain C-Style format
110                   specifiers
111     :param args:  The parameters required by the format specifiers
112     """
113     if level <= DEBUG_LEVEL:
114         if not args:
115             print(msg, file=sys.stderr)
116         else:
117             print(msg % tuple(args), file=sys.stderr)
118
119
120 def debug_lineno(*args):
121     """ Print an unformatted log message to stderr, contaning the line number
122     """
123     tb = traceback.extract_stack(limit=2)
124     print((" %s:" "\033[01;33m"
125            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
126           file=sys.stderr)
127     for a in args:
128         print(a, file=sys.stderr)
129     print(file=sys.stderr)
130     sys.stderr.flush()
131
132
133 def random_colour_print(seeds):
134     """Return a function that prints a coloured line to stderr. The colour
135     of the line depends on a sort of hash of the integer arguments."""
136     if seeds:
137         s = 214
138         for x in seeds:
139             s += 17
140             s *= x
141             s %= 214
142         prefix = "\033[38;5;%dm" % (18 + s)
143
144         def p(*args):
145             if DEBUG_LEVEL > 0:
146                 for a in args:
147                     print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
148     else:
149         def p(*args):
150             if DEBUG_LEVEL > 0:
151                 for a in args:
152                     print(a, file=sys.stderr)
153
154     return p
155
156
157 class FakePacketError(Exception):
158     pass
159
160
161 class Packet(object):
162     """Details of a network packet"""
163     __slots__ = ('timestamp',
164                  'ip_protocol',
165                  'stream_number',
166                  'src',
167                  'dest',
168                  'protocol',
169                  'opcode',
170                  'desc',
171                  'extra',
172                  'endpoints')
173     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
174                  protocol, opcode, desc, extra):
175         self.timestamp = timestamp
176         self.ip_protocol = ip_protocol
177         self.stream_number = stream_number
178         self.src = src
179         self.dest = dest
180         self.protocol = protocol
181         self.opcode = opcode
182         self.desc = desc
183         self.extra = extra
184         if self.src < self.dest:
185             self.endpoints = (self.src, self.dest)
186         else:
187             self.endpoints = (self.dest, self.src)
188
189     @classmethod
190     def from_line(cls, line):
191         fields = line.rstrip('\n').split('\t')
192         (timestamp,
193          ip_protocol,
194          stream_number,
195          src,
196          dest,
197          protocol,
198          opcode,
199          desc) = fields[:8]
200         extra = fields[8:]
201
202         timestamp = float(timestamp)
203         src = int(src)
204         dest = int(dest)
205
206         return cls(timestamp, ip_protocol, stream_number, src, dest,
207                    protocol, opcode, desc, extra)
208
209     def as_summary(self, time_offset=0.0):
210         """Format the packet as a traffic_summary line.
211         """
212         extra = '\t'.join(self.extra)
213         t = self.timestamp + time_offset
214         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
215                 (t,
216                  self.ip_protocol,
217                  self.stream_number or '',
218                  self.src,
219                  self.dest,
220                  self.protocol,
221                  self.opcode,
222                  self.desc,
223                  extra))
224
225     def __str__(self):
226         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
227                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
228                  self.stream_number, self.protocol, self.opcode, self.desc,
229                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
230
231     def __repr__(self):
232         return "<Packet @%s>" % self
233
234     def copy(self):
235         return self.__class__(self.timestamp,
236                               self.ip_protocol,
237                               self.stream_number,
238                               self.src,
239                               self.dest,
240                               self.protocol,
241                               self.opcode,
242                               self.desc,
243                               self.extra)
244
245     def as_packet_type(self):
246         t = '%s:%s' % (self.protocol, self.opcode)
247         return t
248
249     def client_score(self):
250         """A positive number means we think it is a client; a negative number
251         means we think it is a server. Zero means no idea. range: -1 to 1.
252         """
253         key = (self.protocol, self.opcode)
254         if key in CLIENT_CLUES:
255             return CLIENT_CLUES[key]
256         if key in SERVER_CLUES:
257             return -SERVER_CLUES[key]
258         return 0.0
259
260     def play(self, conversation, context):
261         """Send the packet over the network, if required.
262
263         Some packets are ignored, i.e. for  protocols not handled,
264         server response messages, or messages that are generated by the
265         protocol layer associated with other packets.
266         """
267         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
268         try:
269             fn = getattr(traffic_packets, fn_name)
270
271         except AttributeError as e:
272             print("Conversation(%s) Missing handler %s" %
273                   (conversation.conversation_id, fn_name),
274                   file=sys.stderr)
275             return
276
277         # Don't display a message for kerberos packets, they're not directly
278         # generated they're used to indicate kerberos should be used
279         if self.protocol != "kerberos":
280             debug(2, "Conversation(%s) Calling handler %s" %
281                      (conversation.conversation_id, fn_name))
282
283         start = time.time()
284         try:
285             if fn(self, conversation, context):
286                 # Only collect timing data for functions that generate
287                 # network traffic, or fail
288                 end = time.time()
289                 duration = end - start
290                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
291                       (end, conversation.conversation_id, self.protocol,
292                        self.opcode, duration))
293         except Exception as e:
294             end = time.time()
295             duration = end - start
296             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
297                   (end, conversation.conversation_id, self.protocol,
298                    self.opcode, duration, e))
299
300     def __cmp__(self, other):
301         return self.timestamp - other.timestamp
302
303     def is_really_a_packet(self, missing_packet_stats=None):
304         return is_a_real_packet(self.protocol, self.opcode)
305
306
307 def is_a_real_packet(protocol, opcode):
308     """Is the packet one that can be ignored?
309
310     If so removing it will have no effect on the replay
311     """
312     if protocol in SKIPPED_PROTOCOLS:
313         # Ignore any packets for the protocols we're not interested in.
314         return False
315     if protocol == "ldap" and opcode == '':
316         # skip ldap continuation packets
317         return False
318
319     fn_name = 'packet_%s_%s' % (protocol, opcode)
320     fn = getattr(traffic_packets, fn_name, None)
321     if fn is None:
322         LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
323         return False
324     if fn is traffic_packets.null_packet:
325         return False
326     return True
327
328
329 def is_a_traffic_generating_packet(protocol, opcode):
330     """Return true if a packet generates traffic in its own right. Some of
331     these will generate traffic in certain contexts (e.g. ldap unbind
332     after a bind) but not if the conversation consists only of these packets.
333     """
334     if protocol == 'wait':
335         return False
336
337     if (protocol, opcode) in (
338             ('kerberos', ''),
339             ('ldap', '2'),
340             ('dcerpc', '15'),
341             ('dcerpc', '16')):
342         return False
343
344     return is_a_real_packet(protocol, opcode)
345
346
347 class ReplayContext(object):
348     """State/Context for a conversation between an simulated client and a
349        server. Some of the context is shared amongst all conversations
350        and should be generated before the fork, while other context is
351        specific to a particular conversation and should be generated
352        *after* the fork, in generate_process_local_config().
353     """
354     def __init__(self,
355                  server=None,
356                  lp=None,
357                  creds=None,
358                  total_conversations=None,
359                  badpassword_frequency=None,
360                  prefer_kerberos=None,
361                  tempdir=None,
362                  statsdir=None,
363                  ou=None,
364                  base_dn=None,
365                  domain=os.environ.get("DOMAIN"),
366                  domain_sid=None,
367                  instance_id=None):
368         self.server                   = server
369         self.netlogon_connection      = None
370         self.creds                    = creds
371         self.lp                       = lp
372         if prefer_kerberos:
373             self.kerberos_state = MUST_USE_KERBEROS
374         else:
375             self.kerberos_state = DONT_USE_KERBEROS
376         self.ou                       = ou
377         self.base_dn                  = base_dn
378         self.domain                   = domain
379         self.statsdir                 = statsdir
380         self.global_tempdir           = tempdir
381         self.domain_sid               = domain_sid
382         self.realm                    = lp.get('realm')
383         self.instance_id              = instance_id
384
385         # Bad password attempt controls
386         self.badpassword_frequency    = badpassword_frequency
387         self.last_lsarpc_bad          = False
388         self.last_lsarpc_named_bad    = False
389         self.last_simple_bind_bad     = False
390         self.last_bind_bad            = False
391         self.last_srvsvc_bad          = False
392         self.last_drsuapi_bad         = False
393         self.last_netlogon_bad        = False
394         self.last_samlogon_bad        = False
395         self.total_conversations      = total_conversations
396         self.generate_ldap_search_tables()
397
398     def generate_ldap_search_tables(self):
399         session = system_session()
400
401         db = SamDB(url="ldap://%s" % self.server,
402                    session_info=session,
403                    credentials=self.creds,
404                    lp=self.lp)
405
406         res = db.search(db.domain_dn(),
407                         scope=ldb.SCOPE_SUBTREE,
408                         controls=["paged_results:1:1000"],
409                         attrs=['dn'])
410
411         # find a list of dns for each pattern
412         # e.g. CN,CN,CN,DC,DC
413         dn_map = {}
414         attribute_clue_map = {
415             'invocationId': []
416         }
417
418         for r in res:
419             dn = str(r.dn)
420             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
421             dns = dn_map.setdefault(pattern, [])
422             dns.append(dn)
423             if dn.startswith('CN=NTDS Settings,'):
424                 attribute_clue_map['invocationId'].append(dn)
425
426         # extend the map in case we are working with a different
427         # number of DC components.
428         # for k, v in self.dn_map.items():
429         #     print >>sys.stderr, k, len(v)
430
431         for k in list(dn_map.keys()):
432             if k[-3:] != ',DC':
433                 continue
434             p = k[:-3]
435             while p[-3:] == ',DC':
436                 p = p[:-3]
437             for i in range(5):
438                 p += ',DC'
439                 if p != k and p in dn_map:
440                     print('dn_map collison %s %s' % (k, p),
441                           file=sys.stderr)
442                     continue
443                 dn_map[p] = dn_map[k]
444
445         self.dn_map = dn_map
446         self.attribute_clue_map = attribute_clue_map
447
448         # pre-populate DN-based search filters (it's simplest to generate them
449         # once, when the test starts). These are used by guess_search_filter()
450         # to avoid full-scans
451         self.search_filters = {}
452
453         # lookup all the GPO DNs
454         res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
455                         expression='(objectclass=groupPolicyContainer)')
456         gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
457
458         # a search for the 'gPCFileSysPath' attribute is probably a GPO search
459         # (as per the MS-GPOL spec) which searches for GPOs by DN
460         self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
461
462         # likewise, a search for gpLink is probably the Domain SOM search part
463         # of the MS-GPOL, in which case it's looking up a few OUs by DN
464         ou_str = ""
465         for ou in ["Domain Controllers,", "traffic_replay,", ""]:
466             ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
467         self.search_filters['gpLink'] = "(|{0})".format(ou_str)
468
469         # The CEP Web Service can query the AD DC to get pKICertificateTemplate
470         # objects (as per MS-WCCE)
471         self.search_filters['pKIExtendedKeyUsage'] = \
472             '(objectCategory=pKICertificateTemplate)'
473
474         # assume that anything querying the usnChanged is some kind of
475         # synchronization tool, e.g. AD Change Detection Connector
476         res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
477         self.search_filters['usnChanged'] = \
478             '(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
479
480     # The traffic_learner script doesn't preserve the LDAP search filter, and
481     # having no filter can result in a full DB scan. This is costly for a large
482     # DB, and not necessarily representative of real world traffic. As there
483     # several standard LDAP queries that get used by AD tools, we can apply
484     # some logic and guess what the search filter might have been originally.
485     def guess_search_filter(self, attrs, dn_sig, dn):
486
487         # there are some standard spec-based searches that query fairly unique
488         # attributes. Check if the search is likely one of these
489         for key in self.search_filters.keys():
490             if key in attrs:
491                 return self.search_filters[key]
492
493         # if it's the top-level domain, assume we're looking up a single user,
494         # e.g. like powershell Get-ADUser or a similar tool
495         if dn_sig == 'DC,DC':
496             random_user_id = random.random() % self.total_conversations
497             account_name = user_name(self.instance_id, random_user_id)
498             return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
499
500         # otherwise just return everything in the sub-tree
501         return '(objectClass=*)'
502
503     def generate_process_local_config(self, account, conversation):
504         self.ldap_connections         = []
505         self.dcerpc_connections       = []
506         self.lsarpc_connections       = []
507         self.lsarpc_connections_named = []
508         self.drsuapi_connections      = []
509         self.srvsvc_connections       = []
510         self.samr_contexts            = []
511         self.netbios_name             = account.netbios_name
512         self.machinepass              = account.machinepass
513         self.username                 = account.username
514         self.userpass                 = account.userpass
515
516         self.tempdir = mk_masked_dir(self.global_tempdir,
517                                      'conversation-%d' %
518                                      conversation.conversation_id)
519
520         self.lp.set("private dir", self.tempdir)
521         self.lp.set("lock dir", self.tempdir)
522         self.lp.set("state directory", self.tempdir)
523         self.lp.set("tls verify peer", "no_check")
524
525         self.remoteAddress = "/root/ncalrpc_as_system"
526         self.samlogon_dn   = ("cn=%s,%s" %
527                               (self.netbios_name, self.ou))
528         self.user_dn       = ("cn=%s,%s" %
529                               (self.username, self.ou))
530
531         self.generate_machine_creds()
532         self.generate_user_creds()
533
534     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
535         """Execute the supplied logon function, randomly choosing the
536            bad credentials.
537
538            Based on the frequency in badpassword_frequency randomly perform the
539            function with the supplied bad credentials.
540            If run with bad credentials, the function is re-run with the good
541            credentials.
542            failed_last_time is used to prevent consecutive bad credential
543            attempts. So the over all bad credential frequency will be lower
544            than that requested, but not significantly.
545         """
546         if not failed_last_time:
547             if (self.badpassword_frequency and
548                 random.random() < self.badpassword_frequency):
549                 try:
550                     f(bad)
551                 except Exception:
552                     # Ignore any exceptions as the operation may fail
553                     # as it's being performed with bad credentials
554                     pass
555                 failed_last_time = True
556             else:
557                 failed_last_time = False
558
559         result = f(good)
560         return (result, failed_last_time)
561
562     def generate_user_creds(self):
563         """Generate the conversation specific user Credentials.
564
565         Each Conversation has an associated user account used to simulate
566         any non Administrative user traffic.
567
568         Generates user credentials with good and bad passwords and ldap
569         simple bind credentials with good and bad passwords.
570         """
571         self.user_creds = Credentials()
572         self.user_creds.guess(self.lp)
573         self.user_creds.set_workstation(self.netbios_name)
574         self.user_creds.set_password(self.userpass)
575         self.user_creds.set_username(self.username)
576         self.user_creds.set_domain(self.domain)
577         self.user_creds.set_kerberos_state(self.kerberos_state)
578
579         self.user_creds_bad = Credentials()
580         self.user_creds_bad.guess(self.lp)
581         self.user_creds_bad.set_workstation(self.netbios_name)
582         self.user_creds_bad.set_password(self.userpass[:-4])
583         self.user_creds_bad.set_username(self.username)
584         self.user_creds_bad.set_kerberos_state(self.kerberos_state)
585
586         # Credentials for ldap simple bind.
587         self.simple_bind_creds = Credentials()
588         self.simple_bind_creds.guess(self.lp)
589         self.simple_bind_creds.set_workstation(self.netbios_name)
590         self.simple_bind_creds.set_password(self.userpass)
591         self.simple_bind_creds.set_username(self.username)
592         self.simple_bind_creds.set_gensec_features(
593             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
594         self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
595         self.simple_bind_creds.set_bind_dn(self.user_dn)
596
597         self.simple_bind_creds_bad = Credentials()
598         self.simple_bind_creds_bad.guess(self.lp)
599         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
600         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
601         self.simple_bind_creds_bad.set_username(self.username)
602         self.simple_bind_creds_bad.set_gensec_features(
603             self.simple_bind_creds_bad.get_gensec_features() |
604             gensec.FEATURE_SEAL)
605         self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
606         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
607
608     def generate_machine_creds(self):
609         """Generate the conversation specific machine Credentials.
610
611         Each Conversation has an associated machine account.
612
613         Generates machine credentials with good and bad passwords.
614         """
615
616         self.machine_creds = Credentials()
617         self.machine_creds.guess(self.lp)
618         self.machine_creds.set_workstation(self.netbios_name)
619         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
620         self.machine_creds.set_password(self.machinepass)
621         self.machine_creds.set_username(self.netbios_name + "$")
622         self.machine_creds.set_domain(self.domain)
623         self.machine_creds.set_kerberos_state(self.kerberos_state)
624
625         self.machine_creds_bad = Credentials()
626         self.machine_creds_bad.guess(self.lp)
627         self.machine_creds_bad.set_workstation(self.netbios_name)
628         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
629         self.machine_creds_bad.set_password(self.machinepass[:-4])
630         self.machine_creds_bad.set_username(self.netbios_name + "$")
631         self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
632
633     def get_matching_dn(self, pattern, attributes=None):
634         # If the pattern is an empty string, we assume ROOTDSE,
635         # Otherwise we try adding or removing DC suffixes, then
636         # shorter leading patterns until we hit one.
637         # e.g if there is no CN,CN,CN,CN,DC,DC
638         # we first try       CN,CN,CN,CN,DC
639         # and                CN,CN,CN,CN,DC,DC,DC
640         # then change to        CN,CN,CN,DC,DC
641         # and as last resort we use the base_dn
642         attr_clue = self.attribute_clue_map.get(attributes)
643         if attr_clue:
644             return random.choice(attr_clue)
645
646         pattern = pattern.upper()
647         while pattern:
648             if pattern in self.dn_map:
649                 return random.choice(self.dn_map[pattern])
650             # chop one off the front and try it all again.
651             pattern = pattern[3:]
652
653         return self.base_dn
654
655     def get_dcerpc_connection(self, new=False):
656         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
657         if self.dcerpc_connections and not new:
658             return self.dcerpc_connections[-1]
659         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
660                              (guid, 1), self.lp)
661         self.dcerpc_connections.append(c)
662         return c
663
664     def get_srvsvc_connection(self, new=False):
665         if self.srvsvc_connections and not new:
666             return self.srvsvc_connections[-1]
667
668         def connect(creds):
669             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
670                                  self.lp,
671                                  creds)
672
673         (c, self.last_srvsvc_bad) = \
674             self.with_random_bad_credentials(connect,
675                                              self.user_creds,
676                                              self.user_creds_bad,
677                                              self.last_srvsvc_bad)
678
679         self.srvsvc_connections.append(c)
680         return c
681
682     def get_lsarpc_connection(self, new=False):
683         if self.lsarpc_connections and not new:
684             return self.lsarpc_connections[-1]
685
686         def connect(creds):
687             binding_options = 'schannel,seal,sign'
688             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
689                               (self.server, binding_options),
690                               self.lp,
691                               creds)
692
693         (c, self.last_lsarpc_bad) = \
694             self.with_random_bad_credentials(connect,
695                                              self.machine_creds,
696                                              self.machine_creds_bad,
697                                              self.last_lsarpc_bad)
698
699         self.lsarpc_connections.append(c)
700         return c
701
702     def get_lsarpc_named_pipe_connection(self, new=False):
703         if self.lsarpc_connections_named and not new:
704             return self.lsarpc_connections_named[-1]
705
706         def connect(creds):
707             return lsa.lsarpc("ncacn_np:%s" % (self.server),
708                               self.lp,
709                               creds)
710
711         (c, self.last_lsarpc_named_bad) = \
712             self.with_random_bad_credentials(connect,
713                                              self.machine_creds,
714                                              self.machine_creds_bad,
715                                              self.last_lsarpc_named_bad)
716
717         self.lsarpc_connections_named.append(c)
718         return c
719
720     def get_drsuapi_connection_pair(self, new=False, unbind=False):
721         """get a (drs, drs_handle) tuple"""
722         if self.drsuapi_connections and not new:
723             c = self.drsuapi_connections[-1]
724             return c
725
726         def connect(creds):
727             binding_options = 'seal'
728             binding_string = "ncacn_ip_tcp:%s[%s]" %\
729                              (self.server, binding_options)
730             return drsuapi.drsuapi(binding_string, self.lp, creds)
731
732         (drs, self.last_drsuapi_bad) = \
733             self.with_random_bad_credentials(connect,
734                                              self.user_creds,
735                                              self.user_creds_bad,
736                                              self.last_drsuapi_bad)
737
738         (drs_handle, supported_extensions) = drs_DsBind(drs)
739         c = (drs, drs_handle)
740         self.drsuapi_connections.append(c)
741         return c
742
743     def get_ldap_connection(self, new=False, simple=False):
744         if self.ldap_connections and not new:
745             return self.ldap_connections[-1]
746
747         def simple_bind(creds):
748             """
749             To run simple bind against Windows, we need to run
750             following commands in PowerShell:
751
752                 Install-windowsfeature ADCS-Cert-Authority
753                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
754                 Restart-Computer
755
756             """
757             return SamDB('ldaps://%s' % self.server,
758                          credentials=creds,
759                          lp=self.lp)
760
761         def sasl_bind(creds):
762             return SamDB('ldap://%s' % self.server,
763                          credentials=creds,
764                          lp=self.lp)
765         if simple:
766             (samdb, self.last_simple_bind_bad) = \
767                 self.with_random_bad_credentials(simple_bind,
768                                                  self.simple_bind_creds,
769                                                  self.simple_bind_creds_bad,
770                                                  self.last_simple_bind_bad)
771         else:
772             (samdb, self.last_bind_bad) = \
773                 self.with_random_bad_credentials(sasl_bind,
774                                                  self.user_creds,
775                                                  self.user_creds_bad,
776                                                  self.last_bind_bad)
777
778         self.ldap_connections.append(samdb)
779         return samdb
780
781     def get_samr_context(self, new=False):
782         if not self.samr_contexts or new:
783             self.samr_contexts.append(
784                 SamrContext(self.server, lp=self.lp, creds=self.creds))
785         return self.samr_contexts[-1]
786
787     def get_netlogon_connection(self):
788
789         if self.netlogon_connection:
790             return self.netlogon_connection
791
792         def connect(creds):
793             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
794                                      (self.server),
795                                      self.lp,
796                                      creds)
797         (c, self.last_netlogon_bad) = \
798             self.with_random_bad_credentials(connect,
799                                              self.machine_creds,
800                                              self.machine_creds_bad,
801                                              self.last_netlogon_bad)
802         self.netlogon_connection = c
803         return c
804
805     def guess_a_dns_lookup(self):
806         return (self.realm, 'A')
807
808     def get_authenticator(self):
809         auth = self.machine_creds.new_client_authenticator()
810         current  = netr_Authenticator()
811         current.cred.data = [x if isinstance(x, int) else ord(x)
812                              for x in auth["credential"]]
813         current.timestamp = auth["timestamp"]
814
815         subsequent = netr_Authenticator()
816         return (current, subsequent)
817
818     def write_stats(self, filename, **kwargs):
819         """Write arbitrary key/value pairs to a file in our stats directory in
820         order for them to be picked up later by another process working out
821         statistics."""
822         filename = os.path.join(self.statsdir, filename)
823         f = open(filename, 'w')
824         for k, v in kwargs.items():
825             print("%s: %s" % (k, v), file=f)
826         f.close()
827
828
829 class SamrContext(object):
830     """State/Context associated with a samr connection.
831     """
832     def __init__(self, server, lp=None, creds=None):
833         self.connection    = None
834         self.handle        = None
835         self.domain_handle = None
836         self.domain_sid    = None
837         self.group_handle  = None
838         self.user_handle   = None
839         self.rids          = None
840         self.server        = server
841         self.lp            = lp
842         self.creds         = creds
843
844     def get_connection(self):
845         if not self.connection:
846             self.connection = samr.samr(
847                 "ncacn_ip_tcp:%s[seal]" % (self.server),
848                 lp_ctx=self.lp,
849                 credentials=self.creds)
850
851         return self.connection
852
853     def get_handle(self):
854         if not self.handle:
855             c = self.get_connection()
856             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
857         return self.handle
858
859
860 class Conversation(object):
861     """Details of a converation between a simulated client and a server."""
862     def __init__(self, start_time=None, endpoints=None, seq=(),
863                  conversation_id=None):
864         self.start_time = start_time
865         self.endpoints = endpoints
866         self.packets = []
867         self.msg = random_colour_print(endpoints)
868         self.client_balance = 0.0
869         self.conversation_id = conversation_id
870         for p in seq:
871             self.add_short_packet(*p)
872
873     def __cmp__(self, other):
874         if self.start_time is None:
875             if other.start_time is None:
876                 return 0
877             return -1
878         if other.start_time is None:
879             return 1
880         return self.start_time - other.start_time
881
882     def add_packet(self, packet):
883         """Add a packet object to this conversation, making a local copy with
884         a conversation-relative timestamp."""
885         p = packet.copy()
886
887         if self.start_time is None:
888             self.start_time = p.timestamp
889
890         if self.endpoints is None:
891             self.endpoints = p.endpoints
892
893         if p.endpoints != self.endpoints:
894             raise FakePacketError("Conversation endpoints %s don't match"
895                                   "packet endpoints %s" %
896                                   (self.endpoints, p.endpoints))
897
898         p.timestamp -= self.start_time
899
900         if p.src == p.endpoints[0]:
901             self.client_balance -= p.client_score()
902         else:
903             self.client_balance += p.client_score()
904
905         if p.is_really_a_packet():
906             self.packets.append(p)
907
908     def add_short_packet(self, timestamp, protocol, opcode, extra,
909                          client=True, skip_unused_packets=True):
910         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
911         (possibly empty) list of extra data. If client is True, assume
912         this packet is from the client to the server.
913         """
914         if skip_unused_packets and not is_a_real_packet(protocol, opcode):
915             return
916
917         src, dest = self.guess_client_server()
918         if not client:
919             src, dest = dest, src
920         key = (protocol, opcode)
921         desc = OP_DESCRIPTIONS.get(key, '')
922         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
923         packet = Packet(timestamp - self.start_time, ip_protocol,
924                         '', src, dest,
925                         protocol, opcode, desc, extra)
926         # XXX we're assuming the timestamp is already adjusted for
927         # this conversation?
928         # XXX should we adjust client balance for guessed packets?
929         if packet.src == packet.endpoints[0]:
930             self.client_balance -= packet.client_score()
931         else:
932             self.client_balance += packet.client_score()
933         if packet.is_really_a_packet():
934             self.packets.append(packet)
935
936     def __str__(self):
937         return ("<Conversation %s %s starting %.3f %d packets>" %
938                 (self.conversation_id, self.endpoints, self.start_time,
939                  len(self.packets)))
940
941     __repr__ = __str__
942
943     def __iter__(self):
944         return iter(self.packets)
945
946     def __len__(self):
947         return len(self.packets)
948
949     def get_duration(self):
950         if len(self.packets) < 2:
951             return 0
952         return self.packets[-1].timestamp - self.packets[0].timestamp
953
954     def replay_as_summary_lines(self):
955         return [p.as_summary(self.start_time) for p in self.packets]
956
957     def replay_with_delay(self, start, context=None, account=None):
958         """Replay the conversation at the right time.
959         (We're already in a fork)."""
960         # first we sleep until the first packet
961         t = self.start_time
962         now = time.time() - start
963         gap = t - now
964         sleep_time = gap - SLEEP_OVERHEAD
965         if sleep_time > 0:
966             time.sleep(sleep_time)
967
968         miss = (time.time() - start) - t
969         self.msg("starting %s [miss %.3f]" % (self, miss))
970
971         max_gap = 0.0
972         max_sleep_miss = 0.0
973         # packet times are relative to conversation start
974         p_start = time.time()
975         for p in self.packets:
976             now = time.time() - p_start
977             gap = now - p.timestamp
978             if gap > max_gap:
979                 max_gap = gap
980             if gap < 0:
981                 sleep_time = -gap - SLEEP_OVERHEAD
982                 if sleep_time > 0:
983                     time.sleep(sleep_time)
984                     t = time.time() - p_start
985                     if t - p.timestamp > max_sleep_miss:
986                         max_sleep_miss = t - p.timestamp
987
988             p.play(self, context)
989
990         return max_gap, miss, max_sleep_miss
991
992     def guess_client_server(self, server_clue=None):
993         """Have a go at deciding who is the server and who is the client.
994         returns (client, server)
995         """
996         a, b = self.endpoints
997
998         if self.client_balance < 0:
999             return (a, b)
1000
1001         # in the absense of a clue, we will fall through to assuming
1002         # the lowest number is the server (which is usually true).
1003
1004         if self.client_balance == 0 and server_clue == b:
1005             return (a, b)
1006
1007         return (b, a)
1008
1009     def forget_packets_outside_window(self, s, e):
1010         """Prune any packets outside the timne window we're interested in
1011
1012         :param s: start of the window
1013         :param e: end of the window
1014         """
1015         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
1016         self.start_time = self.packets[0].timestamp if self.packets else None
1017
1018     def renormalise_times(self, start_time):
1019         """Adjust the packet start times relative to the new start time."""
1020         for p in self.packets:
1021             p.timestamp -= start_time
1022
1023         if self.start_time is not None:
1024             self.start_time -= start_time
1025
1026
1027 class DnsHammer(Conversation):
1028     """A lightweight conversation that generates a lot of dns:0 packets on
1029     the fly"""
1030
1031     def __init__(self, dns_rate, duration, query_file=None):
1032         n = int(dns_rate * duration)
1033         self.times = [random.uniform(0, duration) for i in range(n)]
1034         self.times.sort()
1035         self.rate = dns_rate
1036         self.duration = duration
1037         self.start_time = 0
1038         self.query_choices = self._get_query_choices(query_file=query_file)
1039
1040     def __str__(self):
1041         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1042                 (len(self.times), self.duration, self.rate))
1043
1044     def _get_query_choices(self, query_file=None):
1045         """
1046         Read dns query choices from a file, or return default
1047
1048         rname may contain format string like `{realm}`
1049         realm can be fetched from context.realm
1050         """
1051
1052         if query_file:
1053             with open(query_file, 'r') as f:
1054                 text = f.read()
1055             choices = []
1056             for line in text.splitlines():
1057                 line = line.strip()
1058                 if line and not line.startswith('#'):
1059                     args = line.split(',')
1060                     assert len(args) == 4
1061                     choices.append(args)
1062             return choices
1063         else:
1064             return [
1065                 (0, '{realm}', 'A', 'yes'),
1066                 (1, '{realm}', 'NS', 'yes'),
1067                 (2, '*.{realm}', 'A', 'no'),
1068                 (3, '*.{realm}', 'NS', 'no'),
1069                 (10, '_msdcs.{realm}', 'A', 'yes'),
1070                 (11, '_msdcs.{realm}', 'NS', 'yes'),
1071                 (20, 'nx.realm.com', 'A', 'no'),
1072                 (21, 'nx.realm.com', 'NS', 'no'),
1073                 (22, '*.nx.realm.com', 'A', 'no'),
1074                 (23, '*.nx.realm.com', 'NS', 'no'),
1075             ]
1076
1077     def replay(self, context=None):
1078         assert context
1079         assert context.realm
1080         start = time.time()
1081         for t in self.times:
1082             now = time.time() - start
1083             gap = t - now
1084             sleep_time = gap - SLEEP_OVERHEAD
1085             if sleep_time > 0:
1086                 time.sleep(sleep_time)
1087
1088             opcode, rname, rtype, exist = random.choice(self.query_choices)
1089             rname = rname.format(realm=context.realm)
1090             success = True
1091             packet_start = time.time()
1092             try:
1093                 answers = dns_query(rname, rtype)
1094                 if exist == 'yes' and not len(answers):
1095                     # expect answers but didn't get, fail
1096                     success = False
1097             except Exception:
1098                 success = False
1099             finally:
1100                 end = time.time()
1101                 duration = end - packet_start
1102                 print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
1103
1104
1105 def ingest_summaries(files, dns_mode='count'):
1106     """Load a summary traffic summary file and generated Converations from it.
1107     """
1108
1109     dns_counts = defaultdict(int)
1110     packets = []
1111     for f in files:
1112         if isinstance(f, str):
1113             f = open(f)
1114         print("Ingesting %s" % (f.name,), file=sys.stderr)
1115         for line in f:
1116             p = Packet.from_line(line)
1117             if p.protocol == 'dns' and dns_mode != 'include':
1118                 dns_counts[p.opcode] += 1
1119             else:
1120                 packets.append(p)
1121
1122         f.close()
1123
1124     if not packets:
1125         return [], 0
1126
1127     start_time = min(p.timestamp for p in packets)
1128     last_packet = max(p.timestamp for p in packets)
1129
1130     print("gathering packets into conversations", file=sys.stderr)
1131     conversations = OrderedDict()
1132     for i, p in enumerate(packets):
1133         p.timestamp -= start_time
1134         c = conversations.get(p.endpoints)
1135         if c is None:
1136             c = Conversation(conversation_id=(i + 2))
1137             conversations[p.endpoints] = c
1138         c.add_packet(p)
1139
1140     # We only care about conversations with actual traffic, so we
1141     # filter out conversations with nothing to say. We do that here,
1142     # rather than earlier, because those empty packets contain useful
1143     # hints as to which end of the conversation was the client.
1144     conversation_list = []
1145     for c in conversations.values():
1146         if len(c) != 0:
1147             conversation_list.append(c)
1148
1149     # This is obviously not correct, as many conversations will appear
1150     # to start roughly simultaneously at the beginning of the snapshot.
1151     # To which we say: oh well, so be it.
1152     duration = float(last_packet - start_time)
1153     mean_interval = len(conversations) / duration
1154
1155     return conversation_list, mean_interval, duration, dns_counts
1156
1157
1158 def guess_server_address(conversations):
1159     # we guess the most common address.
1160     addresses = Counter()
1161     for c in conversations:
1162         addresses.update(c.endpoints)
1163     if addresses:
1164         return addresses.most_common(1)[0]
1165
1166
1167 def stringify_keys(x):
1168     y = {}
1169     for k, v in x.items():
1170         k2 = '\t'.join(k)
1171         y[k2] = v
1172     return y
1173
1174
1175 def unstringify_keys(x):
1176     y = {}
1177     for k, v in x.items():
1178         t = tuple(str(k).split('\t'))
1179         y[t] = v
1180     return y
1181
1182
1183 class TrafficModel(object):
1184     def __init__(self, n=3):
1185         self.ngrams = {}
1186         self.query_details = {}
1187         self.n = n
1188         self.dns_opcounts = defaultdict(int)
1189         self.cumulative_duration = 0.0
1190         self.packet_rate = [0, 1]
1191
1192     def learn(self, conversations, dns_opcounts={}):
1193         prev = 0.0
1194         cum_duration = 0.0
1195         key = (NON_PACKET,) * (self.n - 1)
1196
1197         server = guess_server_address(conversations)
1198
1199         for k, v in dns_opcounts.items():
1200             self.dns_opcounts[k] += v
1201
1202         if len(conversations) > 1:
1203             first = conversations[0].start_time
1204             total = 0
1205             last = first + 0.1
1206             for c in conversations:
1207                 total += len(c)
1208                 last = max(last, c.packets[-1].timestamp)
1209
1210             self.packet_rate[0] = total
1211             self.packet_rate[1] = last - first
1212
1213         for c in conversations:
1214             client, server = c.guess_client_server(server)
1215             cum_duration += c.get_duration()
1216             key = (NON_PACKET,) * (self.n - 1)
1217             for p in c:
1218                 if p.src != client:
1219                     continue
1220
1221                 elapsed = p.timestamp - prev
1222                 prev = p.timestamp
1223                 if elapsed > WAIT_THRESHOLD:
1224                     # add the wait as an extra state
1225                     wait = 'wait:%d' % (math.log(max(1.0,
1226                                                      elapsed * WAIT_SCALE)))
1227                     self.ngrams.setdefault(key, []).append(wait)
1228                     key = key[1:] + (wait,)
1229
1230                 short_p = p.as_packet_type()
1231                 self.query_details.setdefault(short_p,
1232                                               []).append(tuple(p.extra))
1233                 self.ngrams.setdefault(key, []).append(short_p)
1234                 key = key[1:] + (short_p,)
1235
1236         self.cumulative_duration += cum_duration
1237         # add in the end
1238         self.ngrams.setdefault(key, []).append(NON_PACKET)
1239
1240     def save(self, f):
1241         ngrams = {}
1242         for k, v in self.ngrams.items():
1243             k = '\t'.join(k)
1244             ngrams[k] = dict(Counter(v))
1245
1246         query_details = {}
1247         for k, v in self.query_details.items():
1248             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1249                                             for x in v))
1250
1251         d = {
1252             'ngrams': ngrams,
1253             'query_details': query_details,
1254             'cumulative_duration': self.cumulative_duration,
1255             'packet_rate': self.packet_rate,
1256             'version': CURRENT_MODEL_VERSION
1257         }
1258         d['dns'] = self.dns_opcounts
1259
1260         if isinstance(f, str):
1261             f = open(f, 'w')
1262
1263         json.dump(d, f, indent=2)
1264
1265     def load(self, f):
1266         if isinstance(f, str):
1267             f = open(f)
1268
1269         d = json.load(f)
1270
1271         try:
1272             version = d["version"]
1273             if version < REQUIRED_MODEL_VERSION:
1274                 raise ValueError("the model file is version %d; "
1275                                  "version %d is required" %
1276                                  (version, REQUIRED_MODEL_VERSION))
1277         except KeyError:
1278                 raise ValueError("the model file lacks a version number; "
1279                                  "version %d is required" %
1280                                  (REQUIRED_MODEL_VERSION))
1281
1282         for k, v in d['ngrams'].items():
1283             k = tuple(str(k).split('\t'))
1284             values = self.ngrams.setdefault(k, [])
1285             for p, count in v.items():
1286                 values.extend([str(p)] * count)
1287             values.sort()
1288
1289         for k, v in d['query_details'].items():
1290             values = self.query_details.setdefault(str(k), [])
1291             for p, count in v.items():
1292                 if p == '-':
1293                     values.extend([()] * count)
1294                 else:
1295                     values.extend([tuple(str(p).split('\t'))] * count)
1296             values.sort()
1297
1298         if 'dns' in d:
1299             for k, v in d['dns'].items():
1300                 self.dns_opcounts[k] += v
1301
1302         self.cumulative_duration = d['cumulative_duration']
1303         self.packet_rate = d['packet_rate']
1304
1305     def construct_conversation_sequence(self, timestamp=0.0,
1306                                         hard_stop=None,
1307                                         replay_speed=1,
1308                                         ignore_before=0,
1309                                         persistence=0):
1310         """Construct an individual conversation packet sequence from the
1311         model.
1312         """
1313         c = []
1314         key = (NON_PACKET,) * (self.n - 1)
1315         if ignore_before is None:
1316             ignore_before = timestamp - 1
1317
1318         while True:
1319             p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1320             if p == NON_PACKET:
1321                 if timestamp < ignore_before:
1322                     break
1323                 if random.random() > persistence:
1324                     print("ending after %s (persistence %.1f)" % (key, persistence),
1325                           file=sys.stderr)
1326                     break
1327
1328                 p = 'wait:%d' % random.randrange(5, 12)
1329                 print("trying %s instead of end" % p, file=sys.stderr)
1330
1331             if p in self.query_details:
1332                 extra = random.choice(self.query_details[p])
1333             else:
1334                 extra = []
1335
1336             protocol, opcode = p.split(':', 1)
1337             if protocol == 'wait':
1338                 log_wait_time = int(opcode) + random.random()
1339                 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1340                 timestamp += wait
1341             else:
1342                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1343                 wait = math.exp(log_wait) / replay_speed
1344                 timestamp += wait
1345                 if hard_stop is not None and timestamp > hard_stop:
1346                     break
1347                 if timestamp >= ignore_before:
1348                     c.append((timestamp, protocol, opcode, extra))
1349
1350             key = key[1:] + (p,)
1351             if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
1352                 # two waits in a row can only be caused by "persistence"
1353                 # tricks, and will not result in any packets being found.
1354                 # Instead we pretend this is a fresh start.
1355                 key = (NON_PACKET,) * (self.n - 1)
1356
1357         return c
1358
1359     def scale_to_packet_rate(self, scale):
1360         rate_n, rate_t  = self.packet_rate
1361         return scale * rate_n / rate_t
1362
1363     def packet_rate_to_scale(self, pps):
1364         rate_n, rate_t  = self.packet_rate
1365         return  pps * rate_t / rate_n
1366
1367     def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
1368                                         persistence=0):
1369         """Generate a list of conversation descriptions from the model."""
1370
1371         # We run the simulation for ten times as long as our desired
1372         # duration, and take the section at the end.
1373         lead_in = 9 * duration
1374         target_packets = int(packet_rate * duration)
1375         conversations = []
1376         n_packets = 0
1377
1378         while n_packets < target_packets:
1379             start = random.uniform(-lead_in, duration)
1380             c = self.construct_conversation_sequence(start,
1381                                                      hard_stop=duration,
1382                                                      replay_speed=replay_speed,
1383                                                      ignore_before=0,
1384                                                      persistence=persistence)
1385             # will these "packets" generate actual traffic?
1386             # some (e.g. ldap unbind) will not generate anything
1387             # if the previous packets are not there, and if the
1388             # conversation only has those it wastes a process doing nothing.
1389             for timestamp, protocol, opcode, extra in c:
1390                 if is_a_traffic_generating_packet(protocol, opcode):
1391                     break
1392             else:
1393                 continue
1394
1395             conversations.append(c)
1396             n_packets += len(c)
1397
1398         scale = self.packet_rate_to_scale(packet_rate)
1399         print(("we have %d packets (target %d) in %d conversations at %.1f/s "
1400                "(scale %f)" % (n_packets, target_packets, len(conversations),
1401                                packet_rate, scale)),
1402               file=sys.stderr)
1403         conversations.sort()  # sorts by first element == start time
1404         return conversations
1405
1406
1407 def seq_to_conversations(seq, server=1, client=2):
1408     conversations = []
1409     for s in seq:
1410         if s:
1411             c = Conversation(s[0][0], (server, client), s)
1412             client += 1
1413             conversations.append(c)
1414     return conversations
1415
1416
1417 IP_PROTOCOLS = {
1418     'dns': '11',
1419     'rpc_netlogon': '06',
1420     'kerberos': '06',      # ratio 16248:258
1421     'smb': '06',
1422     'smb2': '06',
1423     'ldap': '06',
1424     'cldap': '11',
1425     'lsarpc': '06',
1426     'samr': '06',
1427     'dcerpc': '06',
1428     'epm': '06',
1429     'drsuapi': '06',
1430     'browser': '11',
1431     'smb_netlogon': '11',
1432     'srvsvc': '06',
1433     'nbns': '11',
1434 }
1435
1436 OP_DESCRIPTIONS = {
1437     ('browser', '0x01'): 'Host Announcement (0x01)',
1438     ('browser', '0x02'): 'Request Announcement (0x02)',
1439     ('browser', '0x08'): 'Browser Election Request (0x08)',
1440     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1441     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1442     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1443     ('cldap', '3'): 'searchRequest',
1444     ('cldap', '5'): 'searchResDone',
1445     ('dcerpc', '0'): 'Request',
1446     ('dcerpc', '11'): 'Bind',
1447     ('dcerpc', '12'): 'Bind_ack',
1448     ('dcerpc', '13'): 'Bind_nak',
1449     ('dcerpc', '14'): 'Alter_context',
1450     ('dcerpc', '15'): 'Alter_context_resp',
1451     ('dcerpc', '16'): 'AUTH3',
1452     ('dcerpc', '2'): 'Response',
1453     ('dns', '0'): 'query',
1454     ('dns', '1'): 'response',
1455     ('drsuapi', '0'): 'DsBind',
1456     ('drsuapi', '12'): 'DsCrackNames',
1457     ('drsuapi', '13'): 'DsWriteAccountSpn',
1458     ('drsuapi', '1'): 'DsUnbind',
1459     ('drsuapi', '2'): 'DsReplicaSync',
1460     ('drsuapi', '3'): 'DsGetNCChanges',
1461     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1462     ('epm', '3'): 'Map',
1463     ('kerberos', ''): '',
1464     ('ldap', '0'): 'bindRequest',
1465     ('ldap', '1'): 'bindResponse',
1466     ('ldap', '2'): 'unbindRequest',
1467     ('ldap', '3'): 'searchRequest',
1468     ('ldap', '4'): 'searchResEntry',
1469     ('ldap', '5'): 'searchResDone',
1470     ('ldap', ''): '*** Unknown ***',
1471     ('lsarpc', '14'): 'lsa_LookupNames',
1472     ('lsarpc', '15'): 'lsa_LookupSids',
1473     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1474     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1475     ('lsarpc', '6'): 'lsa_OpenPolicy',
1476     ('lsarpc', '76'): 'lsa_LookupSids3',
1477     ('lsarpc', '77'): 'lsa_LookupNames4',
1478     ('nbns', '0'): 'query',
1479     ('nbns', '1'): 'response',
1480     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1481     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1482     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1483     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1484     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1485     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1486     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1487     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1488     ('samr', '0',): 'Connect',
1489     ('samr', '16'): 'GetAliasMembership',
1490     ('samr', '17'): 'LookupNames',
1491     ('samr', '18'): 'LookupRids',
1492     ('samr', '19'): 'OpenGroup',
1493     ('samr', '1'): 'Close',
1494     ('samr', '25'): 'QueryGroupMember',
1495     ('samr', '34'): 'OpenUser',
1496     ('samr', '36'): 'QueryUserInfo',
1497     ('samr', '39'): 'GetGroupsForUser',
1498     ('samr', '3'): 'QuerySecurity',
1499     ('samr', '5'): 'LookupDomain',
1500     ('samr', '64'): 'Connect5',
1501     ('samr', '6'): 'EnumDomains',
1502     ('samr', '7'): 'OpenDomain',
1503     ('samr', '8'): 'QueryDomainInfo',
1504     ('smb', '0x04'): 'Close (0x04)',
1505     ('smb', '0x24'): 'Locking AndX (0x24)',
1506     ('smb', '0x2e'): 'Read AndX (0x2e)',
1507     ('smb', '0x32'): 'Trans2 (0x32)',
1508     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1509     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1510     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1511     ('smb', '0x74'): 'Logoff AndX (0x74)',
1512     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1513     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1514     ('smb2', '0'): 'NegotiateProtocol',
1515     ('smb2', '11'): 'Ioctl',
1516     ('smb2', '14'): 'Find',
1517     ('smb2', '16'): 'GetInfo',
1518     ('smb2', '18'): 'Break',
1519     ('smb2', '1'): 'SessionSetup',
1520     ('smb2', '2'): 'SessionLogoff',
1521     ('smb2', '3'): 'TreeConnect',
1522     ('smb2', '4'): 'TreeDisconnect',
1523     ('smb2', '5'): 'Create',
1524     ('smb2', '6'): 'Close',
1525     ('smb2', '8'): 'Read',
1526     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1527     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1528                                'user unknown (0x17)'),
1529     ('srvsvc', '16'): 'NetShareGetInfo',
1530     ('srvsvc', '21'): 'NetSrvGetInfo',
1531 }
1532
1533
1534 def expand_short_packet(p, timestamp, src, dest, extra):
1535     protocol, opcode = p.split(':', 1)
1536     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1537     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1538
1539     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1540     line.extend(extra)
1541     return '\t'.join(line)
1542
1543
1544 def flushing_signal_handler(signal, frame):
1545     """Signal handler closes standard out and error.
1546
1547     Triggered by a sigterm, ensures that the log messages are flushed
1548     to disk and not lost.
1549     """
1550     sys.stderr.close()
1551     sys.stdout.close()
1552     os._exit(0)
1553
1554
1555 def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1556     """Fork a new process and replay the conversation sequence."""
1557     # We will need to reseed the random number generator or all the
1558     # clients will end up using the same sequence of random
1559     # numbers. random.randint() is mixed in so the initial seed will
1560     # have an effect here.
1561     seed = client_id * 1000 + random.randint(0, 999)
1562
1563     # flush our buffers so messages won't be written by both sides
1564     sys.stdout.flush()
1565     sys.stderr.flush()
1566     pid = os.fork()
1567     if pid != 0:
1568         return pid
1569
1570     # we must never return, or we'll end up running parts of the
1571     # parent's clean-up code. So we work in a try...finally, and
1572     # try to print any exceptions.
1573     try:
1574         random.seed(seed)
1575         endpoints = (server_id, client_id)
1576         status = 0
1577         t = cs[0][0]
1578         c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1579         signal.signal(signal.SIGTERM, flushing_signal_handler)
1580
1581         context.generate_process_local_config(account, c)
1582         sys.stdin.close()
1583         os.close(0)
1584         filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1585                                 c.conversation_id)
1586         f = open(filename, 'w')
1587         try:
1588             sys.stdout.close()
1589             os.close(1)
1590         except IOError as e:
1591             LOGGER.info("stdout closing failed with %s" % e)
1592             pass
1593
1594         sys.stdout = f
1595         now = time.time() - start
1596         gap = t - now
1597         sleep_time = gap - SLEEP_OVERHEAD
1598         if sleep_time > 0:
1599             time.sleep(sleep_time)
1600
1601         max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1602                                                                  context=context)
1603         print("Maximum lag: %f" % max_lag)
1604         print("Start lag: %f" % start_lag)
1605         print("Max sleep miss: %f" % max_sleep_miss)
1606
1607     except Exception:
1608         status = 1
1609         print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1610               file=sys.stderr)
1611         traceback.print_exc(sys.stderr)
1612         sys.stderr.flush()
1613     finally:
1614         sys.stderr.close()
1615         sys.stdout.close()
1616         os._exit(status)
1617
1618
1619 def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
1620     sys.stdout.flush()
1621     sys.stderr.flush()
1622     pid = os.fork()
1623     if pid != 0:
1624         return pid
1625
1626     sys.stdin.close()
1627     os.close(0)
1628
1629     try:
1630         sys.stdout.close()
1631         os.close(1)
1632     except IOError as e:
1633         LOGGER.warn("stdout closing failed with %s" % e)
1634         pass
1635     filename = os.path.join(context.statsdir, 'stats-dns')
1636     sys.stdout = open(filename, 'w')
1637
1638     try:
1639         status = 0
1640         signal.signal(signal.SIGTERM, flushing_signal_handler)
1641         hammer = DnsHammer(dns_rate, duration, query_file=query_file)
1642         hammer.replay(context=context)
1643     except Exception:
1644         status = 1
1645         print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1646               file=sys.stderr)
1647         traceback.print_exc(sys.stderr)
1648     finally:
1649         sys.stderr.close()
1650         sys.stdout.close()
1651         os._exit(status)
1652
1653
1654 def replay(conversation_seq,
1655            host=None,
1656            creds=None,
1657            lp=None,
1658            accounts=None,
1659            dns_rate=0,
1660            dns_query_file=None,
1661            duration=None,
1662            latency_timeout=1.0,
1663            stop_on_any_error=False,
1664            **kwargs):
1665
1666     context = ReplayContext(server=host,
1667                             creds=creds,
1668                             lp=lp,
1669                             total_conversations=len(conversation_seq),
1670                             **kwargs)
1671
1672     if len(accounts) < len(conversation_seq):
1673         raise ValueError(("we have %d accounts but %d conversations" %
1674                           (len(accounts), len(conversation_seq))))
1675
1676     # Set the process group so that the calling scripts are not killed
1677     # when the forked child processes are killed.
1678     os.setpgrp()
1679
1680     # we delay the start by a bit to allow all the forks to get up and
1681     # running.
1682     delay = len(conversation_seq) * 0.02
1683     start = time.time() + delay
1684
1685     if duration is None:
1686         # end slightly after the last packet of the last conversation
1687         # to start. Conversations other than the last could still be
1688         # going, but we don't care.
1689         duration = conversation_seq[-1][-1][0] + latency_timeout
1690
1691     print("We will start in %.1f seconds" % delay,
1692           file=sys.stderr)
1693     print("We will stop after %.1f seconds" % (duration + delay),
1694           file=sys.stderr)
1695     print("runtime %.1f seconds" % duration,
1696           file=sys.stderr)
1697
1698     # give one second grace for packets to finish before killing begins
1699     end = start + duration + 1.0
1700
1701     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1702           % (len(conversation_seq), duration))
1703
1704     context.write_stats('intentions',
1705                         Planned_conversations=len(conversation_seq),
1706                         Planned_packets=sum(len(x) for x in conversation_seq))
1707
1708     children = {}
1709     try:
1710         if dns_rate:
1711             pid = dnshammer_in_fork(dns_rate, duration, context,
1712                                     query_file=dns_query_file)
1713             children[pid] = 1
1714
1715         for i, cs in enumerate(conversation_seq):
1716             account = accounts[i]
1717             client_id = i + 2
1718             pid = replay_seq_in_fork(cs, start, context, account, client_id)
1719             children[pid] = client_id
1720
1721         # HERE, we are past all the forks
1722         t = time.time()
1723         print("all forks done in %.1f seconds, waiting %.1f" %
1724               (t - start + delay, t - start),
1725               file=sys.stderr)
1726
1727         while time.time() < end and children:
1728             time.sleep(0.003)
1729             try:
1730                 pid, status = os.waitpid(-1, os.WNOHANG)
1731             except OSError as e:
1732                 if e.errno != ECHILD:  # no child processes
1733                     raise
1734                 break
1735             if pid:
1736                 c = children.pop(pid, None)
1737                 if DEBUG_LEVEL > 0:
1738                     print(("process %d finished conversation %d;"
1739                            " %d to go" %
1740                            (pid, c, len(children))), file=sys.stderr)
1741                 if stop_on_any_error and status != 0:
1742                     break
1743
1744     except Exception:
1745         print("EXCEPTION in parent", file=sys.stderr)
1746         traceback.print_exc()
1747     finally:
1748         context.write_stats('unfinished',
1749                             Unfinished_conversations=len(children))
1750
1751         for s in (15, 15, 9):
1752             print(("killing %d children with -%d" %
1753                    (len(children), s)), file=sys.stderr)
1754             for pid in children:
1755                 try:
1756                     os.kill(pid, s)
1757                 except OSError as e:
1758                     if e.errno != ESRCH:  # don't fail if it has already died
1759                         raise
1760             time.sleep(0.5)
1761             end = time.time() + 1
1762             while children:
1763                 try:
1764                     pid, status = os.waitpid(-1, os.WNOHANG)
1765                 except OSError as e:
1766                     if e.errno != ECHILD:
1767                         raise
1768                 if pid != 0:
1769                     c = children.pop(pid, None)
1770                     if c is None:
1771                         print("children is %s, no pid found" % children)
1772                         sys.stderr.flush()
1773                         sys.stdout.flush()
1774                         os._exit(1)
1775                     print(("kill -%d %d KILLED conversation; "
1776                            "%d to go" %
1777                            (s, pid, len(children))),
1778                           file=sys.stderr)
1779                 if time.time() >= end:
1780                     break
1781
1782             if not children:
1783                 break
1784             time.sleep(1)
1785
1786         if children:
1787             print("%d children are missing" % len(children),
1788                   file=sys.stderr)
1789
1790         # there may be stragglers that were forked just as ^C was hit
1791         # and don't appear in the list of children. We can get them
1792         # with killpg, but that will also kill us, so this is^H^H would be
1793         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1794         # so as not to have to fuss around writing signal handlers.
1795         try:
1796             os.killpg(0, 2)
1797         except KeyboardInterrupt:
1798             print("ignoring fake ^C", file=sys.stderr)
1799
1800
1801 def openLdb(host, creds, lp):
1802     session = system_session()
1803     ldb = SamDB(url="ldap://%s" % host,
1804                 session_info=session,
1805                 options=['modules:paged_searches'],
1806                 credentials=creds,
1807                 lp=lp)
1808     return ldb
1809
1810
1811 def ou_name(ldb, instance_id):
1812     """Generate an ou name from the instance id"""
1813     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1814                                                     ldb.domain_dn())
1815
1816
1817 def create_ou(ldb, instance_id):
1818     """Create an ou, all created user and machine accounts will belong to it.
1819
1820     This allows all the created resources to be cleaned up easily.
1821     """
1822     ou = ou_name(ldb, instance_id)
1823     try:
1824         ldb.add({"dn": ou.split(',', 1)[1],
1825                  "objectclass": "organizationalunit"})
1826     except LdbError as e:
1827         (status, _) = e.args
1828         # ignore already exists
1829         if status != 68:
1830             raise
1831     try:
1832         ldb.add({"dn": ou,
1833                  "objectclass": "organizationalunit"})
1834     except LdbError as e:
1835         (status, _) = e.args
1836         # ignore already exists
1837         if status != 68:
1838             raise
1839     return ou
1840
1841
1842 # ConversationAccounts holds details of the machine and user accounts
1843 # associated with a conversation.
1844 #
1845 # We use a named tuple to reduce shared memory usage.
1846 ConversationAccounts = namedtuple('ConversationAccounts',
1847                                   ('netbios_name',
1848                                    'machinepass',
1849                                    'username',
1850                                    'userpass'))
1851
1852
1853 def generate_replay_accounts(ldb, instance_id, number, password):
1854     """Generate a series of unique machine and user account names."""
1855
1856     accounts = []
1857     for i in range(1, number + 1):
1858         netbios_name = machine_name(instance_id, i)
1859         username = user_name(instance_id, i)
1860
1861         account = ConversationAccounts(netbios_name, password, username,
1862                                        password)
1863         accounts.append(account)
1864     return accounts
1865
1866
1867 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1868                            traffic_account=True):
1869     """Create a machine account via ldap."""
1870
1871     ou = ou_name(ldb, instance_id)
1872     dn = "cn=%s,%s" % (netbios_name, ou)
1873     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1874
1875     if traffic_account:
1876         # we set these bits for the machine account otherwise the replayed
1877         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1878         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1879                                UF_SERVER_TRUST_ACCOUNT)
1880
1881     else:
1882         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1883
1884     ldb.add({
1885         "dn": dn,
1886         "objectclass": "computer",
1887         "sAMAccountName": "%s$" % netbios_name,
1888         "userAccountControl": account_controls,
1889         "unicodePwd": utf16pw})
1890
1891
1892 def create_user_account(ldb, instance_id, username, userpass):
1893     """Create a user account via ldap."""
1894     ou = ou_name(ldb, instance_id)
1895     user_dn = "cn=%s,%s" % (username, ou)
1896     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1897     ldb.add({
1898         "dn": user_dn,
1899         "objectclass": "user",
1900         "sAMAccountName": username,
1901         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1902         "unicodePwd": utf16pw
1903     })
1904
1905     # grant user write permission to do things like write account SPN
1906     sdutils = sd_utils.SDUtils(ldb)
1907     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1908
1909
1910 def create_group(ldb, instance_id, name):
1911     """Create a group via ldap."""
1912
1913     ou = ou_name(ldb, instance_id)
1914     dn = "cn=%s,%s" % (name, ou)
1915     ldb.add({
1916         "dn": dn,
1917         "objectclass": "group",
1918         "sAMAccountName": name,
1919     })
1920
1921
1922 def user_name(instance_id, i):
1923     """Generate a user name based in the instance id"""
1924     return "STGU-%d-%d" % (instance_id, i)
1925
1926
1927 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1928     """Seach objectclass, return attr in a set"""
1929     objs = ldb.search(
1930         expression="(objectClass={})".format(objectclass),
1931         attrs=[attr]
1932     )
1933     return {str(obj[attr]) for obj in objs}
1934
1935
1936 def generate_users(ldb, instance_id, number, password):
1937     """Add users to the server"""
1938     existing_objects = search_objectclass(ldb, objectclass='user')
1939     users = 0
1940     for i in range(number, 0, -1):
1941         name = user_name(instance_id, i)
1942         if name not in existing_objects:
1943             create_user_account(ldb, instance_id, name, password)
1944             users += 1
1945             if users % 50 == 0:
1946                 LOGGER.info("Created %u/%u users" % (users, number))
1947
1948     return users
1949
1950
1951 def machine_name(instance_id, i, traffic_account=True):
1952     """Generate a machine account name from instance id."""
1953     if traffic_account:
1954         # traffic accounts correspond to a given user, and use different
1955         # userAccountControl flags to ensure packets get processed correctly
1956         # by the DC
1957         return "STGM-%d-%d" % (instance_id, i)
1958     else:
1959         # Otherwise we're just generating computer accounts to simulate a
1960         # semi-realistic network. These use the default computer
1961         # userAccountControl flags, so we use a different account name so that
1962         # we don't try to use them when generating packets
1963         return "PC-%d-%d" % (instance_id, i)
1964
1965
1966 def generate_machine_accounts(ldb, instance_id, number, password,
1967                               traffic_account=True):
1968     """Add machine accounts to the server"""
1969     existing_objects = search_objectclass(ldb, objectclass='computer')
1970     added = 0
1971     for i in range(number, 0, -1):
1972         name = machine_name(instance_id, i, traffic_account)
1973         if name + "$" not in existing_objects:
1974             create_machine_account(ldb, instance_id, name, password,
1975                                    traffic_account)
1976             added += 1
1977             if added % 50 == 0:
1978                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1979
1980     return added
1981
1982
1983 def group_name(instance_id, i):
1984     """Generate a group name from instance id."""
1985     return "STGG-%d-%d" % (instance_id, i)
1986
1987
1988 def generate_groups(ldb, instance_id, number):
1989     """Create the required number of groups on the server."""
1990     existing_objects = search_objectclass(ldb, objectclass='group')
1991     groups = 0
1992     for i in range(number, 0, -1):
1993         name = group_name(instance_id, i)
1994         if name not in existing_objects:
1995             create_group(ldb, instance_id, name)
1996             groups += 1
1997             if groups % 1000 == 0:
1998                 LOGGER.info("Created %u/%u groups" % (groups, number))
1999
2000     return groups
2001
2002
2003 def clean_up_accounts(ldb, instance_id):
2004     """Remove the created accounts and groups from the server."""
2005     ou = ou_name(ldb, instance_id)
2006     try:
2007         ldb.delete(ou, ["tree_delete:1"])
2008     except LdbError as e:
2009         (status, _) = e.args
2010         # ignore does not exist
2011         if status != 32:
2012             raise
2013
2014
2015 def generate_users_and_groups(ldb, instance_id, password,
2016                               number_of_users, number_of_groups,
2017                               group_memberships, max_members,
2018                               machine_accounts, traffic_accounts=True):
2019     """Generate the required users and groups, allocating the users to
2020        those groups."""
2021     memberships_added = 0
2022     groups_added = 0
2023     computers_added = 0
2024
2025     create_ou(ldb, instance_id)
2026
2027     LOGGER.info("Generating dummy user accounts")
2028     users_added = generate_users(ldb, instance_id, number_of_users, password)
2029
2030     LOGGER.info("Generating dummy machine accounts")
2031     computers_added = generate_machine_accounts(ldb, instance_id,
2032                                                 machine_accounts, password,
2033                                                 traffic_accounts)
2034
2035     if number_of_groups > 0:
2036         LOGGER.info("Generating dummy groups")
2037         groups_added = generate_groups(ldb, instance_id, number_of_groups)
2038
2039     if group_memberships > 0:
2040         LOGGER.info("Assigning users to groups")
2041         assignments = GroupAssignments(number_of_groups,
2042                                        groups_added,
2043                                        number_of_users,
2044                                        users_added,
2045                                        group_memberships,
2046                                        max_members)
2047         LOGGER.info("Adding users to groups")
2048         add_users_to_groups(ldb, instance_id, assignments)
2049         memberships_added = assignments.total()
2050
2051     if (groups_added > 0 and users_added == 0 and
2052        number_of_groups != groups_added):
2053         LOGGER.warning("The added groups will contain no members")
2054
2055     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
2056                 (users_added, computers_added, groups_added,
2057                  memberships_added))
2058
2059
2060 class GroupAssignments(object):
2061     def __init__(self, number_of_groups, groups_added, number_of_users,
2062                  users_added, group_memberships, max_members):
2063
2064         self.count = 0
2065         self.generate_group_distribution(number_of_groups)
2066         self.generate_user_distribution(number_of_users, group_memberships)
2067         self.max_members = max_members
2068         self.assignments = defaultdict(list)
2069         self.assign_groups(number_of_groups, groups_added, number_of_users,
2070                            users_added, group_memberships)
2071
2072     def cumulative_distribution(self, weights):
2073         # make sure the probabilities conform to a cumulative distribution
2074         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
2075         # probability a proportional share of 1.0. Higher probabilities get a
2076         # bigger share, so are more likely to be picked. We use the cumulative
2077         # value, so we can use random.random() as a simple index into the list
2078         dist = []
2079         total = sum(weights)
2080         if total == 0:
2081             return None
2082
2083         cumulative = 0.0
2084         for probability in weights:
2085             cumulative += probability
2086             dist.append(cumulative / total)
2087         return dist
2088
2089     def generate_user_distribution(self, num_users, num_memberships):
2090         """Probability distribution of a user belonging to a group.
2091         """
2092         # Assign a weighted probability to each user. Use the Pareto
2093         # Distribution so that some users are in a lot of groups, and the
2094         # bulk of users are in only a few groups. If we're assigning a large
2095         # number of group memberships, use a higher shape. This means slightly
2096         # fewer outlying users that are in large numbers of groups. The aim is
2097         # to have no users belonging to more than ~500 groups.
2098         if num_memberships > 5000000:
2099             shape = 3.0
2100         elif num_memberships > 2000000:
2101             shape = 2.5
2102         elif num_memberships > 300000:
2103             shape = 2.25
2104         else:
2105             shape = 1.75
2106
2107         weights = []
2108         for x in range(1, num_users + 1):
2109             p = random.paretovariate(shape)
2110             weights.append(p)
2111
2112         # convert the weights to a cumulative distribution between 0.0 and 1.0
2113         self.user_dist = self.cumulative_distribution(weights)
2114
2115     def generate_group_distribution(self, n):
2116         """Probability distribution of a group containing a user."""
2117
2118         # Assign a weighted probability to each user. Probability decreases
2119         # as the group-ID increases
2120         weights = []
2121         for x in range(1, n + 1):
2122             p = 1 / (x**1.3)
2123             weights.append(p)
2124
2125         # convert the weights to a cumulative distribution between 0.0 and 1.0
2126         self.group_weights = weights
2127         self.group_dist = self.cumulative_distribution(weights)
2128
2129     def generate_random_membership(self):
2130         """Returns a randomly generated user-group membership"""
2131
2132         # the list items are cumulative distribution values between 0.0 and
2133         # 1.0, which makes random() a handy way to index the list to get a
2134         # weighted random user/group. (Here the user/group returned are
2135         # zero-based array indexes)
2136         user = bisect.bisect(self.user_dist, random.random())
2137         group = bisect.bisect(self.group_dist, random.random())
2138
2139         return user, group
2140
2141     def users_in_group(self, group):
2142         return self.assignments[group]
2143
2144     def get_groups(self):
2145         return self.assignments.keys()
2146
2147     def cap_group_membership(self, group, max_members):
2148         """Prevent the group's membership from exceeding the max specified"""
2149         num_members = len(self.assignments[group])
2150         if num_members >= max_members:
2151             LOGGER.info("Group {0} has {1} members".format(group, num_members))
2152
2153             # remove this group and then recalculate the cumulative
2154             # distribution, so this group is no longer selected
2155             self.group_weights[group - 1] = 0
2156             new_dist = self.cumulative_distribution(self.group_weights)
2157             self.group_dist = new_dist
2158
2159     def add_assignment(self, user, group):
2160         # the assignments are stored in a dictionary where key=group,
2161         # value=list-of-users-in-group (indexing by group-ID allows us to
2162         # optimize for DB membership writes)
2163         if user not in self.assignments[group]:
2164             self.assignments[group].append(user)
2165             self.count += 1
2166
2167         # check if there'a cap on how big the groups can grow
2168         if self.max_members:
2169             self.cap_group_membership(group, self.max_members)
2170
2171     def assign_groups(self, number_of_groups, groups_added,
2172                       number_of_users, users_added, group_memberships):
2173         """Allocate users to groups.
2174
2175         The intention is to have a few users that belong to most groups, while
2176         the majority of users belong to a few groups.
2177
2178         A few groups will contain most users, with the remaining only having a
2179         few users.
2180         """
2181
2182         if group_memberships <= 0:
2183             return
2184
2185         # Calculate the number of group menberships required
2186         group_memberships = math.ceil(
2187             float(group_memberships) *
2188             (float(users_added) / float(number_of_users)))
2189
2190         if self.max_members:
2191             group_memberships = min(group_memberships,
2192                                     self.max_members * number_of_groups)
2193
2194         existing_users  = number_of_users  - users_added  - 1
2195         existing_groups = number_of_groups - groups_added - 1
2196         while self.total() < group_memberships:
2197             user, group = self.generate_random_membership()
2198
2199             if group > existing_groups or user > existing_users:
2200                 # the + 1 converts the array index to the corresponding
2201                 # group or user number
2202                 self.add_assignment(user + 1, group + 1)
2203
2204     def total(self):
2205         return self.count
2206
2207
2208 def add_users_to_groups(db, instance_id, assignments):
2209     """Takes the assignments of users to groups and applies them to the DB."""
2210
2211     total = assignments.total()
2212     count = 0
2213     added = 0
2214
2215     for group in assignments.get_groups():
2216         users_in_group = assignments.users_in_group(group)
2217         if len(users_in_group) == 0:
2218             continue
2219
2220         # Split up the users into chunks, so we write no more than 1K at a
2221         # time. (Minimizing the DB modifies is more efficient, but writing
2222         # 10K+ users to a single group becomes inefficient memory-wise)
2223         for chunk in range(0, len(users_in_group), 1000):
2224             chunk_of_users = users_in_group[chunk:chunk + 1000]
2225             add_group_members(db, instance_id, group, chunk_of_users)
2226
2227             added += len(chunk_of_users)
2228             count += 1
2229             if count % 50 == 0:
2230                 LOGGER.info("Added %u/%u memberships" % (added, total))
2231
2232 def add_group_members(db, instance_id, group, users_in_group):
2233     """Adds the given users to group specified."""
2234
2235     ou = ou_name(db, instance_id)
2236
2237     def build_dn(name):
2238         return("cn=%s,%s" % (name, ou))
2239
2240     group_dn = build_dn(group_name(instance_id, group))
2241     m = ldb.Message()
2242     m.dn = ldb.Dn(db, group_dn)
2243
2244     for user in users_in_group:
2245         user_dn = build_dn(user_name(instance_id, user))
2246         idx = "member-" + str(user)
2247         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2248
2249     db.modify(m)
2250
2251
2252 def generate_stats(statsdir, timing_file):
2253     """Generate and print the summary stats for a run."""
2254     first      = sys.float_info.max
2255     last       = 0
2256     successful = 0
2257     failed     = 0
2258     latencies  = {}
2259     failures   = Counter()
2260     unique_conversations = set()
2261     if timing_file is not None:
2262         tw = timing_file.write
2263     else:
2264         def tw(x):
2265             pass
2266
2267     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2268
2269     float_values = {
2270         'Maximum lag': 0,
2271         'Start lag': 0,
2272         'Max sleep miss': 0,
2273     }
2274     int_values = {
2275         'Planned_conversations': 0,
2276         'Planned_packets': 0,
2277         'Unfinished_conversations': 0,
2278     }
2279
2280     for filename in os.listdir(statsdir):
2281         path = os.path.join(statsdir, filename)
2282         with open(path, 'r') as f:
2283             for line in f:
2284                 try:
2285                     fields       = line.rstrip('\n').split('\t')
2286                     conversation = fields[1]
2287                     protocol     = fields[2]
2288                     packet_type  = fields[3]
2289                     latency      = float(fields[4])
2290                     t = float(fields[0])
2291                     first        = min(t - latency, first)
2292                     last         = max(t, last)
2293
2294                     op = (protocol, packet_type)
2295                     latencies.setdefault(op, []).append(latency)
2296                     if fields[5] == 'True':
2297                         successful += 1
2298                     else:
2299                         failed += 1
2300                         failures[op] += 1
2301
2302                     unique_conversations.add(conversation)
2303
2304                     tw(line)
2305                 except (ValueError, IndexError):
2306                     if ':' in line:
2307                         k, v = line.split(':', 1)
2308                         if k in float_values:
2309                             float_values[k] = max(float(v),
2310                                                   float_values[k])
2311                         elif k in int_values:
2312                             int_values[k] = max(int(v),
2313                                                 int_values[k])
2314                         else:
2315                             print(line, file=sys.stderr)
2316                     else:
2317                         # not a valid line print and ignore
2318                         print(line, file=sys.stderr)
2319
2320     duration = last - first
2321     if successful == 0:
2322         success_rate = 0
2323     else:
2324         success_rate = successful / duration
2325     if failed == 0:
2326         failure_rate = 0
2327     else:
2328         failure_rate = failed / duration
2329
2330     conversations = len(unique_conversations)
2331
2332     print("Total conversations:   %10d" % conversations)
2333     print("Successful operations: %10d (%.3f per second)"
2334           % (successful, success_rate))
2335     print("Failed operations:     %10d (%.3f per second)"
2336           % (failed, failure_rate))
2337
2338     for k, v in sorted(float_values.items()):
2339         print("%-28s %f" % (k.replace('_', ' ') + ':', v))
2340     for k, v in sorted(int_values.items()):
2341         print("%-28s %d" % (k.replace('_', ' ') + ':', v))
2342
2343     print("Protocol    Op Code  Description                               "
2344           " Count       Failed         Mean       Median          "
2345           "95%        Range          Max")
2346
2347     ops = {}
2348     for proto, packet in latencies:
2349         if proto not in ops:
2350             ops[proto] = set()
2351         ops[proto].add(packet)
2352     protocols = sorted(ops.keys())
2353
2354     for protocol in protocols:
2355         packet_types = sorted(ops[protocol], key=opcode_key)
2356         for packet_type in packet_types:
2357             op = (protocol, packet_type)
2358             values     = latencies[op]
2359             values     = sorted(values)
2360             count      = len(values)
2361             failed     = failures[op]
2362             mean       = sum(values) / count
2363             median     = calc_percentile(values, 0.50)
2364             percentile = calc_percentile(values, 0.95)
2365             rng        = values[-1] - values[0]
2366             maxv       = values[-1]
2367             desc       = OP_DESCRIPTIONS.get(op, '')
2368             print("%-12s   %4s  %-35s %12d %12d %12.6f "
2369                   "%12.6f %12.6f %12.6f %12.6f"
2370                   % (protocol,
2371                      packet_type,
2372                      desc,
2373                      count,
2374                      failed,
2375                      mean,
2376                      median,
2377                      percentile,
2378                      rng,
2379                      maxv))
2380
2381
2382 def opcode_key(v):
2383     """Sort key for the operation code to ensure that it sorts numerically"""
2384     try:
2385         return "%03d" % int(v)
2386     except ValueError:
2387         return v
2388
2389
2390 def calc_percentile(values, percentile):
2391     """Calculate the specified percentile from the list of values.
2392
2393     Assumes the list is sorted in ascending order.
2394     """
2395
2396     if not values:
2397         return 0
2398     k = (len(values) - 1) * percentile
2399     f = math.floor(k)
2400     c = math.ceil(k)
2401     if f == c:
2402         return values[int(k)]
2403     d0 = values[int(f)] * (c - k)
2404     d1 = values[int(c)] * (k - f)
2405     return d0 + d1
2406
2407
2408 def mk_masked_dir(*path):
2409     """In a testenv we end up with 0777 directories that look an alarming
2410     green colour with ls. Use umask to avoid that."""
2411     # py3 os.mkdir can do this
2412     d = os.path.join(*path)
2413     mask = os.umask(0o077)
2414     os.mkdir(d)
2415     os.umask(mask)
2416     return d