aabf6ed0a4225bd4e2581b9f8c404a444700ff62
[amitay/samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
55 import bisect
56
57 SLEEP_OVERHEAD = 3e-4
58
59 # we don't use None, because it complicates [de]serialisation
60 NON_PACKET = '-'
61
62 CLIENT_CLUES = {
63     ('dns', '0'): 1.0,      # query
64     ('smb', '0x72'): 1.0,   # Negotiate protocol
65     ('ldap', '0'): 1.0,     # bind
66     ('ldap', '3'): 1.0,     # searchRequest
67     ('ldap', '2'): 1.0,     # unbindRequest
68     ('cldap', '3'): 1.0,
69     ('dcerpc', '11'): 1.0,  # bind
70     ('dcerpc', '14'): 1.0,  # Alter_context
71     ('nbns', '0'): 1.0,     # query
72 }
73
74 SERVER_CLUES = {
75     ('dns', '1'): 1.0,      # response
76     ('ldap', '1'): 1.0,     # bind response
77     ('ldap', '4'): 1.0,     # search result
78     ('ldap', '5'): 1.0,     # search done
79     ('cldap', '5'): 1.0,
80     ('dcerpc', '12'): 1.0,  # bind_ack
81     ('dcerpc', '13'): 1.0,  # bind_nak
82     ('dcerpc', '15'): 1.0,  # Alter_context response
83 }
84
85 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
86
87 WAIT_SCALE = 10.0
88 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
89 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
90
91 # DEBUG_LEVEL can be changed by scripts with -d
92 DEBUG_LEVEL = 0
93
94 LOGGER = get_samba_logger(name=__name__)
95
96
97 def debug(level, msg, *args):
98     """Print a formatted debug message to standard error.
99
100
101     :param level: The debug level, message will be printed if it is <= the
102                   currently set debug level. The debug level can be set with
103                   the -d option.
104     :param msg:   The message to be logged, can contain C-Style format
105                   specifiers
106     :param args:  The parameters required by the format specifiers
107     """
108     if level <= DEBUG_LEVEL:
109         if not args:
110             print(msg, file=sys.stderr)
111         else:
112             print(msg % tuple(args), file=sys.stderr)
113
114
115 def debug_lineno(*args):
116     """ Print an unformatted log message to stderr, contaning the line number
117     """
118     tb = traceback.extract_stack(limit=2)
119     print((" %s:" "\033[01;33m"
120            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
121           file=sys.stderr)
122     for a in args:
123         print(a, file=sys.stderr)
124     print(file=sys.stderr)
125     sys.stderr.flush()
126
127
128 def random_colour_print():
129     """Return a function that prints a randomly coloured line to stderr"""
130     n = 18 + random.randrange(214)
131     prefix = "\033[38;5;%dm" % n
132
133     def p(*args):
134         for a in args:
135             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
136
137     return p
138
139
140 class FakePacketError(Exception):
141     pass
142
143
144 class Packet(object):
145     """Details of a network packet"""
146     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
147                  protocol, opcode, desc, extra):
148
149         self.timestamp = timestamp
150         self.ip_protocol = ip_protocol
151         self.stream_number = stream_number
152         self.src = src
153         self.dest = dest
154         self.protocol = protocol
155         self.opcode = opcode
156         self.desc = desc
157         self.extra = extra
158         if self.src < self.dest:
159             self.endpoints = (self.src, self.dest)
160         else:
161             self.endpoints = (self.dest, self.src)
162
163     @classmethod
164     def from_line(self, line):
165         fields = line.rstrip('\n').split('\t')
166         (timestamp,
167          ip_protocol,
168          stream_number,
169          src,
170          dest,
171          protocol,
172          opcode,
173          desc) = fields[:8]
174         extra = fields[8:]
175
176         timestamp = float(timestamp)
177         src = int(src)
178         dest = int(dest)
179
180         return Packet(timestamp, ip_protocol, stream_number, src, dest,
181                       protocol, opcode, desc, extra)
182
183     def as_summary(self, time_offset=0.0):
184         """Format the packet as a traffic_summary line.
185         """
186         extra = '\t'.join(self.extra)
187         t = self.timestamp + time_offset
188         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
189                 (t,
190                  self.ip_protocol,
191                  self.stream_number or '',
192                  self.src,
193                  self.dest,
194                  self.protocol,
195                  self.opcode,
196                  self.desc,
197                  extra))
198
199     def __str__(self):
200         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
201                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
202                  self.stream_number, self.protocol, self.opcode, self.desc,
203                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
204
205     def __repr__(self):
206         return "<Packet @%s>" % self
207
208     def copy(self):
209         return self.__class__(self.timestamp,
210                               self.ip_protocol,
211                               self.stream_number,
212                               self.src,
213                               self.dest,
214                               self.protocol,
215                               self.opcode,
216                               self.desc,
217                               self.extra)
218
219     def as_packet_type(self):
220         t = '%s:%s' % (self.protocol, self.opcode)
221         return t
222
223     def client_score(self):
224         """A positive number means we think it is a client; a negative number
225         means we think it is a server. Zero means no idea. range: -1 to 1.
226         """
227         key = (self.protocol, self.opcode)
228         if key in CLIENT_CLUES:
229             return CLIENT_CLUES[key]
230         if key in SERVER_CLUES:
231             return -SERVER_CLUES[key]
232         return 0.0
233
234     def play(self, conversation, context):
235         """Send the packet over the network, if required.
236
237         Some packets are ignored, i.e. for  protocols not handled,
238         server response messages, or messages that are generated by the
239         protocol layer associated with other packets.
240         """
241         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
242         try:
243             fn = getattr(traffic_packets, fn_name)
244
245         except AttributeError as e:
246             print("Conversation(%s) Missing handler %s" %
247                   (conversation.conversation_id, fn_name),
248                   file=sys.stderr)
249             return
250
251         # Don't display a message for kerberos packets, they're not directly
252         # generated they're used to indicate kerberos should be used
253         if self.protocol != "kerberos":
254             debug(2, "Conversation(%s) Calling handler %s" %
255                      (conversation.conversation_id, fn_name))
256
257         start = time.time()
258         try:
259             if fn(self, conversation, context):
260                 # Only collect timing data for functions that generate
261                 # network traffic, or fail
262                 end = time.time()
263                 duration = end - start
264                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
265                       (end, conversation.conversation_id, self.protocol,
266                        self.opcode, duration))
267         except Exception as e:
268             end = time.time()
269             duration = end - start
270             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
271                   (end, conversation.conversation_id, self.protocol,
272                    self.opcode, duration, e))
273
274     def __cmp__(self, other):
275         return self.timestamp - other.timestamp
276
277     def is_really_a_packet(self, missing_packet_stats=None):
278         """Is the packet one that can be ignored?
279
280         If so removing it will have no effect on the replay
281         """
282         if self.protocol in SKIPPED_PROTOCOLS:
283             # Ignore any packets for the protocols we're not interested in.
284             return False
285         if self.protocol == "ldap" and self.opcode == '':
286             # skip ldap continuation packets
287             return False
288
289         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
290         fn = getattr(traffic_packets, fn_name, None)
291         if not fn:
292             print("missing packet %s" % fn_name, file=sys.stderr)
293             return False
294         if fn is traffic_packets.null_packet:
295             return False
296         return True
297
298
299 class ReplayContext(object):
300     """State/Context for an individual conversation between an simulated client
301        and a server.
302     """
303
304     def __init__(self,
305                  server=None,
306                  lp=None,
307                  creds=None,
308                  badpassword_frequency=None,
309                  prefer_kerberos=None,
310                  tempdir=None,
311                  statsdir=None,
312                  ou=None,
313                  base_dn=None,
314                  domain=None,
315                  domain_sid=None):
316
317         self.server                   = server
318         self.ldap_connections         = []
319         self.dcerpc_connections       = []
320         self.lsarpc_connections       = []
321         self.lsarpc_connections_named = []
322         self.drsuapi_connections      = []
323         self.srvsvc_connections       = []
324         self.samr_contexts            = []
325         self.netlogon_connection      = None
326         self.creds                    = creds
327         self.lp                       = lp
328         self.prefer_kerberos          = prefer_kerberos
329         self.ou                       = ou
330         self.base_dn                  = base_dn
331         self.domain                   = domain
332         self.statsdir                 = statsdir
333         self.global_tempdir           = tempdir
334         self.domain_sid               = domain_sid
335         self.realm                    = lp.get('realm')
336
337         # Bad password attempt controls
338         self.badpassword_frequency    = badpassword_frequency
339         self.last_lsarpc_bad          = False
340         self.last_lsarpc_named_bad    = False
341         self.last_simple_bind_bad     = False
342         self.last_bind_bad            = False
343         self.last_srvsvc_bad          = False
344         self.last_drsuapi_bad         = False
345         self.last_netlogon_bad        = False
346         self.last_samlogon_bad        = False
347         self.generate_ldap_search_tables()
348         self.next_conversation_id = itertools.count()
349
350     def generate_ldap_search_tables(self):
351         session = system_session()
352
353         db = SamDB(url="ldap://%s" % self.server,
354                    session_info=session,
355                    credentials=self.creds,
356                    lp=self.lp)
357
358         res = db.search(db.domain_dn(),
359                         scope=ldb.SCOPE_SUBTREE,
360                         controls=["paged_results:1:1000"],
361                         attrs=['dn'])
362
363         # find a list of dns for each pattern
364         # e.g. CN,CN,CN,DC,DC
365         dn_map = {}
366         attribute_clue_map = {
367             'invocationId': []
368         }
369
370         for r in res:
371             dn = str(r.dn)
372             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
373             dns = dn_map.setdefault(pattern, [])
374             dns.append(dn)
375             if dn.startswith('CN=NTDS Settings,'):
376                 attribute_clue_map['invocationId'].append(dn)
377
378         # extend the map in case we are working with a different
379         # number of DC components.
380         # for k, v in self.dn_map.items():
381         #     print >>sys.stderr, k, len(v)
382
383         for k in list(dn_map.keys()):
384             if k[-3:] != ',DC':
385                 continue
386             p = k[:-3]
387             while p[-3:] == ',DC':
388                 p = p[:-3]
389             for i in range(5):
390                 p += ',DC'
391                 if p != k and p in dn_map:
392                     print('dn_map collison %s %s' % (k, p),
393                           file=sys.stderr)
394                     continue
395                 dn_map[p] = dn_map[k]
396
397         self.dn_map = dn_map
398         self.attribute_clue_map = attribute_clue_map
399
400     def generate_process_local_config(self, account, conversation):
401         if account is None:
402             return
403         self.netbios_name             = account.netbios_name
404         self.machinepass              = account.machinepass
405         self.username                 = account.username
406         self.userpass                 = account.userpass
407
408         self.tempdir = mk_masked_dir(self.global_tempdir,
409                                      'conversation-%d' %
410                                      conversation.conversation_id)
411
412         self.lp.set("private dir", self.tempdir)
413         self.lp.set("lock dir", self.tempdir)
414         self.lp.set("state directory", self.tempdir)
415         self.lp.set("tls verify peer", "no_check")
416
417         # If the domain was not specified, check for the environment
418         # variable.
419         if self.domain is None:
420             self.domain = os.environ["DOMAIN"]
421
422         self.remoteAddress = "/root/ncalrpc_as_system"
423         self.samlogon_dn   = ("cn=%s,%s" %
424                               (self.netbios_name, self.ou))
425         self.user_dn       = ("cn=%s,%s" %
426                               (self.username, self.ou))
427
428         self.generate_machine_creds()
429         self.generate_user_creds()
430
431     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
432         """Execute the supplied logon function, randomly choosing the
433            bad credentials.
434
435            Based on the frequency in badpassword_frequency randomly perform the
436            function with the supplied bad credentials.
437            If run with bad credentials, the function is re-run with the good
438            credentials.
439            failed_last_time is used to prevent consecutive bad credential
440            attempts. So the over all bad credential frequency will be lower
441            than that requested, but not significantly.
442         """
443         if not failed_last_time:
444             if (self.badpassword_frequency and self.badpassword_frequency > 0
445                 and random.random() < self.badpassword_frequency):
446                 try:
447                     f(bad)
448                 except:
449                     # Ignore any exceptions as the operation may fail
450                     # as it's being performed with bad credentials
451                     pass
452                 failed_last_time = True
453             else:
454                 failed_last_time = False
455
456         result = f(good)
457         return (result, failed_last_time)
458
459     def generate_user_creds(self):
460         """Generate the conversation specific user Credentials.
461
462         Each Conversation has an associated user account used to simulate
463         any non Administrative user traffic.
464
465         Generates user credentials with good and bad passwords and ldap
466         simple bind credentials with good and bad passwords.
467         """
468         self.user_creds = Credentials()
469         self.user_creds.guess(self.lp)
470         self.user_creds.set_workstation(self.netbios_name)
471         self.user_creds.set_password(self.userpass)
472         self.user_creds.set_username(self.username)
473         self.user_creds.set_domain(self.domain)
474         if self.prefer_kerberos:
475             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
476         else:
477             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
478
479         self.user_creds_bad = Credentials()
480         self.user_creds_bad.guess(self.lp)
481         self.user_creds_bad.set_workstation(self.netbios_name)
482         self.user_creds_bad.set_password(self.userpass[:-4])
483         self.user_creds_bad.set_username(self.username)
484         if self.prefer_kerberos:
485             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
486         else:
487             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
488
489         # Credentials for ldap simple bind.
490         self.simple_bind_creds = Credentials()
491         self.simple_bind_creds.guess(self.lp)
492         self.simple_bind_creds.set_workstation(self.netbios_name)
493         self.simple_bind_creds.set_password(self.userpass)
494         self.simple_bind_creds.set_username(self.username)
495         self.simple_bind_creds.set_gensec_features(
496             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
497         if self.prefer_kerberos:
498             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
499         else:
500             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
501         self.simple_bind_creds.set_bind_dn(self.user_dn)
502
503         self.simple_bind_creds_bad = Credentials()
504         self.simple_bind_creds_bad.guess(self.lp)
505         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
506         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
507         self.simple_bind_creds_bad.set_username(self.username)
508         self.simple_bind_creds_bad.set_gensec_features(
509             self.simple_bind_creds_bad.get_gensec_features() |
510             gensec.FEATURE_SEAL)
511         if self.prefer_kerberos:
512             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
513         else:
514             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
515         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
516
517     def generate_machine_creds(self):
518         """Generate the conversation specific machine Credentials.
519
520         Each Conversation has an associated machine account.
521
522         Generates machine credentials with good and bad passwords.
523         """
524
525         self.machine_creds = Credentials()
526         self.machine_creds.guess(self.lp)
527         self.machine_creds.set_workstation(self.netbios_name)
528         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
529         self.machine_creds.set_password(self.machinepass)
530         self.machine_creds.set_username(self.netbios_name + "$")
531         self.machine_creds.set_domain(self.domain)
532         if self.prefer_kerberos:
533             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
534         else:
535             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
536
537         self.machine_creds_bad = Credentials()
538         self.machine_creds_bad.guess(self.lp)
539         self.machine_creds_bad.set_workstation(self.netbios_name)
540         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
541         self.machine_creds_bad.set_password(self.machinepass[:-4])
542         self.machine_creds_bad.set_username(self.netbios_name + "$")
543         if self.prefer_kerberos:
544             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
545         else:
546             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
547
548     def get_matching_dn(self, pattern, attributes=None):
549         # If the pattern is an empty string, we assume ROOTDSE,
550         # Otherwise we try adding or removing DC suffixes, then
551         # shorter leading patterns until we hit one.
552         # e.g if there is no CN,CN,CN,CN,DC,DC
553         # we first try       CN,CN,CN,CN,DC
554         # and                CN,CN,CN,CN,DC,DC,DC
555         # then change to        CN,CN,CN,DC,DC
556         # and as last resort we use the base_dn
557         attr_clue = self.attribute_clue_map.get(attributes)
558         if attr_clue:
559             return random.choice(attr_clue)
560
561         pattern = pattern.upper()
562         while pattern:
563             if pattern in self.dn_map:
564                 return random.choice(self.dn_map[pattern])
565             # chop one off the front and try it all again.
566             pattern = pattern[3:]
567
568         return self.base_dn
569
570     def get_dcerpc_connection(self, new=False):
571         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
572         if self.dcerpc_connections and not new:
573             return self.dcerpc_connections[-1]
574         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
575                              (guid, 1), self.lp)
576         self.dcerpc_connections.append(c)
577         return c
578
579     def get_srvsvc_connection(self, new=False):
580         if self.srvsvc_connections and not new:
581             return self.srvsvc_connections[-1]
582
583         def connect(creds):
584             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
585                                  self.lp,
586                                  creds)
587
588         (c, self.last_srvsvc_bad) = \
589             self.with_random_bad_credentials(connect,
590                                              self.user_creds,
591                                              self.user_creds_bad,
592                                              self.last_srvsvc_bad)
593
594         self.srvsvc_connections.append(c)
595         return c
596
597     def get_lsarpc_connection(self, new=False):
598         if self.lsarpc_connections and not new:
599             return self.lsarpc_connections[-1]
600
601         def connect(creds):
602             binding_options = 'schannel,seal,sign'
603             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
604                               (self.server, binding_options),
605                               self.lp,
606                               creds)
607
608         (c, self.last_lsarpc_bad) = \
609             self.with_random_bad_credentials(connect,
610                                              self.machine_creds,
611                                              self.machine_creds_bad,
612                                              self.last_lsarpc_bad)
613
614         self.lsarpc_connections.append(c)
615         return c
616
617     def get_lsarpc_named_pipe_connection(self, new=False):
618         if self.lsarpc_connections_named and not new:
619             return self.lsarpc_connections_named[-1]
620
621         def connect(creds):
622             return lsa.lsarpc("ncacn_np:%s" % (self.server),
623                               self.lp,
624                               creds)
625
626         (c, self.last_lsarpc_named_bad) = \
627             self.with_random_bad_credentials(connect,
628                                              self.machine_creds,
629                                              self.machine_creds_bad,
630                                              self.last_lsarpc_named_bad)
631
632         self.lsarpc_connections_named.append(c)
633         return c
634
635     def get_drsuapi_connection_pair(self, new=False, unbind=False):
636         """get a (drs, drs_handle) tuple"""
637         if self.drsuapi_connections and not new:
638             c = self.drsuapi_connections[-1]
639             return c
640
641         def connect(creds):
642             binding_options = 'seal'
643             binding_string = "ncacn_ip_tcp:%s[%s]" %\
644                              (self.server, binding_options)
645             return drsuapi.drsuapi(binding_string, self.lp, creds)
646
647         (drs, self.last_drsuapi_bad) = \
648             self.with_random_bad_credentials(connect,
649                                              self.user_creds,
650                                              self.user_creds_bad,
651                                              self.last_drsuapi_bad)
652
653         (drs_handle, supported_extensions) = drs_DsBind(drs)
654         c = (drs, drs_handle)
655         self.drsuapi_connections.append(c)
656         return c
657
658     def get_ldap_connection(self, new=False, simple=False):
659         if self.ldap_connections and not new:
660             return self.ldap_connections[-1]
661
662         def simple_bind(creds):
663             """
664             To run simple bind against Windows, we need to run
665             following commands in PowerShell:
666
667                 Install-windowsfeature ADCS-Cert-Authority
668                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
669                 Restart-Computer
670
671             """
672             return SamDB('ldaps://%s' % self.server,
673                          credentials=creds,
674                          lp=self.lp)
675
676         def sasl_bind(creds):
677             return SamDB('ldap://%s' % self.server,
678                          credentials=creds,
679                          lp=self.lp)
680         if simple:
681             (samdb, self.last_simple_bind_bad) = \
682                 self.with_random_bad_credentials(simple_bind,
683                                                  self.simple_bind_creds,
684                                                  self.simple_bind_creds_bad,
685                                                  self.last_simple_bind_bad)
686         else:
687             (samdb, self.last_bind_bad) = \
688                 self.with_random_bad_credentials(sasl_bind,
689                                                  self.user_creds,
690                                                  self.user_creds_bad,
691                                                  self.last_bind_bad)
692
693         self.ldap_connections.append(samdb)
694         return samdb
695
696     def get_samr_context(self, new=False):
697         if not self.samr_contexts or new:
698             self.samr_contexts.append(
699                 SamrContext(self.server, lp=self.lp, creds=self.creds))
700         return self.samr_contexts[-1]
701
702     def get_netlogon_connection(self):
703
704         if self.netlogon_connection:
705             return self.netlogon_connection
706
707         def connect(creds):
708             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
709                                      (self.server),
710                                      self.lp,
711                                      creds)
712         (c, self.last_netlogon_bad) = \
713             self.with_random_bad_credentials(connect,
714                                              self.machine_creds,
715                                              self.machine_creds_bad,
716                                              self.last_netlogon_bad)
717         self.netlogon_connection = c
718         return c
719
720     def guess_a_dns_lookup(self):
721         return (self.realm, 'A')
722
723     def get_authenticator(self):
724         auth = self.machine_creds.new_client_authenticator()
725         current  = netr_Authenticator()
726         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
727         current.timestamp = auth["timestamp"]
728
729         subsequent = netr_Authenticator()
730         return (current, subsequent)
731
732
733 class SamrContext(object):
734     """State/Context associated with a samr connection.
735     """
736     def __init__(self, server, lp=None, creds=None):
737         self.connection    = None
738         self.handle        = None
739         self.domain_handle = None
740         self.domain_sid    = None
741         self.group_handle  = None
742         self.user_handle   = None
743         self.rids          = None
744         self.server        = server
745         self.lp            = lp
746         self.creds         = creds
747
748     def get_connection(self):
749         if not self.connection:
750             self.connection = samr.samr(
751                 "ncacn_ip_tcp:%s[seal]" % (self.server),
752                 lp_ctx=self.lp,
753                 credentials=self.creds)
754
755         return self.connection
756
757     def get_handle(self):
758         if not self.handle:
759             c = self.get_connection()
760             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
761         return self.handle
762
763
764 class Conversation(object):
765     """Details of a converation between a simulated client and a server."""
766     conversation_id = None
767
768     def __init__(self, start_time=None, endpoints=None):
769         self.start_time = start_time
770         self.endpoints = endpoints
771         self.packets = []
772         self.msg = random_colour_print()
773         self.client_balance = 0.0
774
775     def __cmp__(self, other):
776         if self.start_time is None:
777             if other.start_time is None:
778                 return 0
779             return -1
780         if other.start_time is None:
781             return 1
782         return self.start_time - other.start_time
783
784     def add_packet(self, packet):
785         """Add a packet object to this conversation, making a local copy with
786         a conversation-relative timestamp."""
787         p = packet.copy()
788
789         if self.start_time is None:
790             self.start_time = p.timestamp
791
792         if self.endpoints is None:
793             self.endpoints = p.endpoints
794
795         if p.endpoints != self.endpoints:
796             raise FakePacketError("Conversation endpoints %s don't match"
797                                   "packet endpoints %s" %
798                                   (self.endpoints, p.endpoints))
799
800         p.timestamp -= self.start_time
801
802         if p.src == p.endpoints[0]:
803             self.client_balance -= p.client_score()
804         else:
805             self.client_balance += p.client_score()
806
807         if p.is_really_a_packet():
808             self.packets.append(p)
809
810     def add_short_packet(self, timestamp, protocol, opcode, extra,
811                          client=True):
812         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
813         (possibly empty) list of extra data. If client is True, assume
814         this packet is from the client to the server.
815         """
816         src, dest = self.guess_client_server()
817         if not client:
818             src, dest = dest, src
819         key = (protocol, opcode)
820         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
821         if protocol in IP_PROTOCOLS:
822             ip_protocol = IP_PROTOCOLS[protocol]
823         else:
824             ip_protocol = '06'
825         packet = Packet(timestamp - self.start_time, ip_protocol,
826                         '', src, dest,
827                         protocol, opcode, desc, extra)
828         # XXX we're assuming the timestamp is already adjusted for
829         # this conversation?
830         # XXX should we adjust client balance for guessed packets?
831         if packet.src == packet.endpoints[0]:
832             self.client_balance -= packet.client_score()
833         else:
834             self.client_balance += packet.client_score()
835         if packet.is_really_a_packet():
836             self.packets.append(packet)
837
838     def __str__(self):
839         return ("<Conversation %s %s starting %.3f %d packets>" %
840                 (self.conversation_id, self.endpoints, self.start_time,
841                  len(self.packets)))
842
843     __repr__ = __str__
844
845     def __iter__(self):
846         return iter(self.packets)
847
848     def __len__(self):
849         return len(self.packets)
850
851     def get_duration(self):
852         if len(self.packets) < 2:
853             return 0
854         return self.packets[-1].timestamp - self.packets[0].timestamp
855
856     def replay_as_summary_lines(self):
857         lines = []
858         for p in self.packets:
859             lines.append(p.as_summary(self.start_time))
860         return lines
861
862     def replay_in_fork_with_delay(self, start, context=None, account=None):
863         """Fork a new process and replay the conversation.
864         """
865         def signal_handler(signal, frame):
866             """Signal handler closes standard out and error.
867
868             Triggered by a sigterm, ensures that the log messages are flushed
869             to disk and not lost.
870             """
871             sys.stderr.close()
872             sys.stdout.close()
873             os._exit(0)
874
875         t = self.start_time
876         now = time.time() - start
877         gap = t - now
878         # we are replaying strictly in order, so it is safe to sleep
879         # in the main process if the gap is big enough. This reduces
880         # the number of concurrent threads, which allows us to make
881         # larger loads.
882         if gap > 0.15 and False:
883             print("sleeping for %f in main process" % (gap - 0.1),
884                   file=sys.stderr)
885             time.sleep(gap - 0.1)
886             now = time.time() - start
887             gap = t - now
888             print("gap is now %f" % gap, file=sys.stderr)
889
890         self.conversation_id = next(context.next_conversation_id)
891         pid = os.fork()
892         if pid != 0:
893             return pid
894         pid = os.getpid()
895         signal.signal(signal.SIGTERM, signal_handler)
896         # we must never return, or we'll end up running parts of the
897         # parent's clean-up code. So we work in a try...finally, and
898         # try to print any exceptions.
899
900         try:
901             context.generate_process_local_config(account, self)
902             sys.stdin.close()
903             os.close(0)
904             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
905                                     self.conversation_id)
906             sys.stdout.close()
907             sys.stdout = open(filename, 'w')
908
909             sleep_time = gap - SLEEP_OVERHEAD
910             if sleep_time > 0:
911                 time.sleep(sleep_time)
912
913             miss = t - (time.time() - start)
914             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
915             self.replay(context)
916         except Exception:
917             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
918                   file=sys.stderr)
919             traceback.print_exc(sys.stderr)
920         finally:
921             sys.stderr.close()
922             sys.stdout.close()
923             os._exit(0)
924
925     def replay(self, context=None):
926         start = time.time()
927
928         for p in self.packets:
929             now = time.time() - start
930             gap = p.timestamp - now
931             sleep_time = gap - SLEEP_OVERHEAD
932             if sleep_time > 0:
933                 time.sleep(sleep_time)
934
935             miss = p.timestamp - (time.time() - start)
936             if context is None:
937                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
938                                                            os.getpid()))
939                 continue
940             p.play(self, context)
941
942     def guess_client_server(self, server_clue=None):
943         """Have a go at deciding who is the server and who is the client.
944         returns (client, server)
945         """
946         a, b = self.endpoints
947
948         if self.client_balance < 0:
949             return (a, b)
950
951         # in the absense of a clue, we will fall through to assuming
952         # the lowest number is the server (which is usually true).
953
954         if self.client_balance == 0 and server_clue == b:
955             return (a, b)
956
957         return (b, a)
958
959     def forget_packets_outside_window(self, s, e):
960         """Prune any packets outside the timne window we're interested in
961
962         :param s: start of the window
963         :param e: end of the window
964         """
965         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
966         self.start_time = self.packets[0].timestamp if self.packets else None
967
968     def renormalise_times(self, start_time):
969         """Adjust the packet start times relative to the new start time."""
970         for p in self.packets:
971             p.timestamp -= start_time
972
973         if self.start_time is not None:
974             self.start_time -= start_time
975
976
977 class DnsHammer(Conversation):
978     """A lightweight conversation that generates a lot of dns:0 packets on
979     the fly"""
980
981     def __init__(self, dns_rate, duration):
982         n = int(dns_rate * duration)
983         self.times = [random.uniform(0, duration) for i in range(n)]
984         self.times.sort()
985         self.rate = dns_rate
986         self.duration = duration
987         self.start_time = 0
988         self.msg = random_colour_print()
989
990     def __str__(self):
991         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
992                 (len(self.times), self.duration, self.rate))
993
994     def replay_in_fork_with_delay(self, start, context=None, account=None):
995         return Conversation.replay_in_fork_with_delay(self,
996                                                       start,
997                                                       context,
998                                                       account)
999
1000     def replay(self, context=None):
1001         start = time.time()
1002         fn = traffic_packets.packet_dns_0
1003         for t in self.times:
1004             now = time.time() - start
1005             gap = t - now
1006             sleep_time = gap - SLEEP_OVERHEAD
1007             if sleep_time > 0:
1008                 time.sleep(sleep_time)
1009
1010             if context is None:
1011                 miss = t - (time.time() - start)
1012                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1013                                                            os.getpid()))
1014                 continue
1015
1016             packet_start = time.time()
1017             try:
1018                 fn(self, self, context)
1019                 end = time.time()
1020                 duration = end - packet_start
1021                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1022             except Exception as e:
1023                 end = time.time()
1024                 duration = end - packet_start
1025                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1026
1027
1028 def ingest_summaries(files, dns_mode='count'):
1029     """Load a summary traffic summary file and generated Converations from it.
1030     """
1031
1032     dns_counts = defaultdict(int)
1033     packets = []
1034     for f in files:
1035         if isinstance(f, str):
1036             f = open(f)
1037         print("Ingesting %s" % (f.name,), file=sys.stderr)
1038         for line in f:
1039             p = Packet.from_line(line)
1040             if p.protocol == 'dns' and dns_mode != 'include':
1041                 dns_counts[p.opcode] += 1
1042             else:
1043                 packets.append(p)
1044
1045         f.close()
1046
1047     if not packets:
1048         return [], 0
1049
1050     start_time = min(p.timestamp for p in packets)
1051     last_packet = max(p.timestamp for p in packets)
1052
1053     print("gathering packets into conversations", file=sys.stderr)
1054     conversations = OrderedDict()
1055     for p in packets:
1056         p.timestamp -= start_time
1057         c = conversations.get(p.endpoints)
1058         if c is None:
1059             c = Conversation()
1060             conversations[p.endpoints] = c
1061         c.add_packet(p)
1062
1063     # We only care about conversations with actual traffic, so we
1064     # filter out conversations with nothing to say. We do that here,
1065     # rather than earlier, because those empty packets contain useful
1066     # hints as to which end of the conversation was the client.
1067     conversation_list = []
1068     for c in conversations.values():
1069         if len(c) != 0:
1070             conversation_list.append(c)
1071
1072     # This is obviously not correct, as many conversations will appear
1073     # to start roughly simultaneously at the beginning of the snapshot.
1074     # To which we say: oh well, so be it.
1075     duration = float(last_packet - start_time)
1076     mean_interval = len(conversations) / duration
1077
1078     return conversation_list, mean_interval, duration, dns_counts
1079
1080
1081 def guess_server_address(conversations):
1082     # we guess the most common address.
1083     addresses = Counter()
1084     for c in conversations:
1085         addresses.update(c.endpoints)
1086     if addresses:
1087         return addresses.most_common(1)[0]
1088
1089
1090 def stringify_keys(x):
1091     y = {}
1092     for k, v in x.items():
1093         k2 = '\t'.join(k)
1094         y[k2] = v
1095     return y
1096
1097
1098 def unstringify_keys(x):
1099     y = {}
1100     for k, v in x.items():
1101         t = tuple(str(k).split('\t'))
1102         y[t] = v
1103     return y
1104
1105
1106 class TrafficModel(object):
1107     def __init__(self, n=3):
1108         self.ngrams = {}
1109         self.query_details = {}
1110         self.n = n
1111         self.dns_opcounts = defaultdict(int)
1112         self.cumulative_duration = 0.0
1113         self.conversation_rate = [0, 1]
1114
1115     def learn(self, conversations, dns_opcounts={}):
1116         prev = 0.0
1117         cum_duration = 0.0
1118         key = (NON_PACKET,) * (self.n - 1)
1119
1120         server = guess_server_address(conversations)
1121
1122         for k, v in dns_opcounts.items():
1123             self.dns_opcounts[k] += v
1124
1125         if len(conversations) > 1:
1126             elapsed =\
1127                 conversations[-1].start_time - conversations[0].start_time
1128             self.conversation_rate[0] = len(conversations)
1129             self.conversation_rate[1] = elapsed
1130
1131         for c in conversations:
1132             client, server = c.guess_client_server(server)
1133             cum_duration += c.get_duration()
1134             key = (NON_PACKET,) * (self.n - 1)
1135             for p in c:
1136                 if p.src != client:
1137                     continue
1138
1139                 elapsed = p.timestamp - prev
1140                 prev = p.timestamp
1141                 if elapsed > WAIT_THRESHOLD:
1142                     # add the wait as an extra state
1143                     wait = 'wait:%d' % (math.log(max(1.0,
1144                                                      elapsed * WAIT_SCALE)))
1145                     self.ngrams.setdefault(key, []).append(wait)
1146                     key = key[1:] + (wait,)
1147
1148                 short_p = p.as_packet_type()
1149                 self.query_details.setdefault(short_p,
1150                                               []).append(tuple(p.extra))
1151                 self.ngrams.setdefault(key, []).append(short_p)
1152                 key = key[1:] + (short_p,)
1153
1154         self.cumulative_duration += cum_duration
1155         # add in the end
1156         self.ngrams.setdefault(key, []).append(NON_PACKET)
1157
1158     def save(self, f):
1159         ngrams = {}
1160         for k, v in self.ngrams.items():
1161             k = '\t'.join(k)
1162             ngrams[k] = dict(Counter(v))
1163
1164         query_details = {}
1165         for k, v in self.query_details.items():
1166             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1167                                             for x in v))
1168
1169         d = {
1170             'ngrams': ngrams,
1171             'query_details': query_details,
1172             'cumulative_duration': self.cumulative_duration,
1173             'conversation_rate': self.conversation_rate,
1174         }
1175         d['dns'] = self.dns_opcounts
1176
1177         if isinstance(f, str):
1178             f = open(f, 'w')
1179
1180         json.dump(d, f, indent=2)
1181
1182     def load(self, f):
1183         if isinstance(f, str):
1184             f = open(f)
1185
1186         d = json.load(f)
1187
1188         for k, v in d['ngrams'].items():
1189             k = tuple(str(k).split('\t'))
1190             values = self.ngrams.setdefault(k, [])
1191             for p, count in v.items():
1192                 values.extend([str(p)] * count)
1193
1194         for k, v in d['query_details'].items():
1195             values = self.query_details.setdefault(str(k), [])
1196             for p, count in v.items():
1197                 if p == '-':
1198                     values.extend([()] * count)
1199                 else:
1200                     values.extend([tuple(str(p).split('\t'))] * count)
1201
1202         if 'dns' in d:
1203             for k, v in d['dns'].items():
1204                 self.dns_opcounts[k] += v
1205
1206         self.cumulative_duration = d['cumulative_duration']
1207         self.conversation_rate = d['conversation_rate']
1208
1209     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1210                                hard_stop=None, packet_rate=1):
1211         """Construct a individual converation from the model."""
1212
1213         c = Conversation(timestamp, (server, client))
1214
1215         key = (NON_PACKET,) * (self.n - 1)
1216
1217         while key in self.ngrams:
1218             p = random.choice(self.ngrams.get(key, NON_PACKET))
1219             if p == NON_PACKET:
1220                 break
1221             if p in self.query_details:
1222                 extra = random.choice(self.query_details[p])
1223             else:
1224                 extra = []
1225
1226             protocol, opcode = p.split(':', 1)
1227             if protocol == 'wait':
1228                 log_wait_time = int(opcode) + random.random()
1229                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1230                 timestamp += wait
1231             else:
1232                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1233                 wait = math.exp(log_wait) / packet_rate
1234                 timestamp += wait
1235                 if hard_stop is not None and timestamp > hard_stop:
1236                     break
1237                 c.add_short_packet(timestamp, protocol, opcode, extra)
1238
1239             key = key[1:] + (p,)
1240
1241         return c
1242
1243     def generate_conversations(self, rate, duration, packet_rate=1):
1244         """Generate a list of conversations from the model."""
1245
1246         # We run the simulation for at least ten times as long as our
1247         # desired duration, and take a section near the start.
1248         rate_n, rate_t  = self.conversation_rate
1249
1250         duration2 = max(rate_t, duration * 2)
1251         n = rate * duration2 * rate_n / rate_t
1252
1253         server = 1
1254         client = 2
1255
1256         conversations = []
1257         end = duration2
1258         start = end - duration
1259
1260         while client < n + 2:
1261             start = random.uniform(0, duration2)
1262             c = self.construct_conversation(start,
1263                                             client,
1264                                             server,
1265                                             hard_stop=(duration2 * 5),
1266                                             packet_rate=packet_rate)
1267
1268             c.forget_packets_outside_window(start, end)
1269             c.renormalise_times(start)
1270             if len(c) != 0:
1271                 conversations.append(c)
1272             client += 1
1273
1274         print(("we have %d conversations at rate %f" %
1275                (len(conversations), rate)), file=sys.stderr)
1276         conversations.sort()
1277         return conversations
1278
1279
1280 IP_PROTOCOLS = {
1281     'dns': '11',
1282     'rpc_netlogon': '06',
1283     'kerberos': '06',      # ratio 16248:258
1284     'smb': '06',
1285     'smb2': '06',
1286     'ldap': '06',
1287     'cldap': '11',
1288     'lsarpc': '06',
1289     'samr': '06',
1290     'dcerpc': '06',
1291     'epm': '06',
1292     'drsuapi': '06',
1293     'browser': '11',
1294     'smb_netlogon': '11',
1295     'srvsvc': '06',
1296     'nbns': '11',
1297 }
1298
1299 OP_DESCRIPTIONS = {
1300     ('browser', '0x01'): 'Host Announcement (0x01)',
1301     ('browser', '0x02'): 'Request Announcement (0x02)',
1302     ('browser', '0x08'): 'Browser Election Request (0x08)',
1303     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1304     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1305     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1306     ('cldap', '3'): 'searchRequest',
1307     ('cldap', '5'): 'searchResDone',
1308     ('dcerpc', '0'): 'Request',
1309     ('dcerpc', '11'): 'Bind',
1310     ('dcerpc', '12'): 'Bind_ack',
1311     ('dcerpc', '13'): 'Bind_nak',
1312     ('dcerpc', '14'): 'Alter_context',
1313     ('dcerpc', '15'): 'Alter_context_resp',
1314     ('dcerpc', '16'): 'AUTH3',
1315     ('dcerpc', '2'): 'Response',
1316     ('dns', '0'): 'query',
1317     ('dns', '1'): 'response',
1318     ('drsuapi', '0'): 'DsBind',
1319     ('drsuapi', '12'): 'DsCrackNames',
1320     ('drsuapi', '13'): 'DsWriteAccountSpn',
1321     ('drsuapi', '1'): 'DsUnbind',
1322     ('drsuapi', '2'): 'DsReplicaSync',
1323     ('drsuapi', '3'): 'DsGetNCChanges',
1324     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1325     ('epm', '3'): 'Map',
1326     ('kerberos', ''): '',
1327     ('ldap', '0'): 'bindRequest',
1328     ('ldap', '1'): 'bindResponse',
1329     ('ldap', '2'): 'unbindRequest',
1330     ('ldap', '3'): 'searchRequest',
1331     ('ldap', '4'): 'searchResEntry',
1332     ('ldap', '5'): 'searchResDone',
1333     ('ldap', ''): '*** Unknown ***',
1334     ('lsarpc', '14'): 'lsa_LookupNames',
1335     ('lsarpc', '15'): 'lsa_LookupSids',
1336     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1337     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1338     ('lsarpc', '6'): 'lsa_OpenPolicy',
1339     ('lsarpc', '76'): 'lsa_LookupSids3',
1340     ('lsarpc', '77'): 'lsa_LookupNames4',
1341     ('nbns', '0'): 'query',
1342     ('nbns', '1'): 'response',
1343     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1344     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1345     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1346     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1347     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1348     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1349     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1350     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1351     ('samr', '0',): 'Connect',
1352     ('samr', '16'): 'GetAliasMembership',
1353     ('samr', '17'): 'LookupNames',
1354     ('samr', '18'): 'LookupRids',
1355     ('samr', '19'): 'OpenGroup',
1356     ('samr', '1'): 'Close',
1357     ('samr', '25'): 'QueryGroupMember',
1358     ('samr', '34'): 'OpenUser',
1359     ('samr', '36'): 'QueryUserInfo',
1360     ('samr', '39'): 'GetGroupsForUser',
1361     ('samr', '3'): 'QuerySecurity',
1362     ('samr', '5'): 'LookupDomain',
1363     ('samr', '64'): 'Connect5',
1364     ('samr', '6'): 'EnumDomains',
1365     ('samr', '7'): 'OpenDomain',
1366     ('samr', '8'): 'QueryDomainInfo',
1367     ('smb', '0x04'): 'Close (0x04)',
1368     ('smb', '0x24'): 'Locking AndX (0x24)',
1369     ('smb', '0x2e'): 'Read AndX (0x2e)',
1370     ('smb', '0x32'): 'Trans2 (0x32)',
1371     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1372     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1373     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1374     ('smb', '0x74'): 'Logoff AndX (0x74)',
1375     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1376     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1377     ('smb2', '0'): 'NegotiateProtocol',
1378     ('smb2', '11'): 'Ioctl',
1379     ('smb2', '14'): 'Find',
1380     ('smb2', '16'): 'GetInfo',
1381     ('smb2', '18'): 'Break',
1382     ('smb2', '1'): 'SessionSetup',
1383     ('smb2', '2'): 'SessionLogoff',
1384     ('smb2', '3'): 'TreeConnect',
1385     ('smb2', '4'): 'TreeDisconnect',
1386     ('smb2', '5'): 'Create',
1387     ('smb2', '6'): 'Close',
1388     ('smb2', '8'): 'Read',
1389     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1390     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1391                                'user unknown (0x17)'),
1392     ('srvsvc', '16'): 'NetShareGetInfo',
1393     ('srvsvc', '21'): 'NetSrvGetInfo',
1394 }
1395
1396
1397 def expand_short_packet(p, timestamp, src, dest, extra):
1398     protocol, opcode = p.split(':', 1)
1399     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1400     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1401
1402     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1403     line.extend(extra)
1404     return '\t'.join(line)
1405
1406
1407 def replay(conversations,
1408            host=None,
1409            creds=None,
1410            lp=None,
1411            accounts=None,
1412            dns_rate=0,
1413            duration=None,
1414            **kwargs):
1415
1416     context = ReplayContext(server=host,
1417                             creds=creds,
1418                             lp=lp,
1419                             **kwargs)
1420
1421     if len(accounts) < len(conversations):
1422         print(("we have %d accounts but %d conversations" %
1423                (accounts, conversations)), file=sys.stderr)
1424
1425     cstack = list(zip(
1426         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1427         accounts))
1428
1429     # Set the process group so that the calling scripts are not killed
1430     # when the forked child processes are killed.
1431     os.setpgrp()
1432
1433     start = time.time()
1434
1435     if duration is None:
1436         # end 1 second after the last packet of the last conversation
1437         # to start. Conversations other than the last could still be
1438         # going, but we don't care.
1439         duration = cstack[0][0].packets[-1].timestamp + 1.0
1440         print("We will stop after %.1f seconds" % duration,
1441               file=sys.stderr)
1442
1443     end = start + duration
1444
1445     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1446           % (len(conversations), duration))
1447
1448     children = {}
1449     if dns_rate:
1450         dns_hammer = DnsHammer(dns_rate, duration)
1451         cstack.append((dns_hammer, None))
1452
1453     try:
1454         while True:
1455             # we spawn a batch, wait for finishers, then spawn another
1456             now = time.time()
1457             batch_end = min(now + 2.0, end)
1458             fork_time = 0.0
1459             fork_n = 0
1460             while cstack:
1461                 c, account = cstack.pop()
1462                 if c.start_time + start > batch_end:
1463                     cstack.append((c, account))
1464                     break
1465
1466                 st = time.time()
1467                 pid = c.replay_in_fork_with_delay(start, context, account)
1468                 children[pid] = c
1469                 t = time.time()
1470                 elapsed = t - st
1471                 fork_time += elapsed
1472                 fork_n += 1
1473                 print("forked %s in pid %s (in %fs)" % (c, pid,
1474                                                         elapsed),
1475                       file=sys.stderr)
1476
1477             if fork_n:
1478                 print(("forked %d times in %f seconds (avg %f)" %
1479                        (fork_n, fork_time, fork_time / fork_n)),
1480                       file=sys.stderr)
1481             elif cstack:
1482                 debug(2, "no forks in batch ending %f" % batch_end)
1483
1484             while time.time() < batch_end - 1.0:
1485                 time.sleep(0.01)
1486                 try:
1487                     pid, status = os.waitpid(-1, os.WNOHANG)
1488                 except OSError as e:
1489                     if e.errno != 10:  # no child processes
1490                         raise
1491                     break
1492                 if pid:
1493                     c = children.pop(pid, None)
1494                     print(("process %d finished conversation %s;"
1495                            " %d to go" %
1496                            (pid, c, len(children))), file=sys.stderr)
1497
1498             if time.time() >= end:
1499                 print("time to stop", file=sys.stderr)
1500                 break
1501
1502     except Exception:
1503         print("EXCEPTION in parent", file=sys.stderr)
1504         traceback.print_exc()
1505     finally:
1506         for s in (15, 15, 9):
1507             print(("killing %d children with -%d" %
1508                    (len(children), s)), file=sys.stderr)
1509             for pid in children:
1510                 try:
1511                     os.kill(pid, s)
1512                 except OSError as e:
1513                     if e.errno != 3:  # don't fail if it has already died
1514                         raise
1515             time.sleep(0.5)
1516             end = time.time() + 1
1517             while children:
1518                 try:
1519                     pid, status = os.waitpid(-1, os.WNOHANG)
1520                 except OSError as e:
1521                     if e.errno != 10:
1522                         raise
1523                 if pid != 0:
1524                     c = children.pop(pid, None)
1525                     print(("kill -%d %d KILLED conversation %s; "
1526                            "%d to go" %
1527                            (s, pid, c, len(children))),
1528                           file=sys.stderr)
1529                 if time.time() >= end:
1530                     break
1531
1532             if not children:
1533                 break
1534             time.sleep(1)
1535
1536         if children:
1537             print("%d children are missing" % len(children),
1538                   file=sys.stderr)
1539
1540         # there may be stragglers that were forked just as ^C was hit
1541         # and don't appear in the list of children. We can get them
1542         # with killpg, but that will also kill us, so this is^H^H would be
1543         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1544         # so as not to have to fuss around writing signal handlers.
1545         try:
1546             os.killpg(0, 2)
1547         except KeyboardInterrupt:
1548             print("ignoring fake ^C", file=sys.stderr)
1549
1550
1551 def openLdb(host, creds, lp):
1552     session = system_session()
1553     ldb = SamDB(url="ldap://%s" % host,
1554                 session_info=session,
1555                 options=['modules:paged_searches'],
1556                 credentials=creds,
1557                 lp=lp)
1558     return ldb
1559
1560
1561 def ou_name(ldb, instance_id):
1562     """Generate an ou name from the instance id"""
1563     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1564                                                     ldb.domain_dn())
1565
1566
1567 def create_ou(ldb, instance_id):
1568     """Create an ou, all created user and machine accounts will belong to it.
1569
1570     This allows all the created resources to be cleaned up easily.
1571     """
1572     ou = ou_name(ldb, instance_id)
1573     try:
1574         ldb.add({"dn": ou.split(',', 1)[1],
1575                  "objectclass": "organizationalunit"})
1576     except LdbError as e:
1577         (status, _) = e.args
1578         # ignore already exists
1579         if status != 68:
1580             raise
1581     try:
1582         ldb.add({"dn": ou,
1583                  "objectclass": "organizationalunit"})
1584     except LdbError as e:
1585         (status, _) = e.args
1586         # ignore already exists
1587         if status != 68:
1588             raise
1589     return ou
1590
1591
1592 class ConversationAccounts(object):
1593     """Details of the machine and user accounts associated with a conversation.
1594     """
1595     def __init__(self, netbios_name, machinepass, username, userpass):
1596         self.netbios_name = netbios_name
1597         self.machinepass  = machinepass
1598         self.username     = username
1599         self.userpass     = userpass
1600
1601
1602 def generate_replay_accounts(ldb, instance_id, number, password):
1603     """Generate a series of unique machine and user account names."""
1604
1605     generate_traffic_accounts(ldb, instance_id, number, password)
1606     accounts = []
1607     for i in range(1, number + 1):
1608         netbios_name = "STGM-%d-%d" % (instance_id, i)
1609         username     = "STGU-%d-%d" % (instance_id, i)
1610
1611         account = ConversationAccounts(netbios_name, password, username,
1612                                        password)
1613         accounts.append(account)
1614     return accounts
1615
1616
1617 def generate_traffic_accounts(ldb, instance_id, number, password):
1618     """Create the specified number of user and machine accounts.
1619
1620     As accounts are not explicitly deleted between runs. This function starts
1621     with the last account and iterates backwards stopping either when it
1622     finds an already existing account or it has generated all the required
1623     accounts.
1624     """
1625     print(("Generating machine and conversation accounts, "
1626            "as required for %d conversations" % number),
1627           file=sys.stderr)
1628     added = 0
1629     for i in range(number, 0, -1):
1630         try:
1631             netbios_name = "STGM-%d-%d" % (instance_id, i)
1632             create_machine_account(ldb, instance_id, netbios_name, password)
1633             added += 1
1634         except LdbError as e:
1635             (status, _) = e.args
1636             if status == 68:
1637                 break
1638             else:
1639                 raise
1640     if added > 0:
1641         print("Added %d new machine accounts" % added,
1642               file=sys.stderr)
1643
1644     added = 0
1645     for i in range(number, 0, -1):
1646         try:
1647             username = "STGU-%d-%d" % (instance_id, i)
1648             create_user_account(ldb, instance_id, username, password)
1649             added += 1
1650         except LdbError as e:
1651             (status, _) = e.args
1652             if status == 68:
1653                 break
1654             else:
1655                 raise
1656
1657     if added > 0:
1658         print("Added %d new user accounts" % added,
1659               file=sys.stderr)
1660
1661
1662 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1663     """Create a machine account via ldap."""
1664
1665     ou = ou_name(ldb, instance_id)
1666     dn = "cn=%s,%s" % (netbios_name, ou)
1667     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1668
1669     start = time.time()
1670     ldb.add({
1671         "dn": dn,
1672         "objectclass": "computer",
1673         "sAMAccountName": "%s$" % netbios_name,
1674         "userAccountControl":
1675             str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
1676         "unicodePwd": utf16pw})
1677     end = time.time()
1678     duration = end - start
1679     LOGGER.info("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1680
1681
1682 def create_user_account(ldb, instance_id, username, userpass):
1683     """Create a user account via ldap."""
1684     ou = ou_name(ldb, instance_id)
1685     user_dn = "cn=%s,%s" % (username, ou)
1686     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1687     start = time.time()
1688     ldb.add({
1689         "dn": user_dn,
1690         "objectclass": "user",
1691         "sAMAccountName": username,
1692         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1693         "unicodePwd": utf16pw
1694     })
1695
1696     # grant user write permission to do things like write account SPN
1697     sdutils = sd_utils.SDUtils(ldb)
1698     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1699
1700     end = time.time()
1701     duration = end - start
1702     LOGGER.info("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1703
1704
1705 def create_group(ldb, instance_id, name):
1706     """Create a group via ldap."""
1707
1708     ou = ou_name(ldb, instance_id)
1709     dn = "cn=%s,%s" % (name, ou)
1710     start = time.time()
1711     ldb.add({
1712         "dn": dn,
1713         "objectclass": "group",
1714         "sAMAccountName": name,
1715     })
1716     end = time.time()
1717     duration = end - start
1718     LOGGER.info("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1719
1720
1721 def user_name(instance_id, i):
1722     """Generate a user name based in the instance id"""
1723     return "STGU-%d-%d" % (instance_id, i)
1724
1725
1726 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1727     """Seach objectclass, return attr in a set"""
1728     objs = ldb.search(
1729         expression="(objectClass={})".format(objectclass),
1730         attrs=[attr]
1731     )
1732     return {str(obj[attr]) for obj in objs}
1733
1734
1735 def generate_users(ldb, instance_id, number, password):
1736     """Add users to the server"""
1737     existing_objects = search_objectclass(ldb, objectclass='user')
1738     users = 0
1739     for i in range(number, 0, -1):
1740         name = user_name(instance_id, i)
1741         if name not in existing_objects:
1742             create_user_account(ldb, instance_id, name, password)
1743             users += 1
1744
1745     return users
1746
1747
1748 def group_name(instance_id, i):
1749     """Generate a group name from instance id."""
1750     return "STGG-%d-%d" % (instance_id, i)
1751
1752
1753 def generate_groups(ldb, instance_id, number):
1754     """Create the required number of groups on the server."""
1755     existing_objects = search_objectclass(ldb, objectclass='group')
1756     groups = 0
1757     for i in range(number, 0, -1):
1758         name = group_name(instance_id, i)
1759         if name not in existing_objects:
1760             create_group(ldb, instance_id, name)
1761             groups += 1
1762
1763     return groups
1764
1765
1766 def clean_up_accounts(ldb, instance_id):
1767     """Remove the created accounts and groups from the server."""
1768     ou = ou_name(ldb, instance_id)
1769     try:
1770         ldb.delete(ou, ["tree_delete:1"])
1771     except LdbError as e:
1772         (status, _) = e.args
1773         # ignore does not exist
1774         if status != 32:
1775             raise
1776
1777
1778 def generate_users_and_groups(ldb, instance_id, password,
1779                               number_of_users, number_of_groups,
1780                               group_memberships):
1781     """Generate the required users and groups, allocating the users to
1782        those groups."""
1783     memberships_added = 0
1784     groups_added  = 0
1785
1786     create_ou(ldb, instance_id)
1787
1788     print("Generating dummy user accounts", file=sys.stderr)
1789     users_added = generate_users(ldb, instance_id, number_of_users, password)
1790
1791     if number_of_groups > 0:
1792         print("Generating dummy groups", file=sys.stderr)
1793         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1794
1795     if group_memberships > 0:
1796         print("Assigning users to groups", file=sys.stderr)
1797         assignments = GroupAssignments(number_of_groups,
1798                                        groups_added,
1799                                        number_of_users,
1800                                        users_added,
1801                                        group_memberships)
1802         print("Adding users to groups", file=sys.stderr)
1803         add_users_to_groups(ldb, instance_id, assignments.assignments)
1804         memberships_added = assignments.total()
1805
1806     if (groups_added > 0 and users_added == 0 and
1807        number_of_groups != groups_added):
1808         print("Warning: the added groups will contain no members",
1809               file=sys.stderr)
1810
1811     print(("Added %d users, %d groups and %d group memberships" %
1812            (users_added, groups_added, memberships_added)),
1813           file=sys.stderr)
1814
1815
1816 class GroupAssignments(object):
1817     def __init__(self, number_of_groups, groups_added, number_of_users,
1818                  users_added, group_memberships):
1819
1820         self.generate_group_distribution(number_of_groups)
1821         self.generate_user_distribution(number_of_users, group_memberships)
1822         self.assignments = self.assign_groups(number_of_groups,
1823                                               groups_added,
1824                                               number_of_users,
1825                                               users_added,
1826                                               group_memberships)
1827
1828     def cumulative_distribution(self, weights):
1829         # make sure the probabilities conform to a cumulative distribution
1830         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1831         # probability a proportional share of 1.0. Higher probabilities get a
1832         # bigger share, so are more likely to be picked. We use the cumulative
1833         # value, so we can use random.random() as a simple index into the list
1834         dist = []
1835         total = sum(weights)
1836         cumulative = 0.0
1837         for probability in weights:
1838             cumulative += probability
1839             dist.append(cumulative / total)
1840         return dist
1841
1842     def generate_user_distribution(self, num_users, num_memberships):
1843         """Probability distribution of a user belonging to a group.
1844         """
1845         # Assign a weighted probability to each user. Use the Pareto
1846         # Distribution so that some users are in a lot of groups, and the
1847         # bulk of users are in only a few groups. If we're assigning a large
1848         # number of group memberships, use a higher shape. This means slightly
1849         # fewer outlying users that are in large numbers of groups. The aim is
1850         # to have no users belonging to more than ~500 groups.
1851         if num_memberships > 5000000:
1852             shape = 3.0
1853         elif num_memberships > 2000000:
1854             shape = 2.5
1855         elif num_memberships > 300000:
1856             shape = 2.25
1857         else:
1858             shape = 1.75
1859
1860         weights = []
1861         for x in range(1, num_users + 1):
1862             p = random.paretovariate(shape)
1863             weights.append(p)
1864
1865         # convert the weights to a cumulative distribution between 0.0 and 1.0
1866         self.user_dist = self.cumulative_distribution(weights)
1867
1868     def generate_group_distribution(self, n):
1869         """Probability distribution of a group containing a user."""
1870
1871         # Assign a weighted probability to each user. Probability decreases
1872         # as the group-ID increases
1873         weights = []
1874         for x in range(1, n + 1):
1875             p = 1 / (x**1.3)
1876             weights.append(p)
1877
1878         # convert the weights to a cumulative distribution between 0.0 and 1.0
1879         self.group_dist = self.cumulative_distribution(weights)
1880
1881     def generate_random_membership(self):
1882         """Returns a randomly generated user-group membership"""
1883
1884         # the list items are cumulative distribution values between 0.0 and
1885         # 1.0, which makes random() a handy way to index the list to get a
1886         # weighted random user/group. (Here the user/group returned are
1887         # zero-based array indexes)
1888         user = bisect.bisect(self.user_dist, random.random())
1889         group = bisect.bisect(self.group_dist, random.random())
1890
1891         return user, group
1892
1893     def assign_groups(self, number_of_groups, groups_added,
1894                       number_of_users, users_added, group_memberships):
1895         """Allocate users to groups.
1896
1897         The intention is to have a few users that belong to most groups, while
1898         the majority of users belong to a few groups.
1899
1900         A few groups will contain most users, with the remaining only having a
1901         few users.
1902         """
1903
1904         assignments = set()
1905         if group_memberships <= 0:
1906             return assignments
1907
1908         # Calculate the number of group menberships required
1909         group_memberships = math.ceil(
1910             float(group_memberships) *
1911             (float(users_added) / float(number_of_users)))
1912
1913         existing_users  = number_of_users  - users_added  - 1
1914         existing_groups = number_of_groups - groups_added - 1
1915         while len(assignments) < group_memberships:
1916             user, group = self.generate_random_membership()
1917
1918             if group > existing_groups or user > existing_users:
1919                 # the + 1 converts the array index to the corresponding
1920                 # group or user number
1921                 assignments.add(((user + 1), (group + 1)))
1922
1923         return assignments
1924
1925     def total(self):
1926         return len(self.assignments)
1927
1928
1929 def add_users_to_groups(db, instance_id, assignments):
1930     """Add users to their assigned groups.
1931
1932     Takes the list of (group,user) tuples generated by assign_groups and
1933     assign the users to their specified groups."""
1934
1935     ou = ou_name(db, instance_id)
1936
1937     def build_dn(name):
1938         return("cn=%s,%s" % (name, ou))
1939
1940     for (user, group) in assignments:
1941         user_dn  = build_dn(user_name(instance_id, user))
1942         group_dn = build_dn(group_name(instance_id, group))
1943
1944         m = ldb.Message()
1945         m.dn = ldb.Dn(db, group_dn)
1946         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1947         start = time.time()
1948         db.modify(m)
1949         end = time.time()
1950         duration = end - start
1951         LOGGER.info("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1952
1953
1954 def generate_stats(statsdir, timing_file):
1955     """Generate and print the summary stats for a run."""
1956     first      = sys.float_info.max
1957     last       = 0
1958     successful = 0
1959     failed     = 0
1960     latencies  = {}
1961     failures   = {}
1962     unique_converations = set()
1963     conversations = 0
1964
1965     if timing_file is not None:
1966         tw = timing_file.write
1967     else:
1968         def tw(x):
1969             pass
1970
1971     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1972
1973     for filename in os.listdir(statsdir):
1974         path = os.path.join(statsdir, filename)
1975         with open(path, 'r') as f:
1976             for line in f:
1977                 try:
1978                     fields       = line.rstrip('\n').split('\t')
1979                     conversation = fields[1]
1980                     protocol     = fields[2]
1981                     packet_type  = fields[3]
1982                     latency      = float(fields[4])
1983                     first        = min(float(fields[0]) - latency, first)
1984                     last         = max(float(fields[0]), last)
1985
1986                     if protocol not in latencies:
1987                         latencies[protocol] = {}
1988                     if packet_type not in latencies[protocol]:
1989                         latencies[protocol][packet_type] = []
1990
1991                     latencies[protocol][packet_type].append(latency)
1992
1993                     if protocol not in failures:
1994                         failures[protocol] = {}
1995                     if packet_type not in failures[protocol]:
1996                         failures[protocol][packet_type] = 0
1997
1998                     if fields[5] == 'True':
1999                         successful += 1
2000                     else:
2001                         failed += 1
2002                         failures[protocol][packet_type] += 1
2003
2004                     if conversation not in unique_converations:
2005                         unique_converations.add(conversation)
2006                         conversations += 1
2007
2008                     tw(line)
2009                 except (ValueError, IndexError):
2010                     # not a valid line print and ignore
2011                     print(line, file=sys.stderr)
2012                     pass
2013     duration = last - first
2014     if successful == 0:
2015         success_rate = 0
2016     else:
2017         success_rate = successful / duration
2018     if failed == 0:
2019         failure_rate = 0
2020     else:
2021         failure_rate = failed / duration
2022
2023     print("Total conversations:   %10d" % conversations)
2024     print("Successful operations: %10d (%.3f per second)"
2025           % (successful, success_rate))
2026     print("Failed operations:     %10d (%.3f per second)"
2027           % (failed, failure_rate))
2028
2029     print("Protocol    Op Code  Description                               "
2030           " Count       Failed         Mean       Median          "
2031           "95%        Range          Max")
2032
2033     protocols = sorted(latencies.keys())
2034     for protocol in protocols:
2035         packet_types = sorted(latencies[protocol], key=opcode_key)
2036         for packet_type in packet_types:
2037             values     = latencies[protocol][packet_type]
2038             values     = sorted(values)
2039             count      = len(values)
2040             failed     = failures[protocol][packet_type]
2041             mean       = sum(values) / count
2042             median     = calc_percentile(values, 0.50)
2043             percentile = calc_percentile(values, 0.95)
2044             rng        = values[-1] - values[0]
2045             maxv       = values[-1]
2046             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2047             if sys.stdout.isatty:
2048                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2049                       "%12.6f %12.6f %12.6f %12.6f"
2050                       % (protocol,
2051                          packet_type,
2052                          desc,
2053                          count,
2054                          failed,
2055                          mean,
2056                          median,
2057                          percentile,
2058                          rng,
2059                          maxv))
2060             else:
2061                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2062                       % (protocol,
2063                          packet_type,
2064                          desc,
2065                          count,
2066                          failed,
2067                          mean,
2068                          median,
2069                          percentile,
2070                          rng,
2071                          maxv))
2072
2073
2074 def opcode_key(v):
2075     """Sort key for the operation code to ensure that it sorts numerically"""
2076     try:
2077         return "%03d" % int(v)
2078     except:
2079         return v
2080
2081
2082 def calc_percentile(values, percentile):
2083     """Calculate the specified percentile from the list of values.
2084
2085     Assumes the list is sorted in ascending order.
2086     """
2087
2088     if not values:
2089         return 0
2090     k = (len(values) - 1) * percentile
2091     f = math.floor(k)
2092     c = math.ceil(k)
2093     if f == c:
2094         return values[int(k)]
2095     d0 = values[int(f)] * (c - k)
2096     d1 = values[int(c)] * (k - f)
2097     return d0 + d1
2098
2099
2100 def mk_masked_dir(*path):
2101     """In a testenv we end up with 0777 diectories that look an alarming
2102     green colour with ls. Use umask to avoid that."""
2103     d = os.path.join(*path)
2104     mask = os.umask(0o077)
2105     os.mkdir(d)
2106     os.umask(mask)
2107     return d