traffic: remove useless branch in stats report
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 from errno import ECHILD, ESRCH
29
30 from collections import OrderedDict, Counter, defaultdict, namedtuple
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION,
49     UF_WORKSTATION_TRUST_ACCOUNT
50 )
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
56 import bisect
57
58 CURRENT_MODEL_VERSION = 2   # save as this
59 REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
60 SLEEP_OVERHEAD = 3e-4
61
62 # we don't use None, because it complicates [de]serialisation
63 NON_PACKET = '-'
64
65 CLIENT_CLUES = {
66     ('dns', '0'): 1.0,      # query
67     ('smb', '0x72'): 1.0,   # Negotiate protocol
68     ('ldap', '0'): 1.0,     # bind
69     ('ldap', '3'): 1.0,     # searchRequest
70     ('ldap', '2'): 1.0,     # unbindRequest
71     ('cldap', '3'): 1.0,
72     ('dcerpc', '11'): 1.0,  # bind
73     ('dcerpc', '14'): 1.0,  # Alter_context
74     ('nbns', '0'): 1.0,     # query
75 }
76
77 SERVER_CLUES = {
78     ('dns', '1'): 1.0,      # response
79     ('ldap', '1'): 1.0,     # bind response
80     ('ldap', '4'): 1.0,     # search result
81     ('ldap', '5'): 1.0,     # search done
82     ('cldap', '5'): 1.0,
83     ('dcerpc', '12'): 1.0,  # bind_ack
84     ('dcerpc', '13'): 1.0,  # bind_nak
85     ('dcerpc', '15'): 1.0,  # Alter_context response
86 }
87
88 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
89
90 WAIT_SCALE = 10.0
91 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
92 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
93
94 # DEBUG_LEVEL can be changed by scripts with -d
95 DEBUG_LEVEL = 0
96
97 LOGGER = get_samba_logger(name=__name__)
98
99
100 def debug(level, msg, *args):
101     """Print a formatted debug message to standard error.
102
103
104     :param level: The debug level, message will be printed if it is <= the
105                   currently set debug level. The debug level can be set with
106                   the -d option.
107     :param msg:   The message to be logged, can contain C-Style format
108                   specifiers
109     :param args:  The parameters required by the format specifiers
110     """
111     if level <= DEBUG_LEVEL:
112         if not args:
113             print(msg, file=sys.stderr)
114         else:
115             print(msg % tuple(args), file=sys.stderr)
116
117
118 def debug_lineno(*args):
119     """ Print an unformatted log message to stderr, contaning the line number
120     """
121     tb = traceback.extract_stack(limit=2)
122     print((" %s:" "\033[01;33m"
123            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
124           file=sys.stderr)
125     for a in args:
126         print(a, file=sys.stderr)
127     print(file=sys.stderr)
128     sys.stderr.flush()
129
130
131 def random_colour_print(seeds):
132     """Return a function that prints a coloured line to stderr. The colour
133     of the line depends on a sort of hash of the integer arguments."""
134     if seeds:
135         s = 214
136         for x in seeds:
137             s += 17
138             s *= x
139             s %= 214
140         prefix = "\033[38;5;%dm" % (18 + s)
141
142         def p(*args):
143             if DEBUG_LEVEL > 0:
144                 for a in args:
145                     print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
146     else:
147         def p(*args):
148             if DEBUG_LEVEL > 0:
149                 for a in args:
150                     print(a, file=sys.stderr)
151
152     return p
153
154
155 class FakePacketError(Exception):
156     pass
157
158
159 class Packet(object):
160     """Details of a network packet"""
161     __slots__ = ('timestamp',
162                  'ip_protocol',
163                  'stream_number',
164                  'src',
165                  'dest',
166                  'protocol',
167                  'opcode',
168                  'desc',
169                  'extra',
170                  'endpoints')
171     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
172                  protocol, opcode, desc, extra):
173         self.timestamp = timestamp
174         self.ip_protocol = ip_protocol
175         self.stream_number = stream_number
176         self.src = src
177         self.dest = dest
178         self.protocol = protocol
179         self.opcode = opcode
180         self.desc = desc
181         self.extra = extra
182         if self.src < self.dest:
183             self.endpoints = (self.src, self.dest)
184         else:
185             self.endpoints = (self.dest, self.src)
186
187     @classmethod
188     def from_line(cls, line):
189         fields = line.rstrip('\n').split('\t')
190         (timestamp,
191          ip_protocol,
192          stream_number,
193          src,
194          dest,
195          protocol,
196          opcode,
197          desc) = fields[:8]
198         extra = fields[8:]
199
200         timestamp = float(timestamp)
201         src = int(src)
202         dest = int(dest)
203
204         return cls(timestamp, ip_protocol, stream_number, src, dest,
205                    protocol, opcode, desc, extra)
206
207     def as_summary(self, time_offset=0.0):
208         """Format the packet as a traffic_summary line.
209         """
210         extra = '\t'.join(self.extra)
211         t = self.timestamp + time_offset
212         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
213                 (t,
214                  self.ip_protocol,
215                  self.stream_number or '',
216                  self.src,
217                  self.dest,
218                  self.protocol,
219                  self.opcode,
220                  self.desc,
221                  extra))
222
223     def __str__(self):
224         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
225                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
226                  self.stream_number, self.protocol, self.opcode, self.desc,
227                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
228
229     def __repr__(self):
230         return "<Packet @%s>" % self
231
232     def copy(self):
233         return self.__class__(self.timestamp,
234                               self.ip_protocol,
235                               self.stream_number,
236                               self.src,
237                               self.dest,
238                               self.protocol,
239                               self.opcode,
240                               self.desc,
241                               self.extra)
242
243     def as_packet_type(self):
244         t = '%s:%s' % (self.protocol, self.opcode)
245         return t
246
247     def client_score(self):
248         """A positive number means we think it is a client; a negative number
249         means we think it is a server. Zero means no idea. range: -1 to 1.
250         """
251         key = (self.protocol, self.opcode)
252         if key in CLIENT_CLUES:
253             return CLIENT_CLUES[key]
254         if key in SERVER_CLUES:
255             return -SERVER_CLUES[key]
256         return 0.0
257
258     def play(self, conversation, context):
259         """Send the packet over the network, if required.
260
261         Some packets are ignored, i.e. for  protocols not handled,
262         server response messages, or messages that are generated by the
263         protocol layer associated with other packets.
264         """
265         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
266         try:
267             fn = getattr(traffic_packets, fn_name)
268
269         except AttributeError as e:
270             print("Conversation(%s) Missing handler %s" %
271                   (conversation.conversation_id, fn_name),
272                   file=sys.stderr)
273             return
274
275         # Don't display a message for kerberos packets, they're not directly
276         # generated they're used to indicate kerberos should be used
277         if self.protocol != "kerberos":
278             debug(2, "Conversation(%s) Calling handler %s" %
279                      (conversation.conversation_id, fn_name))
280
281         start = time.time()
282         try:
283             if fn(self, conversation, context):
284                 # Only collect timing data for functions that generate
285                 # network traffic, or fail
286                 end = time.time()
287                 duration = end - start
288                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
289                       (end, conversation.conversation_id, self.protocol,
290                        self.opcode, duration))
291         except Exception as e:
292             end = time.time()
293             duration = end - start
294             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
295                   (end, conversation.conversation_id, self.protocol,
296                    self.opcode, duration, e))
297
298     def __cmp__(self, other):
299         return self.timestamp - other.timestamp
300
301     def is_really_a_packet(self, missing_packet_stats=None):
302         return is_a_real_packet(self.protocol, self.opcode)
303
304
305 def is_a_real_packet(protocol, opcode):
306     """Is the packet one that can be ignored?
307
308     If so removing it will have no effect on the replay
309     """
310     if protocol in SKIPPED_PROTOCOLS:
311         # Ignore any packets for the protocols we're not interested in.
312         return False
313     if protocol == "ldap" and opcode == '':
314         # skip ldap continuation packets
315         return False
316
317     fn_name = 'packet_%s_%s' % (protocol, opcode)
318     fn = getattr(traffic_packets, fn_name, None)
319     if fn is None:
320         LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
321         return False
322     if fn is traffic_packets.null_packet:
323         return False
324     return True
325
326
327 def is_a_traffic_generating_packet(protocol, opcode):
328     """Return true if a packet generates traffic in its own right. Some of
329     these will generate traffic in certain contexts (e.g. ldap unbind
330     after a bind) but not if the conversation consists only of these packets.
331     """
332     if protocol == 'wait':
333         return False
334
335     if (protocol, opcode) in (
336             ('kerberos', ''),
337             ('ldap', '2'),
338             ('dcerpc', '15'),
339             ('dcerpc', '16')):
340         return False
341
342     return is_a_real_packet(protocol, opcode)
343
344
345 class ReplayContext(object):
346     """State/Context for a conversation between an simulated client and a
347        server. Some of the context is shared amongst all conversations
348        and should be generated before the fork, while other context is
349        specific to a particular conversation and should be generated
350        *after* the fork, in generate_process_local_config().
351     """
352     def __init__(self,
353                  server=None,
354                  lp=None,
355                  creds=None,
356                  badpassword_frequency=None,
357                  prefer_kerberos=None,
358                  tempdir=None,
359                  statsdir=None,
360                  ou=None,
361                  base_dn=None,
362                  domain=os.environ.get("DOMAIN"),
363                  domain_sid=None):
364         self.server                   = server
365         self.netlogon_connection      = None
366         self.creds                    = creds
367         self.lp                       = lp
368         self.prefer_kerberos          = prefer_kerberos
369         self.ou                       = ou
370         self.base_dn                  = base_dn
371         self.domain                   = domain
372         self.statsdir                 = statsdir
373         self.global_tempdir           = tempdir
374         self.domain_sid               = domain_sid
375         self.realm                    = lp.get('realm')
376
377         # Bad password attempt controls
378         self.badpassword_frequency    = badpassword_frequency
379         self.last_lsarpc_bad          = False
380         self.last_lsarpc_named_bad    = False
381         self.last_simple_bind_bad     = False
382         self.last_bind_bad            = False
383         self.last_srvsvc_bad          = False
384         self.last_drsuapi_bad         = False
385         self.last_netlogon_bad        = False
386         self.last_samlogon_bad        = False
387         self.generate_ldap_search_tables()
388
389     def generate_ldap_search_tables(self):
390         session = system_session()
391
392         db = SamDB(url="ldap://%s" % self.server,
393                    session_info=session,
394                    credentials=self.creds,
395                    lp=self.lp)
396
397         res = db.search(db.domain_dn(),
398                         scope=ldb.SCOPE_SUBTREE,
399                         controls=["paged_results:1:1000"],
400                         attrs=['dn'])
401
402         # find a list of dns for each pattern
403         # e.g. CN,CN,CN,DC,DC
404         dn_map = {}
405         attribute_clue_map = {
406             'invocationId': []
407         }
408
409         for r in res:
410             dn = str(r.dn)
411             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
412             dns = dn_map.setdefault(pattern, [])
413             dns.append(dn)
414             if dn.startswith('CN=NTDS Settings,'):
415                 attribute_clue_map['invocationId'].append(dn)
416
417         # extend the map in case we are working with a different
418         # number of DC components.
419         # for k, v in self.dn_map.items():
420         #     print >>sys.stderr, k, len(v)
421
422         for k in list(dn_map.keys()):
423             if k[-3:] != ',DC':
424                 continue
425             p = k[:-3]
426             while p[-3:] == ',DC':
427                 p = p[:-3]
428             for i in range(5):
429                 p += ',DC'
430                 if p != k and p in dn_map:
431                     print('dn_map collison %s %s' % (k, p),
432                           file=sys.stderr)
433                     continue
434                 dn_map[p] = dn_map[k]
435
436         self.dn_map = dn_map
437         self.attribute_clue_map = attribute_clue_map
438
439     def generate_process_local_config(self, account, conversation):
440         self.ldap_connections         = []
441         self.dcerpc_connections       = []
442         self.lsarpc_connections       = []
443         self.lsarpc_connections_named = []
444         self.drsuapi_connections      = []
445         self.srvsvc_connections       = []
446         self.samr_contexts            = []
447         self.netbios_name             = account.netbios_name
448         self.machinepass              = account.machinepass
449         self.username                 = account.username
450         self.userpass                 = account.userpass
451
452         self.tempdir = mk_masked_dir(self.global_tempdir,
453                                      'conversation-%d' %
454                                      conversation.conversation_id)
455
456         self.lp.set("private dir", self.tempdir)
457         self.lp.set("lock dir", self.tempdir)
458         self.lp.set("state directory", self.tempdir)
459         self.lp.set("tls verify peer", "no_check")
460
461         self.remoteAddress = "/root/ncalrpc_as_system"
462         self.samlogon_dn   = ("cn=%s,%s" %
463                               (self.netbios_name, self.ou))
464         self.user_dn       = ("cn=%s,%s" %
465                               (self.username, self.ou))
466
467         self.generate_machine_creds()
468         self.generate_user_creds()
469
470     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
471         """Execute the supplied logon function, randomly choosing the
472            bad credentials.
473
474            Based on the frequency in badpassword_frequency randomly perform the
475            function with the supplied bad credentials.
476            If run with bad credentials, the function is re-run with the good
477            credentials.
478            failed_last_time is used to prevent consecutive bad credential
479            attempts. So the over all bad credential frequency will be lower
480            than that requested, but not significantly.
481         """
482         if not failed_last_time:
483             if (self.badpassword_frequency and
484                 random.random() < self.badpassword_frequency):
485                 try:
486                     f(bad)
487                 except Exception:
488                     # Ignore any exceptions as the operation may fail
489                     # as it's being performed with bad credentials
490                     pass
491                 failed_last_time = True
492             else:
493                 failed_last_time = False
494
495         result = f(good)
496         return (result, failed_last_time)
497
498     def generate_user_creds(self):
499         """Generate the conversation specific user Credentials.
500
501         Each Conversation has an associated user account used to simulate
502         any non Administrative user traffic.
503
504         Generates user credentials with good and bad passwords and ldap
505         simple bind credentials with good and bad passwords.
506         """
507         self.user_creds = Credentials()
508         self.user_creds.guess(self.lp)
509         self.user_creds.set_workstation(self.netbios_name)
510         self.user_creds.set_password(self.userpass)
511         self.user_creds.set_username(self.username)
512         self.user_creds.set_domain(self.domain)
513         if self.prefer_kerberos:
514             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
515         else:
516             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
517
518         self.user_creds_bad = Credentials()
519         self.user_creds_bad.guess(self.lp)
520         self.user_creds_bad.set_workstation(self.netbios_name)
521         self.user_creds_bad.set_password(self.userpass[:-4])
522         self.user_creds_bad.set_username(self.username)
523         if self.prefer_kerberos:
524             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
525         else:
526             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
527
528         # Credentials for ldap simple bind.
529         self.simple_bind_creds = Credentials()
530         self.simple_bind_creds.guess(self.lp)
531         self.simple_bind_creds.set_workstation(self.netbios_name)
532         self.simple_bind_creds.set_password(self.userpass)
533         self.simple_bind_creds.set_username(self.username)
534         self.simple_bind_creds.set_gensec_features(
535             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
536         if self.prefer_kerberos:
537             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
538         else:
539             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
540         self.simple_bind_creds.set_bind_dn(self.user_dn)
541
542         self.simple_bind_creds_bad = Credentials()
543         self.simple_bind_creds_bad.guess(self.lp)
544         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
545         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
546         self.simple_bind_creds_bad.set_username(self.username)
547         self.simple_bind_creds_bad.set_gensec_features(
548             self.simple_bind_creds_bad.get_gensec_features() |
549             gensec.FEATURE_SEAL)
550         if self.prefer_kerberos:
551             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
552         else:
553             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
554         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
555
556     def generate_machine_creds(self):
557         """Generate the conversation specific machine Credentials.
558
559         Each Conversation has an associated machine account.
560
561         Generates machine credentials with good and bad passwords.
562         """
563
564         self.machine_creds = Credentials()
565         self.machine_creds.guess(self.lp)
566         self.machine_creds.set_workstation(self.netbios_name)
567         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
568         self.machine_creds.set_password(self.machinepass)
569         self.machine_creds.set_username(self.netbios_name + "$")
570         self.machine_creds.set_domain(self.domain)
571         if self.prefer_kerberos:
572             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
573         else:
574             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
575
576         self.machine_creds_bad = Credentials()
577         self.machine_creds_bad.guess(self.lp)
578         self.machine_creds_bad.set_workstation(self.netbios_name)
579         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
580         self.machine_creds_bad.set_password(self.machinepass[:-4])
581         self.machine_creds_bad.set_username(self.netbios_name + "$")
582         if self.prefer_kerberos:
583             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
584         else:
585             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
586
587     def get_matching_dn(self, pattern, attributes=None):
588         # If the pattern is an empty string, we assume ROOTDSE,
589         # Otherwise we try adding or removing DC suffixes, then
590         # shorter leading patterns until we hit one.
591         # e.g if there is no CN,CN,CN,CN,DC,DC
592         # we first try       CN,CN,CN,CN,DC
593         # and                CN,CN,CN,CN,DC,DC,DC
594         # then change to        CN,CN,CN,DC,DC
595         # and as last resort we use the base_dn
596         attr_clue = self.attribute_clue_map.get(attributes)
597         if attr_clue:
598             return random.choice(attr_clue)
599
600         pattern = pattern.upper()
601         while pattern:
602             if pattern in self.dn_map:
603                 return random.choice(self.dn_map[pattern])
604             # chop one off the front and try it all again.
605             pattern = pattern[3:]
606
607         return self.base_dn
608
609     def get_dcerpc_connection(self, new=False):
610         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
611         if self.dcerpc_connections and not new:
612             return self.dcerpc_connections[-1]
613         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
614                              (guid, 1), self.lp)
615         self.dcerpc_connections.append(c)
616         return c
617
618     def get_srvsvc_connection(self, new=False):
619         if self.srvsvc_connections and not new:
620             return self.srvsvc_connections[-1]
621
622         def connect(creds):
623             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
624                                  self.lp,
625                                  creds)
626
627         (c, self.last_srvsvc_bad) = \
628             self.with_random_bad_credentials(connect,
629                                              self.user_creds,
630                                              self.user_creds_bad,
631                                              self.last_srvsvc_bad)
632
633         self.srvsvc_connections.append(c)
634         return c
635
636     def get_lsarpc_connection(self, new=False):
637         if self.lsarpc_connections and not new:
638             return self.lsarpc_connections[-1]
639
640         def connect(creds):
641             binding_options = 'schannel,seal,sign'
642             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
643                               (self.server, binding_options),
644                               self.lp,
645                               creds)
646
647         (c, self.last_lsarpc_bad) = \
648             self.with_random_bad_credentials(connect,
649                                              self.machine_creds,
650                                              self.machine_creds_bad,
651                                              self.last_lsarpc_bad)
652
653         self.lsarpc_connections.append(c)
654         return c
655
656     def get_lsarpc_named_pipe_connection(self, new=False):
657         if self.lsarpc_connections_named and not new:
658             return self.lsarpc_connections_named[-1]
659
660         def connect(creds):
661             return lsa.lsarpc("ncacn_np:%s" % (self.server),
662                               self.lp,
663                               creds)
664
665         (c, self.last_lsarpc_named_bad) = \
666             self.with_random_bad_credentials(connect,
667                                              self.machine_creds,
668                                              self.machine_creds_bad,
669                                              self.last_lsarpc_named_bad)
670
671         self.lsarpc_connections_named.append(c)
672         return c
673
674     def get_drsuapi_connection_pair(self, new=False, unbind=False):
675         """get a (drs, drs_handle) tuple"""
676         if self.drsuapi_connections and not new:
677             c = self.drsuapi_connections[-1]
678             return c
679
680         def connect(creds):
681             binding_options = 'seal'
682             binding_string = "ncacn_ip_tcp:%s[%s]" %\
683                              (self.server, binding_options)
684             return drsuapi.drsuapi(binding_string, self.lp, creds)
685
686         (drs, self.last_drsuapi_bad) = \
687             self.with_random_bad_credentials(connect,
688                                              self.user_creds,
689                                              self.user_creds_bad,
690                                              self.last_drsuapi_bad)
691
692         (drs_handle, supported_extensions) = drs_DsBind(drs)
693         c = (drs, drs_handle)
694         self.drsuapi_connections.append(c)
695         return c
696
697     def get_ldap_connection(self, new=False, simple=False):
698         if self.ldap_connections and not new:
699             return self.ldap_connections[-1]
700
701         def simple_bind(creds):
702             """
703             To run simple bind against Windows, we need to run
704             following commands in PowerShell:
705
706                 Install-windowsfeature ADCS-Cert-Authority
707                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
708                 Restart-Computer
709
710             """
711             return SamDB('ldaps://%s' % self.server,
712                          credentials=creds,
713                          lp=self.lp)
714
715         def sasl_bind(creds):
716             return SamDB('ldap://%s' % self.server,
717                          credentials=creds,
718                          lp=self.lp)
719         if simple:
720             (samdb, self.last_simple_bind_bad) = \
721                 self.with_random_bad_credentials(simple_bind,
722                                                  self.simple_bind_creds,
723                                                  self.simple_bind_creds_bad,
724                                                  self.last_simple_bind_bad)
725         else:
726             (samdb, self.last_bind_bad) = \
727                 self.with_random_bad_credentials(sasl_bind,
728                                                  self.user_creds,
729                                                  self.user_creds_bad,
730                                                  self.last_bind_bad)
731
732         self.ldap_connections.append(samdb)
733         return samdb
734
735     def get_samr_context(self, new=False):
736         if not self.samr_contexts or new:
737             self.samr_contexts.append(
738                 SamrContext(self.server, lp=self.lp, creds=self.creds))
739         return self.samr_contexts[-1]
740
741     def get_netlogon_connection(self):
742
743         if self.netlogon_connection:
744             return self.netlogon_connection
745
746         def connect(creds):
747             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
748                                      (self.server),
749                                      self.lp,
750                                      creds)
751         (c, self.last_netlogon_bad) = \
752             self.with_random_bad_credentials(connect,
753                                              self.machine_creds,
754                                              self.machine_creds_bad,
755                                              self.last_netlogon_bad)
756         self.netlogon_connection = c
757         return c
758
759     def guess_a_dns_lookup(self):
760         return (self.realm, 'A')
761
762     def get_authenticator(self):
763         auth = self.machine_creds.new_client_authenticator()
764         current  = netr_Authenticator()
765         current.cred.data = [x if isinstance(x, int) else ord(x)
766                              for x in auth["credential"]]
767         current.timestamp = auth["timestamp"]
768
769         subsequent = netr_Authenticator()
770         return (current, subsequent)
771
772     def write_stats(self, filename, **kwargs):
773         """Write arbitrary key/value pairs to a file in our stats directory in
774         order for them to be picked up later by another process working out
775         statistics."""
776         filename = os.path.join(self.statsdir, filename)
777         f = open(filename, 'w')
778         for k, v in kwargs.items():
779             print("%s: %s" % (k, v), file=f)
780         f.close()
781
782
783 class SamrContext(object):
784     """State/Context associated with a samr connection.
785     """
786     def __init__(self, server, lp=None, creds=None):
787         self.connection    = None
788         self.handle        = None
789         self.domain_handle = None
790         self.domain_sid    = None
791         self.group_handle  = None
792         self.user_handle   = None
793         self.rids          = None
794         self.server        = server
795         self.lp            = lp
796         self.creds         = creds
797
798     def get_connection(self):
799         if not self.connection:
800             self.connection = samr.samr(
801                 "ncacn_ip_tcp:%s[seal]" % (self.server),
802                 lp_ctx=self.lp,
803                 credentials=self.creds)
804
805         return self.connection
806
807     def get_handle(self):
808         if not self.handle:
809             c = self.get_connection()
810             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
811         return self.handle
812
813
814 class Conversation(object):
815     """Details of a converation between a simulated client and a server."""
816     def __init__(self, start_time=None, endpoints=None, seq=(),
817                  conversation_id=None):
818         self.start_time = start_time
819         self.endpoints = endpoints
820         self.packets = []
821         self.msg = random_colour_print(endpoints)
822         self.client_balance = 0.0
823         self.conversation_id = conversation_id
824         for p in seq:
825             self.add_short_packet(*p)
826
827     def __cmp__(self, other):
828         if self.start_time is None:
829             if other.start_time is None:
830                 return 0
831             return -1
832         if other.start_time is None:
833             return 1
834         return self.start_time - other.start_time
835
836     def add_packet(self, packet):
837         """Add a packet object to this conversation, making a local copy with
838         a conversation-relative timestamp."""
839         p = packet.copy()
840
841         if self.start_time is None:
842             self.start_time = p.timestamp
843
844         if self.endpoints is None:
845             self.endpoints = p.endpoints
846
847         if p.endpoints != self.endpoints:
848             raise FakePacketError("Conversation endpoints %s don't match"
849                                   "packet endpoints %s" %
850                                   (self.endpoints, p.endpoints))
851
852         p.timestamp -= self.start_time
853
854         if p.src == p.endpoints[0]:
855             self.client_balance -= p.client_score()
856         else:
857             self.client_balance += p.client_score()
858
859         if p.is_really_a_packet():
860             self.packets.append(p)
861
862     def add_short_packet(self, timestamp, protocol, opcode, extra,
863                          client=True):
864         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
865         (possibly empty) list of extra data. If client is True, assume
866         this packet is from the client to the server.
867         """
868         src, dest = self.guess_client_server()
869         if not client:
870             src, dest = dest, src
871         key = (protocol, opcode)
872         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
873         if protocol in IP_PROTOCOLS:
874             ip_protocol = IP_PROTOCOLS[protocol]
875         else:
876             ip_protocol = '06'
877         packet = Packet(timestamp - self.start_time, ip_protocol,
878                         '', src, dest,
879                         protocol, opcode, desc, extra)
880         # XXX we're assuming the timestamp is already adjusted for
881         # this conversation?
882         # XXX should we adjust client balance for guessed packets?
883         if packet.src == packet.endpoints[0]:
884             self.client_balance -= packet.client_score()
885         else:
886             self.client_balance += packet.client_score()
887         if packet.is_really_a_packet():
888             self.packets.append(packet)
889
890     def __str__(self):
891         return ("<Conversation %s %s starting %.3f %d packets>" %
892                 (self.conversation_id, self.endpoints, self.start_time,
893                  len(self.packets)))
894
895     __repr__ = __str__
896
897     def __iter__(self):
898         return iter(self.packets)
899
900     def __len__(self):
901         return len(self.packets)
902
903     def get_duration(self):
904         if len(self.packets) < 2:
905             return 0
906         return self.packets[-1].timestamp - self.packets[0].timestamp
907
908     def replay_as_summary_lines(self):
909         lines = []
910         for p in self.packets:
911             lines.append(p.as_summary(self.start_time))
912         return lines
913
914     def replay_with_delay(self, start, context=None, account=None):
915         """Replay the conversation at the right time.
916         (We're already in a fork)."""
917         # first we sleep until the first packet
918         t = self.start_time
919         now = time.time() - start
920         gap = t - now
921         sleep_time = gap - SLEEP_OVERHEAD
922         if sleep_time > 0:
923             time.sleep(sleep_time)
924
925         miss = (time.time() - start) - t
926         self.msg("starting %s [miss %.3f]" % (self, miss))
927
928         max_gap = 0.0
929         max_sleep_miss = 0.0
930         # packet times are relative to conversation start
931         p_start = time.time()
932         for p in self.packets:
933             now = time.time() - p_start
934             gap = now - p.timestamp
935             if gap > max_gap:
936                 max_gap = gap
937             if gap < 0:
938                 sleep_time = -gap - SLEEP_OVERHEAD
939                 if sleep_time > 0:
940                     time.sleep(sleep_time)
941                     t = time.time() - p_start
942                     if t - p.timestamp > max_sleep_miss:
943                         max_sleep_miss = t - p.timestamp
944
945             p.play(self, context)
946
947         return max_gap, miss, max_sleep_miss
948
949     def guess_client_server(self, server_clue=None):
950         """Have a go at deciding who is the server and who is the client.
951         returns (client, server)
952         """
953         a, b = self.endpoints
954
955         if self.client_balance < 0:
956             return (a, b)
957
958         # in the absense of a clue, we will fall through to assuming
959         # the lowest number is the server (which is usually true).
960
961         if self.client_balance == 0 and server_clue == b:
962             return (a, b)
963
964         return (b, a)
965
966     def forget_packets_outside_window(self, s, e):
967         """Prune any packets outside the timne window we're interested in
968
969         :param s: start of the window
970         :param e: end of the window
971         """
972         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
973         self.start_time = self.packets[0].timestamp if self.packets else None
974
975     def renormalise_times(self, start_time):
976         """Adjust the packet start times relative to the new start time."""
977         for p in self.packets:
978             p.timestamp -= start_time
979
980         if self.start_time is not None:
981             self.start_time -= start_time
982
983
984 class DnsHammer(Conversation):
985     """A lightweight conversation that generates a lot of dns:0 packets on
986     the fly"""
987
988     def __init__(self, dns_rate, duration):
989         n = int(dns_rate * duration)
990         self.times = [random.uniform(0, duration) for i in range(n)]
991         self.times.sort()
992         self.rate = dns_rate
993         self.duration = duration
994         self.start_time = 0
995         self.msg = random_colour_print()
996
997     def __str__(self):
998         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
999                 (len(self.times), self.duration, self.rate))
1000
1001     def replay(self, context=None):
1002         start = time.time()
1003         fn = traffic_packets.packet_dns_0
1004         for t in self.times:
1005             now = time.time() - start
1006             gap = t - now
1007             sleep_time = gap - SLEEP_OVERHEAD
1008             if sleep_time > 0:
1009                 time.sleep(sleep_time)
1010
1011             packet_start = time.time()
1012             try:
1013                 fn(None, None, context)
1014                 end = time.time()
1015                 duration = end - packet_start
1016                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1017             except Exception as e:
1018                 end = time.time()
1019                 duration = end - packet_start
1020                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1021
1022
1023 def ingest_summaries(files, dns_mode='count'):
1024     """Load a summary traffic summary file and generated Converations from it.
1025     """
1026
1027     dns_counts = defaultdict(int)
1028     packets = []
1029     for f in files:
1030         if isinstance(f, str):
1031             f = open(f)
1032         print("Ingesting %s" % (f.name,), file=sys.stderr)
1033         for line in f:
1034             p = Packet.from_line(line)
1035             if p.protocol == 'dns' and dns_mode != 'include':
1036                 dns_counts[p.opcode] += 1
1037             else:
1038                 packets.append(p)
1039
1040         f.close()
1041
1042     if not packets:
1043         return [], 0
1044
1045     start_time = min(p.timestamp for p in packets)
1046     last_packet = max(p.timestamp for p in packets)
1047
1048     print("gathering packets into conversations", file=sys.stderr)
1049     conversations = OrderedDict()
1050     for i, p in enumerate(packets):
1051         p.timestamp -= start_time
1052         c = conversations.get(p.endpoints)
1053         if c is None:
1054             c = Conversation(conversation_id=(i + 2))
1055             conversations[p.endpoints] = c
1056         c.add_packet(p)
1057
1058     # We only care about conversations with actual traffic, so we
1059     # filter out conversations with nothing to say. We do that here,
1060     # rather than earlier, because those empty packets contain useful
1061     # hints as to which end of the conversation was the client.
1062     conversation_list = []
1063     for c in conversations.values():
1064         if len(c) != 0:
1065             conversation_list.append(c)
1066
1067     # This is obviously not correct, as many conversations will appear
1068     # to start roughly simultaneously at the beginning of the snapshot.
1069     # To which we say: oh well, so be it.
1070     duration = float(last_packet - start_time)
1071     mean_interval = len(conversations) / duration
1072
1073     return conversation_list, mean_interval, duration, dns_counts
1074
1075
1076 def guess_server_address(conversations):
1077     # we guess the most common address.
1078     addresses = Counter()
1079     for c in conversations:
1080         addresses.update(c.endpoints)
1081     if addresses:
1082         return addresses.most_common(1)[0]
1083
1084
1085 def stringify_keys(x):
1086     y = {}
1087     for k, v in x.items():
1088         k2 = '\t'.join(k)
1089         y[k2] = v
1090     return y
1091
1092
1093 def unstringify_keys(x):
1094     y = {}
1095     for k, v in x.items():
1096         t = tuple(str(k).split('\t'))
1097         y[t] = v
1098     return y
1099
1100
1101 class TrafficModel(object):
1102     def __init__(self, n=3):
1103         self.ngrams = {}
1104         self.query_details = {}
1105         self.n = n
1106         self.dns_opcounts = defaultdict(int)
1107         self.cumulative_duration = 0.0
1108         self.packet_rate = [0, 1]
1109
1110     def learn(self, conversations, dns_opcounts={}):
1111         prev = 0.0
1112         cum_duration = 0.0
1113         key = (NON_PACKET,) * (self.n - 1)
1114
1115         server = guess_server_address(conversations)
1116
1117         for k, v in dns_opcounts.items():
1118             self.dns_opcounts[k] += v
1119
1120         if len(conversations) > 1:
1121             first = conversations[0].start_time
1122             total = 0
1123             last = first + 0.1
1124             for c in conversations:
1125                 total += len(c)
1126                 last = max(last, c.packets[-1].timestamp)
1127
1128             self.packet_rate[0] = total
1129             self.packet_rate[1] = last - first
1130
1131         for c in conversations:
1132             client, server = c.guess_client_server(server)
1133             cum_duration += c.get_duration()
1134             key = (NON_PACKET,) * (self.n - 1)
1135             for p in c:
1136                 if p.src != client:
1137                     continue
1138
1139                 elapsed = p.timestamp - prev
1140                 prev = p.timestamp
1141                 if elapsed > WAIT_THRESHOLD:
1142                     # add the wait as an extra state
1143                     wait = 'wait:%d' % (math.log(max(1.0,
1144                                                      elapsed * WAIT_SCALE)))
1145                     self.ngrams.setdefault(key, []).append(wait)
1146                     key = key[1:] + (wait,)
1147
1148                 short_p = p.as_packet_type()
1149                 self.query_details.setdefault(short_p,
1150                                               []).append(tuple(p.extra))
1151                 self.ngrams.setdefault(key, []).append(short_p)
1152                 key = key[1:] + (short_p,)
1153
1154         self.cumulative_duration += cum_duration
1155         # add in the end
1156         self.ngrams.setdefault(key, []).append(NON_PACKET)
1157
1158     def save(self, f):
1159         ngrams = {}
1160         for k, v in self.ngrams.items():
1161             k = '\t'.join(k)
1162             ngrams[k] = dict(Counter(v))
1163
1164         query_details = {}
1165         for k, v in self.query_details.items():
1166             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1167                                             for x in v))
1168
1169         d = {
1170             'ngrams': ngrams,
1171             'query_details': query_details,
1172             'cumulative_duration': self.cumulative_duration,
1173             'packet_rate': self.packet_rate,
1174             'version': CURRENT_MODEL_VERSION
1175         }
1176         d['dns'] = self.dns_opcounts
1177
1178         if isinstance(f, str):
1179             f = open(f, 'w')
1180
1181         json.dump(d, f, indent=2)
1182
1183     def load(self, f):
1184         if isinstance(f, str):
1185             f = open(f)
1186
1187         d = json.load(f)
1188
1189         try:
1190             version = d["version"]
1191             if version < REQUIRED_MODEL_VERSION:
1192                 raise ValueError("the model file is version %d; "
1193                                  "version %d is required" %
1194                                  (version, REQUIRED_MODEL_VERSION))
1195         except KeyError:
1196                 raise ValueError("the model file lacks a version number; "
1197                                  "version %d is required" %
1198                                  (REQUIRED_MODEL_VERSION))
1199
1200         for k, v in d['ngrams'].items():
1201             k = tuple(str(k).split('\t'))
1202             values = self.ngrams.setdefault(k, [])
1203             for p, count in v.items():
1204                 values.extend([str(p)] * count)
1205             values.sort()
1206
1207         for k, v in d['query_details'].items():
1208             values = self.query_details.setdefault(str(k), [])
1209             for p, count in v.items():
1210                 if p == '-':
1211                     values.extend([()] * count)
1212                 else:
1213                     values.extend([tuple(str(p).split('\t'))] * count)
1214             values.sort()
1215
1216         if 'dns' in d:
1217             for k, v in d['dns'].items():
1218                 self.dns_opcounts[k] += v
1219
1220         self.cumulative_duration = d['cumulative_duration']
1221         self.packet_rate = d['packet_rate']
1222
1223     def construct_conversation_sequence(self, timestamp=0.0,
1224                                         hard_stop=None,
1225                                         replay_speed=1,
1226                                         ignore_before=0):
1227         """Construct an individual conversation packet sequence from the
1228         model.
1229         """
1230         c = []
1231         key = (NON_PACKET,) * (self.n - 1)
1232         if ignore_before is None:
1233             ignore_before = timestamp - 1
1234
1235         while True:
1236             p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1237             if p == NON_PACKET:
1238                 break
1239
1240             if p in self.query_details:
1241                 extra = random.choice(self.query_details[p])
1242             else:
1243                 extra = []
1244
1245             protocol, opcode = p.split(':', 1)
1246             if protocol == 'wait':
1247                 log_wait_time = int(opcode) + random.random()
1248                 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1249                 timestamp += wait
1250             else:
1251                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1252                 wait = math.exp(log_wait) / replay_speed
1253                 timestamp += wait
1254                 if hard_stop is not None and timestamp > hard_stop:
1255                     break
1256                 if timestamp >= ignore_before:
1257                     c.append((timestamp, protocol, opcode, extra))
1258
1259             key = key[1:] + (p,)
1260
1261         return c
1262
1263     def generate_conversation_sequences(self, scale, duration, replay_speed=1):
1264         """Generate a list of conversation descriptions from the model."""
1265
1266         # We run the simulation for ten times as long as our desired
1267         # duration, and take the section at the end.
1268         lead_in = 9 * duration
1269         rate_n, rate_t  = self.packet_rate
1270         target_packets = int(duration * scale * rate_n / rate_t)
1271
1272         conversations = []
1273         n_packets = 0
1274
1275         while n_packets < target_packets:
1276             start = random.uniform(-lead_in, duration)
1277             c = self.construct_conversation_sequence(start,
1278                                                      hard_stop=duration,
1279                                                      replay_speed=replay_speed,
1280                                                      ignore_before=0)
1281             # will these "packets" generate actual traffic?
1282             # some (e.g. ldap unbind) will not generate anything
1283             # if the previous packets are not there, and if the
1284             # conversation only has those it wastes a process doing nothing.
1285             for timestamp, protocol, opcode, extra in c:
1286                 if is_a_traffic_generating_packet(protocol, opcode):
1287                     break
1288             else:
1289                 continue
1290
1291             conversations.append(c)
1292             n_packets += len(c)
1293
1294         print(("we have %d packets (target %d) in %d conversations at scale %f"
1295                % (n_packets, target_packets, len(conversations), scale)),
1296               file=sys.stderr)
1297         conversations.sort()  # sorts by first element == start time
1298         return conversations
1299
1300
1301 def seq_to_conversations(seq, server=1, client=2):
1302     conversations = []
1303     for s in seq:
1304         if s:
1305             c = Conversation(s[0][0], (server, client), s)
1306             client += 1
1307             conversations.append(c)
1308     return conversations
1309
1310
1311 IP_PROTOCOLS = {
1312     'dns': '11',
1313     'rpc_netlogon': '06',
1314     'kerberos': '06',      # ratio 16248:258
1315     'smb': '06',
1316     'smb2': '06',
1317     'ldap': '06',
1318     'cldap': '11',
1319     'lsarpc': '06',
1320     'samr': '06',
1321     'dcerpc': '06',
1322     'epm': '06',
1323     'drsuapi': '06',
1324     'browser': '11',
1325     'smb_netlogon': '11',
1326     'srvsvc': '06',
1327     'nbns': '11',
1328 }
1329
1330 OP_DESCRIPTIONS = {
1331     ('browser', '0x01'): 'Host Announcement (0x01)',
1332     ('browser', '0x02'): 'Request Announcement (0x02)',
1333     ('browser', '0x08'): 'Browser Election Request (0x08)',
1334     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1335     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1336     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1337     ('cldap', '3'): 'searchRequest',
1338     ('cldap', '5'): 'searchResDone',
1339     ('dcerpc', '0'): 'Request',
1340     ('dcerpc', '11'): 'Bind',
1341     ('dcerpc', '12'): 'Bind_ack',
1342     ('dcerpc', '13'): 'Bind_nak',
1343     ('dcerpc', '14'): 'Alter_context',
1344     ('dcerpc', '15'): 'Alter_context_resp',
1345     ('dcerpc', '16'): 'AUTH3',
1346     ('dcerpc', '2'): 'Response',
1347     ('dns', '0'): 'query',
1348     ('dns', '1'): 'response',
1349     ('drsuapi', '0'): 'DsBind',
1350     ('drsuapi', '12'): 'DsCrackNames',
1351     ('drsuapi', '13'): 'DsWriteAccountSpn',
1352     ('drsuapi', '1'): 'DsUnbind',
1353     ('drsuapi', '2'): 'DsReplicaSync',
1354     ('drsuapi', '3'): 'DsGetNCChanges',
1355     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1356     ('epm', '3'): 'Map',
1357     ('kerberos', ''): '',
1358     ('ldap', '0'): 'bindRequest',
1359     ('ldap', '1'): 'bindResponse',
1360     ('ldap', '2'): 'unbindRequest',
1361     ('ldap', '3'): 'searchRequest',
1362     ('ldap', '4'): 'searchResEntry',
1363     ('ldap', '5'): 'searchResDone',
1364     ('ldap', ''): '*** Unknown ***',
1365     ('lsarpc', '14'): 'lsa_LookupNames',
1366     ('lsarpc', '15'): 'lsa_LookupSids',
1367     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1368     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1369     ('lsarpc', '6'): 'lsa_OpenPolicy',
1370     ('lsarpc', '76'): 'lsa_LookupSids3',
1371     ('lsarpc', '77'): 'lsa_LookupNames4',
1372     ('nbns', '0'): 'query',
1373     ('nbns', '1'): 'response',
1374     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1375     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1376     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1377     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1378     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1379     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1380     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1381     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1382     ('samr', '0',): 'Connect',
1383     ('samr', '16'): 'GetAliasMembership',
1384     ('samr', '17'): 'LookupNames',
1385     ('samr', '18'): 'LookupRids',
1386     ('samr', '19'): 'OpenGroup',
1387     ('samr', '1'): 'Close',
1388     ('samr', '25'): 'QueryGroupMember',
1389     ('samr', '34'): 'OpenUser',
1390     ('samr', '36'): 'QueryUserInfo',
1391     ('samr', '39'): 'GetGroupsForUser',
1392     ('samr', '3'): 'QuerySecurity',
1393     ('samr', '5'): 'LookupDomain',
1394     ('samr', '64'): 'Connect5',
1395     ('samr', '6'): 'EnumDomains',
1396     ('samr', '7'): 'OpenDomain',
1397     ('samr', '8'): 'QueryDomainInfo',
1398     ('smb', '0x04'): 'Close (0x04)',
1399     ('smb', '0x24'): 'Locking AndX (0x24)',
1400     ('smb', '0x2e'): 'Read AndX (0x2e)',
1401     ('smb', '0x32'): 'Trans2 (0x32)',
1402     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1403     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1404     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1405     ('smb', '0x74'): 'Logoff AndX (0x74)',
1406     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1407     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1408     ('smb2', '0'): 'NegotiateProtocol',
1409     ('smb2', '11'): 'Ioctl',
1410     ('smb2', '14'): 'Find',
1411     ('smb2', '16'): 'GetInfo',
1412     ('smb2', '18'): 'Break',
1413     ('smb2', '1'): 'SessionSetup',
1414     ('smb2', '2'): 'SessionLogoff',
1415     ('smb2', '3'): 'TreeConnect',
1416     ('smb2', '4'): 'TreeDisconnect',
1417     ('smb2', '5'): 'Create',
1418     ('smb2', '6'): 'Close',
1419     ('smb2', '8'): 'Read',
1420     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1421     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1422                                'user unknown (0x17)'),
1423     ('srvsvc', '16'): 'NetShareGetInfo',
1424     ('srvsvc', '21'): 'NetSrvGetInfo',
1425 }
1426
1427
1428 def expand_short_packet(p, timestamp, src, dest, extra):
1429     protocol, opcode = p.split(':', 1)
1430     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1431     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1432
1433     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1434     line.extend(extra)
1435     return '\t'.join(line)
1436
1437
1438 def flushing_signal_handler(signal, frame):
1439     """Signal handler closes standard out and error.
1440
1441     Triggered by a sigterm, ensures that the log messages are flushed
1442     to disk and not lost.
1443     """
1444     sys.stderr.close()
1445     sys.stdout.close()
1446     os._exit(0)
1447
1448
1449 def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1450     """Fork a new process and replay the conversation sequence."""
1451     # We will need to reseed the random number generator or all the
1452     # clients will end up using the same sequence of random
1453     # numbers. random.randint() is mixed in so the initial seed will
1454     # have an effect here.
1455     seed = client_id * 1000 + random.randint(0, 999)
1456
1457     # flush our buffers so messages won't be written by both sides
1458     sys.stdout.flush()
1459     sys.stderr.flush()
1460     pid = os.fork()
1461     if pid != 0:
1462         return pid
1463
1464     # we must never return, or we'll end up running parts of the
1465     # parent's clean-up code. So we work in a try...finally, and
1466     # try to print any exceptions.
1467     try:
1468         random.seed(seed)
1469         endpoints = (server_id, client_id)
1470         status = 0
1471         t = cs[0][0]
1472         c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1473         signal.signal(signal.SIGTERM, flushing_signal_handler)
1474
1475         context.generate_process_local_config(account, c)
1476         sys.stdin.close()
1477         os.close(0)
1478         filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1479                                 c.conversation_id)
1480         f = open(filename, 'w')
1481         try:
1482             sys.stdout.close()
1483             os.close(1)
1484         except IOError as e:
1485             LOGGER.info("stdout closing failed with %s" % e)
1486             pass
1487
1488         sys.stdout = f
1489         now = time.time() - start
1490         gap = t - now
1491         sleep_time = gap - SLEEP_OVERHEAD
1492         if sleep_time > 0:
1493             time.sleep(sleep_time)
1494
1495         max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1496                                                                  context=context)
1497         print("Maximum lag: %f" % max_lag)
1498         print("Start lag: %f" % start_lag)
1499         print("Max sleep miss: %f" % max_sleep_miss)
1500
1501     except Exception:
1502         status = 1
1503         print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1504               file=sys.stderr)
1505         traceback.print_exc(sys.stderr)
1506         sys.stderr.flush()
1507     finally:
1508         sys.stderr.close()
1509         sys.stdout.close()
1510         os._exit(status)
1511
1512
1513 def dnshammer_in_fork(dns_rate, duration):
1514     sys.stdout.flush()
1515     sys.stderr.flush()
1516     pid = os.fork()
1517     if pid != 0:
1518         return pid
1519     try:
1520         status = 0
1521         signal.signal(signal.SIGTERM, flushing_signal_handler)
1522         hammer = DnsHammer(dns_rate, duration)
1523         hammer.replay()
1524     except Exception:
1525         status = 1
1526         print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1527               file=sys.stderr)
1528         traceback.print_exc(sys.stderr)
1529     finally:
1530         sys.stderr.close()
1531         sys.stdout.close()
1532         os._exit(status)
1533
1534
1535 def replay(conversation_seq,
1536            host=None,
1537            creds=None,
1538            lp=None,
1539            accounts=None,
1540            dns_rate=0,
1541            duration=None,
1542            latency_timeout=1.0,
1543            stop_on_any_error=False,
1544            **kwargs):
1545
1546     context = ReplayContext(server=host,
1547                             creds=creds,
1548                             lp=lp,
1549                             **kwargs)
1550
1551     if len(accounts) < len(conversation_seq):
1552         raise ValueError(("we have %d accounts but %d conversations" %
1553                           (len(accounts), len(conversation_seq))))
1554
1555     # Set the process group so that the calling scripts are not killed
1556     # when the forked child processes are killed.
1557     os.setpgrp()
1558
1559     # we delay the start by a bit to allow all the forks to get up and
1560     # running.
1561     delay = len(conversation_seq) * 0.02
1562     start = time.time() + delay
1563
1564     if duration is None:
1565         # end slightly after the last packet of the last conversation
1566         # to start. Conversations other than the last could still be
1567         # going, but we don't care.
1568         duration = conversation_seq[-1][-1][0] + latency_timeout
1569
1570     print("We will start in %.1f seconds" % delay,
1571           file=sys.stderr)
1572     print("We will stop after %.1f seconds" % (duration + delay),
1573           file=sys.stderr)
1574     print("runtime %.1f seconds" % duration,
1575           file=sys.stderr)
1576
1577     # give one second grace for packets to finish before killing begins
1578     end = start + duration + 1.0
1579
1580     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1581           % (len(conversation_seq), duration))
1582
1583     context.write_stats('intentions',
1584                         Planned_conversations=len(conversation_seq),
1585                         Planned_packets=sum(len(x) for x in conversation_seq))
1586
1587     children = {}
1588     try:
1589         if dns_rate:
1590             pid = dnshammer_in_fork(dns_rate, duration)
1591             children[pid] = 1
1592
1593         for i, cs in enumerate(conversation_seq):
1594             account = accounts[i]
1595             client_id = i + 2
1596             pid = replay_seq_in_fork(cs, start, context, account, client_id)
1597             children[pid] = client_id
1598
1599         # HERE, we are past all the forks
1600         t = time.time()
1601         print("all forks done in %.1f seconds, waiting %.1f" %
1602               (t - start + delay, t - start),
1603               file=sys.stderr)
1604
1605         while time.time() < end and children:
1606             time.sleep(0.003)
1607             try:
1608                 pid, status = os.waitpid(-1, os.WNOHANG)
1609             except OSError as e:
1610                 if e.errno != ECHILD:  # no child processes
1611                     raise
1612                 break
1613             if pid:
1614                 c = children.pop(pid, None)
1615                 if DEBUG_LEVEL > 0:
1616                     print(("process %d finished conversation %d;"
1617                            " %d to go" %
1618                            (pid, c, len(children))), file=sys.stderr)
1619                 if stop_on_any_error and status != 0:
1620                     break
1621
1622     except Exception:
1623         print("EXCEPTION in parent", file=sys.stderr)
1624         traceback.print_exc()
1625     finally:
1626         context.write_stats('unfinished',
1627                             Unfinished_conversations=len(children))
1628
1629         for s in (15, 15, 9):
1630             print(("killing %d children with -%d" %
1631                    (len(children), s)), file=sys.stderr)
1632             for pid in children:
1633                 try:
1634                     os.kill(pid, s)
1635                 except OSError as e:
1636                     if e.errno != ESRCH:  # don't fail if it has already died
1637                         raise
1638             time.sleep(0.5)
1639             end = time.time() + 1
1640             while children:
1641                 try:
1642                     pid, status = os.waitpid(-1, os.WNOHANG)
1643                 except OSError as e:
1644                     if e.errno != ECHILD:
1645                         raise
1646                 if pid != 0:
1647                     c = children.pop(pid, None)
1648                     if c is None:
1649                         print("children is %s, no pid found" % children)
1650                         sys.stderr.flush()
1651                         sys.stdout.flush()
1652                         os._exit(1)
1653                     print(("kill -%d %d KILLED conversation; "
1654                            "%d to go" %
1655                            (s, pid, len(children))),
1656                           file=sys.stderr)
1657                 if time.time() >= end:
1658                     break
1659
1660             if not children:
1661                 break
1662             time.sleep(1)
1663
1664         if children:
1665             print("%d children are missing" % len(children),
1666                   file=sys.stderr)
1667
1668         # there may be stragglers that were forked just as ^C was hit
1669         # and don't appear in the list of children. We can get them
1670         # with killpg, but that will also kill us, so this is^H^H would be
1671         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1672         # so as not to have to fuss around writing signal handlers.
1673         try:
1674             os.killpg(0, 2)
1675         except KeyboardInterrupt:
1676             print("ignoring fake ^C", file=sys.stderr)
1677
1678
1679 def openLdb(host, creds, lp):
1680     session = system_session()
1681     ldb = SamDB(url="ldap://%s" % host,
1682                 session_info=session,
1683                 options=['modules:paged_searches'],
1684                 credentials=creds,
1685                 lp=lp)
1686     return ldb
1687
1688
1689 def ou_name(ldb, instance_id):
1690     """Generate an ou name from the instance id"""
1691     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1692                                                     ldb.domain_dn())
1693
1694
1695 def create_ou(ldb, instance_id):
1696     """Create an ou, all created user and machine accounts will belong to it.
1697
1698     This allows all the created resources to be cleaned up easily.
1699     """
1700     ou = ou_name(ldb, instance_id)
1701     try:
1702         ldb.add({"dn": ou.split(',', 1)[1],
1703                  "objectclass": "organizationalunit"})
1704     except LdbError as e:
1705         (status, _) = e.args
1706         # ignore already exists
1707         if status != 68:
1708             raise
1709     try:
1710         ldb.add({"dn": ou,
1711                  "objectclass": "organizationalunit"})
1712     except LdbError as e:
1713         (status, _) = e.args
1714         # ignore already exists
1715         if status != 68:
1716             raise
1717     return ou
1718
1719
1720 # ConversationAccounts holds details of the machine and user accounts
1721 # associated with a conversation.
1722 #
1723 # We use a named tuple to reduce shared memory usage.
1724 ConversationAccounts = namedtuple('ConversationAccounts',
1725                                   ('netbios_name',
1726                                    'machinepass',
1727                                    'username',
1728                                    'userpass'))
1729
1730
1731 def generate_replay_accounts(ldb, instance_id, number, password):
1732     """Generate a series of unique machine and user account names."""
1733
1734     accounts = []
1735     for i in range(1, number + 1):
1736         netbios_name = machine_name(instance_id, i)
1737         username = user_name(instance_id, i)
1738
1739         account = ConversationAccounts(netbios_name, password, username,
1740                                        password)
1741         accounts.append(account)
1742     return accounts
1743
1744
1745 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1746                            traffic_account=True):
1747     """Create a machine account via ldap."""
1748
1749     ou = ou_name(ldb, instance_id)
1750     dn = "cn=%s,%s" % (netbios_name, ou)
1751     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1752
1753     if traffic_account:
1754         # we set these bits for the machine account otherwise the replayed
1755         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1756         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1757                                UF_SERVER_TRUST_ACCOUNT)
1758
1759     else:
1760         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1761
1762     ldb.add({
1763         "dn": dn,
1764         "objectclass": "computer",
1765         "sAMAccountName": "%s$" % netbios_name,
1766         "userAccountControl": account_controls,
1767         "unicodePwd": utf16pw})
1768
1769
1770 def create_user_account(ldb, instance_id, username, userpass):
1771     """Create a user account via ldap."""
1772     ou = ou_name(ldb, instance_id)
1773     user_dn = "cn=%s,%s" % (username, ou)
1774     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1775     ldb.add({
1776         "dn": user_dn,
1777         "objectclass": "user",
1778         "sAMAccountName": username,
1779         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1780         "unicodePwd": utf16pw
1781     })
1782
1783     # grant user write permission to do things like write account SPN
1784     sdutils = sd_utils.SDUtils(ldb)
1785     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1786
1787
1788 def create_group(ldb, instance_id, name):
1789     """Create a group via ldap."""
1790
1791     ou = ou_name(ldb, instance_id)
1792     dn = "cn=%s,%s" % (name, ou)
1793     ldb.add({
1794         "dn": dn,
1795         "objectclass": "group",
1796         "sAMAccountName": name,
1797     })
1798
1799
1800 def user_name(instance_id, i):
1801     """Generate a user name based in the instance id"""
1802     return "STGU-%d-%d" % (instance_id, i)
1803
1804
1805 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1806     """Seach objectclass, return attr in a set"""
1807     objs = ldb.search(
1808         expression="(objectClass={})".format(objectclass),
1809         attrs=[attr]
1810     )
1811     return {str(obj[attr]) for obj in objs}
1812
1813
1814 def generate_users(ldb, instance_id, number, password):
1815     """Add users to the server"""
1816     existing_objects = search_objectclass(ldb, objectclass='user')
1817     users = 0
1818     for i in range(number, 0, -1):
1819         name = user_name(instance_id, i)
1820         if name not in existing_objects:
1821             create_user_account(ldb, instance_id, name, password)
1822             users += 1
1823             if users % 50 == 0:
1824                 LOGGER.info("Created %u/%u users" % (users, number))
1825
1826     return users
1827
1828
1829 def machine_name(instance_id, i, traffic_account=True):
1830     """Generate a machine account name from instance id."""
1831     if traffic_account:
1832         # traffic accounts correspond to a given user, and use different
1833         # userAccountControl flags to ensure packets get processed correctly
1834         # by the DC
1835         return "STGM-%d-%d" % (instance_id, i)
1836     else:
1837         # Otherwise we're just generating computer accounts to simulate a
1838         # semi-realistic network. These use the default computer
1839         # userAccountControl flags, so we use a different account name so that
1840         # we don't try to use them when generating packets
1841         return "PC-%d-%d" % (instance_id, i)
1842
1843
1844 def generate_machine_accounts(ldb, instance_id, number, password,
1845                               traffic_account=True):
1846     """Add machine accounts to the server"""
1847     existing_objects = search_objectclass(ldb, objectclass='computer')
1848     added = 0
1849     for i in range(number, 0, -1):
1850         name = machine_name(instance_id, i, traffic_account)
1851         if name + "$" not in existing_objects:
1852             create_machine_account(ldb, instance_id, name, password,
1853                                    traffic_account)
1854             added += 1
1855             if added % 50 == 0:
1856                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1857
1858     return added
1859
1860
1861 def group_name(instance_id, i):
1862     """Generate a group name from instance id."""
1863     return "STGG-%d-%d" % (instance_id, i)
1864
1865
1866 def generate_groups(ldb, instance_id, number):
1867     """Create the required number of groups on the server."""
1868     existing_objects = search_objectclass(ldb, objectclass='group')
1869     groups = 0
1870     for i in range(number, 0, -1):
1871         name = group_name(instance_id, i)
1872         if name not in existing_objects:
1873             create_group(ldb, instance_id, name)
1874             groups += 1
1875             if groups % 1000 == 0:
1876                 LOGGER.info("Created %u/%u groups" % (groups, number))
1877
1878     return groups
1879
1880
1881 def clean_up_accounts(ldb, instance_id):
1882     """Remove the created accounts and groups from the server."""
1883     ou = ou_name(ldb, instance_id)
1884     try:
1885         ldb.delete(ou, ["tree_delete:1"])
1886     except LdbError as e:
1887         (status, _) = e.args
1888         # ignore does not exist
1889         if status != 32:
1890             raise
1891
1892
1893 def generate_users_and_groups(ldb, instance_id, password,
1894                               number_of_users, number_of_groups,
1895                               group_memberships, max_members,
1896                               machine_accounts, traffic_accounts=True):
1897     """Generate the required users and groups, allocating the users to
1898        those groups."""
1899     memberships_added = 0
1900     groups_added = 0
1901     computers_added = 0
1902
1903     create_ou(ldb, instance_id)
1904
1905     LOGGER.info("Generating dummy user accounts")
1906     users_added = generate_users(ldb, instance_id, number_of_users, password)
1907
1908     LOGGER.info("Generating dummy machine accounts")
1909     computers_added = generate_machine_accounts(ldb, instance_id,
1910                                                 machine_accounts, password,
1911                                                 traffic_accounts)
1912
1913     if number_of_groups > 0:
1914         LOGGER.info("Generating dummy groups")
1915         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1916
1917     if group_memberships > 0:
1918         LOGGER.info("Assigning users to groups")
1919         assignments = GroupAssignments(number_of_groups,
1920                                        groups_added,
1921                                        number_of_users,
1922                                        users_added,
1923                                        group_memberships,
1924                                        max_members)
1925         LOGGER.info("Adding users to groups")
1926         add_users_to_groups(ldb, instance_id, assignments)
1927         memberships_added = assignments.total()
1928
1929     if (groups_added > 0 and users_added == 0 and
1930        number_of_groups != groups_added):
1931         LOGGER.warning("The added groups will contain no members")
1932
1933     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1934                 (users_added, computers_added, groups_added,
1935                  memberships_added))
1936
1937
1938 class GroupAssignments(object):
1939     def __init__(self, number_of_groups, groups_added, number_of_users,
1940                  users_added, group_memberships, max_members):
1941
1942         self.count = 0
1943         self.generate_group_distribution(number_of_groups)
1944         self.generate_user_distribution(number_of_users, group_memberships)
1945         self.max_members = max_members
1946         self.assignments = defaultdict(list)
1947         self.assign_groups(number_of_groups, groups_added, number_of_users,
1948                            users_added, group_memberships)
1949
1950     def cumulative_distribution(self, weights):
1951         # make sure the probabilities conform to a cumulative distribution
1952         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1953         # probability a proportional share of 1.0. Higher probabilities get a
1954         # bigger share, so are more likely to be picked. We use the cumulative
1955         # value, so we can use random.random() as a simple index into the list
1956         dist = []
1957         total = sum(weights)
1958         if total == 0:
1959             return None
1960
1961         cumulative = 0.0
1962         for probability in weights:
1963             cumulative += probability
1964             dist.append(cumulative / total)
1965         return dist
1966
1967     def generate_user_distribution(self, num_users, num_memberships):
1968         """Probability distribution of a user belonging to a group.
1969         """
1970         # Assign a weighted probability to each user. Use the Pareto
1971         # Distribution so that some users are in a lot of groups, and the
1972         # bulk of users are in only a few groups. If we're assigning a large
1973         # number of group memberships, use a higher shape. This means slightly
1974         # fewer outlying users that are in large numbers of groups. The aim is
1975         # to have no users belonging to more than ~500 groups.
1976         if num_memberships > 5000000:
1977             shape = 3.0
1978         elif num_memberships > 2000000:
1979             shape = 2.5
1980         elif num_memberships > 300000:
1981             shape = 2.25
1982         else:
1983             shape = 1.75
1984
1985         weights = []
1986         for x in range(1, num_users + 1):
1987             p = random.paretovariate(shape)
1988             weights.append(p)
1989
1990         # convert the weights to a cumulative distribution between 0.0 and 1.0
1991         self.user_dist = self.cumulative_distribution(weights)
1992
1993     def generate_group_distribution(self, n):
1994         """Probability distribution of a group containing a user."""
1995
1996         # Assign a weighted probability to each user. Probability decreases
1997         # as the group-ID increases
1998         weights = []
1999         for x in range(1, n + 1):
2000             p = 1 / (x**1.3)
2001             weights.append(p)
2002
2003         # convert the weights to a cumulative distribution between 0.0 and 1.0
2004         self.group_weights = weights
2005         self.group_dist = self.cumulative_distribution(weights)
2006
2007     def generate_random_membership(self):
2008         """Returns a randomly generated user-group membership"""
2009
2010         # the list items are cumulative distribution values between 0.0 and
2011         # 1.0, which makes random() a handy way to index the list to get a
2012         # weighted random user/group. (Here the user/group returned are
2013         # zero-based array indexes)
2014         user = bisect.bisect(self.user_dist, random.random())
2015         group = bisect.bisect(self.group_dist, random.random())
2016
2017         return user, group
2018
2019     def users_in_group(self, group):
2020         return self.assignments[group]
2021
2022     def get_groups(self):
2023         return self.assignments.keys()
2024
2025     def cap_group_membership(self, group, max_members):
2026         """Prevent the group's membership from exceeding the max specified"""
2027         num_members = len(self.assignments[group])
2028         if num_members >= max_members:
2029             LOGGER.info("Group {0} has {1} members".format(group, num_members))
2030
2031             # remove this group and then recalculate the cumulative
2032             # distribution, so this group is no longer selected
2033             self.group_weights[group - 1] = 0
2034             new_dist = self.cumulative_distribution(self.group_weights)
2035             self.group_dist = new_dist
2036
2037     def add_assignment(self, user, group):
2038         # the assignments are stored in a dictionary where key=group,
2039         # value=list-of-users-in-group (indexing by group-ID allows us to
2040         # optimize for DB membership writes)
2041         if user not in self.assignments[group]:
2042             self.assignments[group].append(user)
2043             self.count += 1
2044
2045         # check if there'a cap on how big the groups can grow
2046         if self.max_members:
2047             self.cap_group_membership(group, self.max_members)
2048
2049     def assign_groups(self, number_of_groups, groups_added,
2050                       number_of_users, users_added, group_memberships):
2051         """Allocate users to groups.
2052
2053         The intention is to have a few users that belong to most groups, while
2054         the majority of users belong to a few groups.
2055
2056         A few groups will contain most users, with the remaining only having a
2057         few users.
2058         """
2059
2060         if group_memberships <= 0:
2061             return
2062
2063         # Calculate the number of group menberships required
2064         group_memberships = math.ceil(
2065             float(group_memberships) *
2066             (float(users_added) / float(number_of_users)))
2067
2068         if self.max_members:
2069             group_memberships = min(group_memberships,
2070                                     self.max_members * number_of_groups)
2071
2072         existing_users  = number_of_users  - users_added  - 1
2073         existing_groups = number_of_groups - groups_added - 1
2074         while self.total() < group_memberships:
2075             user, group = self.generate_random_membership()
2076
2077             if group > existing_groups or user > existing_users:
2078                 # the + 1 converts the array index to the corresponding
2079                 # group or user number
2080                 self.add_assignment(user + 1, group + 1)
2081
2082     def total(self):
2083         return self.count
2084
2085
2086 def add_users_to_groups(db, instance_id, assignments):
2087     """Takes the assignments of users to groups and applies them to the DB."""
2088
2089     total = assignments.total()
2090     count = 0
2091     added = 0
2092
2093     for group in assignments.get_groups():
2094         users_in_group = assignments.users_in_group(group)
2095         if len(users_in_group) == 0:
2096             continue
2097
2098         # Split up the users into chunks, so we write no more than 1K at a
2099         # time. (Minimizing the DB modifies is more efficient, but writing
2100         # 10K+ users to a single group becomes inefficient memory-wise)
2101         for chunk in range(0, len(users_in_group), 1000):
2102             chunk_of_users = users_in_group[chunk:chunk + 1000]
2103             add_group_members(db, instance_id, group, chunk_of_users)
2104
2105             added += len(chunk_of_users)
2106             count += 1
2107             if count % 50 == 0:
2108                 LOGGER.info("Added %u/%u memberships" % (added, total))
2109
2110 def add_group_members(db, instance_id, group, users_in_group):
2111     """Adds the given users to group specified."""
2112
2113     ou = ou_name(db, instance_id)
2114
2115     def build_dn(name):
2116         return("cn=%s,%s" % (name, ou))
2117
2118     group_dn = build_dn(group_name(instance_id, group))
2119     m = ldb.Message()
2120     m.dn = ldb.Dn(db, group_dn)
2121
2122     for user in users_in_group:
2123         user_dn = build_dn(user_name(instance_id, user))
2124         idx = "member-" + str(user)
2125         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2126
2127     db.modify(m)
2128
2129
2130 def generate_stats(statsdir, timing_file):
2131     """Generate and print the summary stats for a run."""
2132     first      = sys.float_info.max
2133     last       = 0
2134     successful = 0
2135     failed     = 0
2136     latencies  = {}
2137     failures   = Counter()
2138     unique_conversations = set()
2139     if timing_file is not None:
2140         tw = timing_file.write
2141     else:
2142         def tw(x):
2143             pass
2144
2145     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2146
2147     float_values = {
2148         'Maximum lag': 0,
2149         'Start lag': 0,
2150         'Max sleep miss': 0,
2151     }
2152     int_values = {
2153         'Planned_conversations': 0,
2154         'Planned_packets': 0,
2155         'Unfinished_conversations': 0,
2156     }
2157
2158     for filename in os.listdir(statsdir):
2159         path = os.path.join(statsdir, filename)
2160         with open(path, 'r') as f:
2161             for line in f:
2162                 try:
2163                     fields       = line.rstrip('\n').split('\t')
2164                     conversation = fields[1]
2165                     protocol     = fields[2]
2166                     packet_type  = fields[3]
2167                     latency      = float(fields[4])
2168                     t = float(fields[0])
2169                     first        = min(t - latency, first)
2170                     last         = max(t, last)
2171
2172                     op = (protocol, packet_type)
2173                     latencies.setdefault(op, []).append(latency)
2174                     if fields[5] == 'True':
2175                         successful += 1
2176                     else:
2177                         failed += 1
2178                         failures[op] += 1
2179
2180                     unique_conversations.add(conversation)
2181
2182                     tw(line)
2183                 except (ValueError, IndexError):
2184                     if ':' in line:
2185                         k, v = line.split(':', 1)
2186                         if k in float_values:
2187                             float_values[k] = max(float(v),
2188                                                   float_values[k])
2189                         elif k in int_values:
2190                             int_values[k] = max(int(v),
2191                                                 int_values[k])
2192                         else:
2193                             print(line, file=sys.stderr)
2194                     else:
2195                         # not a valid line print and ignore
2196                         print(line, file=sys.stderr)
2197
2198     duration = last - first
2199     if successful == 0:
2200         success_rate = 0
2201     else:
2202         success_rate = successful / duration
2203     if failed == 0:
2204         failure_rate = 0
2205     else:
2206         failure_rate = failed / duration
2207
2208     conversations = len(unique_conversations)
2209
2210     print("Total conversations:   %10d" % conversations)
2211     print("Successful operations: %10d (%.3f per second)"
2212           % (successful, success_rate))
2213     print("Failed operations:     %10d (%.3f per second)"
2214           % (failed, failure_rate))
2215
2216     for k, v in sorted(float_values.items()):
2217         print("%-28s %f" % (k.replace('_', ' ') + ':', v))
2218     for k, v in sorted(int_values.items()):
2219         print("%-28s %d" % (k.replace('_', ' ') + ':', v))
2220
2221     print("Protocol    Op Code  Description                               "
2222           " Count       Failed         Mean       Median          "
2223           "95%        Range          Max")
2224
2225     ops = {}
2226     for proto, packet in latencies:
2227         if proto not in ops:
2228             ops[proto] = set()
2229         ops[proto].add(packet)
2230     protocols = sorted(ops.keys())
2231
2232     for protocol in protocols:
2233         packet_types = sorted(ops[protocol], key=opcode_key)
2234         for packet_type in packet_types:
2235             op = (protocol, packet_type)
2236             values     = latencies[op]
2237             values     = sorted(values)
2238             count      = len(values)
2239             failed     = failures[op]
2240             mean       = sum(values) / count
2241             median     = calc_percentile(values, 0.50)
2242             percentile = calc_percentile(values, 0.95)
2243             rng        = values[-1] - values[0]
2244             maxv       = values[-1]
2245             desc       = OP_DESCRIPTIONS.get(op, '')
2246             print("%-12s   %4s  %-35s %12d %12d %12.6f "
2247                   "%12.6f %12.6f %12.6f %12.6f"
2248                   % (protocol,
2249                      packet_type,
2250                      desc,
2251                      count,
2252                      failed,
2253                      mean,
2254                      median,
2255                      percentile,
2256                      rng,
2257                      maxv))
2258
2259
2260 def opcode_key(v):
2261     """Sort key for the operation code to ensure that it sorts numerically"""
2262     try:
2263         return "%03d" % int(v)
2264     except ValueError:
2265         return v
2266
2267
2268 def calc_percentile(values, percentile):
2269     """Calculate the specified percentile from the list of values.
2270
2271     Assumes the list is sorted in ascending order.
2272     """
2273
2274     if not values:
2275         return 0
2276     k = (len(values) - 1) * percentile
2277     f = math.floor(k)
2278     c = math.ceil(k)
2279     if f == c:
2280         return values[int(k)]
2281     d0 = values[int(f)] * (c - k)
2282     d1 = values[int(c)] * (k - f)
2283     return d0 + d1
2284
2285
2286 def mk_masked_dir(*path):
2287     """In a testenv we end up with 0777 directories that look an alarming
2288     green colour with ls. Use umask to avoid that."""
2289     # py3 os.mkdir can do this
2290     d = os.path.join(*path)
2291     mask = os.umask(0o077)
2292     os.mkdir(d)
2293     os.umask(mask)
2294     return d