49ad49a55cc34b9c79119a494395a45103adb6f5
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53
54 SLEEP_OVERHEAD = 3e-4
55
56 # we don't use None, because it complicates [de]serialisation
57 NON_PACKET = '-'
58
59 CLIENT_CLUES = {
60     ('dns', '0'): 1.0,      # query
61     ('smb', '0x72'): 1.0,   # Negotiate protocol
62     ('ldap', '0'): 1.0,     # bind
63     ('ldap', '3'): 1.0,     # searchRequest
64     ('ldap', '2'): 1.0,     # unbindRequest
65     ('cldap', '3'): 1.0,
66     ('dcerpc', '11'): 1.0,  # bind
67     ('dcerpc', '14'): 1.0,  # Alter_context
68     ('nbns', '0'): 1.0,     # query
69 }
70
71 SERVER_CLUES = {
72     ('dns', '1'): 1.0,      # response
73     ('ldap', '1'): 1.0,     # bind response
74     ('ldap', '4'): 1.0,     # search result
75     ('ldap', '5'): 1.0,     # search done
76     ('cldap', '5'): 1.0,
77     ('dcerpc', '12'): 1.0,  # bind_ack
78     ('dcerpc', '13'): 1.0,  # bind_nak
79     ('dcerpc', '15'): 1.0,  # Alter_context response
80 }
81
82 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
83
84 WAIT_SCALE = 10.0
85 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
86 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
87
88 # DEBUG_LEVEL can be changed by scripts with -d
89 DEBUG_LEVEL = 0
90
91
92 def debug(level, msg, *args):
93     """Print a formatted debug message to standard error.
94
95
96     :param level: The debug level, message will be printed if it is <= the
97                   currently set debug level. The debug level can be set with
98                   the -d option.
99     :param msg:   The message to be logged, can contain C-Style format
100                   specifiers
101     :param args:  The parameters required by the format specifiers
102     """
103     if level <= DEBUG_LEVEL:
104         if not args:
105             print(msg, file=sys.stderr)
106         else:
107             print(msg % tuple(args), file=sys.stderr)
108
109
110 def debug_lineno(*args):
111     """ Print an unformatted log message to stderr, contaning the line number
112     """
113     tb = traceback.extract_stack(limit=2)
114     print((" %s:" "\033[01;33m"
115            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
116           file=sys.stderr)
117     for a in args:
118         print(a, file=sys.stderr)
119     print(file=sys.stderr)
120     sys.stderr.flush()
121
122
123 def random_colour_print():
124     """Return a function that prints a randomly coloured line to stderr"""
125     n = 18 + random.randrange(214)
126     prefix = "\033[38;5;%dm" % n
127
128     def p(*args):
129         for a in args:
130             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
131
132     return p
133
134
135 class FakePacketError(Exception):
136     pass
137
138
139 class Packet(object):
140     """Details of a network packet"""
141     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
142                  protocol, opcode, desc, extra):
143
144         self.timestamp = timestamp
145         self.ip_protocol = ip_protocol
146         self.stream_number = stream_number
147         self.src = src
148         self.dest = dest
149         self.protocol = protocol
150         self.opcode = opcode
151         self.desc = desc
152         self.extra = extra
153         if self.src < self.dest:
154             self.endpoints = (self.src, self.dest)
155         else:
156             self.endpoints = (self.dest, self.src)
157
158     @classmethod
159     def from_line(self, line):
160         fields = line.rstrip('\n').split('\t')
161         (timestamp,
162          ip_protocol,
163          stream_number,
164          src,
165          dest,
166          protocol,
167          opcode,
168          desc) = fields[:8]
169         extra = fields[8:]
170
171         timestamp = float(timestamp)
172         src = int(src)
173         dest = int(dest)
174
175         return Packet(timestamp, ip_protocol, stream_number, src, dest,
176                       protocol, opcode, desc, extra)
177
178     def as_summary(self, time_offset=0.0):
179         """Format the packet as a traffic_summary line.
180         """
181         extra = '\t'.join(self.extra)
182         t = self.timestamp + time_offset
183         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
184                 (t,
185                  self.ip_protocol,
186                  self.stream_number or '',
187                  self.src,
188                  self.dest,
189                  self.protocol,
190                  self.opcode,
191                  self.desc,
192                  extra))
193
194     def __str__(self):
195         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
196                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
197                  self.stream_number, self.protocol, self.opcode, self.desc,
198                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
199
200     def __repr__(self):
201         return "<Packet @%s>" % self
202
203     def copy(self):
204         return self.__class__(self.timestamp,
205                               self.ip_protocol,
206                               self.stream_number,
207                               self.src,
208                               self.dest,
209                               self.protocol,
210                               self.opcode,
211                               self.desc,
212                               self.extra)
213
214     def as_packet_type(self):
215         t = '%s:%s' % (self.protocol, self.opcode)
216         return t
217
218     def client_score(self):
219         """A positive number means we think it is a client; a negative number
220         means we think it is a server. Zero means no idea. range: -1 to 1.
221         """
222         key = (self.protocol, self.opcode)
223         if key in CLIENT_CLUES:
224             return CLIENT_CLUES[key]
225         if key in SERVER_CLUES:
226             return -SERVER_CLUES[key]
227         return 0.0
228
229     def play(self, conversation, context):
230         """Send the packet over the network, if required.
231
232         Some packets are ignored, i.e. for  protocols not handled,
233         server response messages, or messages that are generated by the
234         protocol layer associated with other packets.
235         """
236         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
237         try:
238             fn = getattr(traffic_packets, fn_name)
239
240         except AttributeError as e:
241             print("Conversation(%s) Missing handler %s" % \
242                   (conversation.conversation_id, fn_name),
243                   file=sys.stderr)
244             return
245
246         # Don't display a message for kerberos packets, they're not directly
247         # generated they're used to indicate kerberos should be used
248         if self.protocol != "kerberos":
249             debug(2, "Conversation(%s) Calling handler %s" %
250                      (conversation.conversation_id, fn_name))
251
252         start = time.time()
253         try:
254             if fn(self, conversation, context):
255                 # Only collect timing data for functions that generate
256                 # network traffic, or fail
257                 end = time.time()
258                 duration = end - start
259                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
260                       (end, conversation.conversation_id, self.protocol,
261                        self.opcode, duration))
262         except Exception as e:
263             end = time.time()
264             duration = end - start
265             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
266                   (end, conversation.conversation_id, self.protocol,
267                    self.opcode, duration, e))
268
269     def __cmp__(self, other):
270         return self.timestamp - other.timestamp
271
272     def is_really_a_packet(self, missing_packet_stats=None):
273         """Is the packet one that can be ignored?
274
275         If so removing it will have no effect on the replay
276         """
277         if self.protocol in SKIPPED_PROTOCOLS:
278             # Ignore any packets for the protocols we're not interested in.
279             return False
280         if self.protocol == "ldap" and self.opcode == '':
281             # skip ldap continuation packets
282             return False
283
284         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
285         fn = getattr(traffic_packets, fn_name, None)
286         if not fn:
287             print("missing packet %s" % fn_name, file=sys.stderr)
288             return False
289         if fn is traffic_packets.null_packet:
290             return False
291         return True
292
293
294 class ReplayContext(object):
295     """State/Context for an individual conversation between an simulated client
296        and a server.
297     """
298
299     def __init__(self,
300                  server=None,
301                  lp=None,
302                  creds=None,
303                  badpassword_frequency=None,
304                  prefer_kerberos=None,
305                  tempdir=None,
306                  statsdir=None,
307                  ou=None,
308                  base_dn=None,
309                  domain=None,
310                  domain_sid=None):
311
312         self.server                   = server
313         self.ldap_connections         = []
314         self.dcerpc_connections       = []
315         self.lsarpc_connections       = []
316         self.lsarpc_connections_named = []
317         self.drsuapi_connections      = []
318         self.srvsvc_connections       = []
319         self.samr_contexts            = []
320         self.netlogon_connection      = None
321         self.creds                    = creds
322         self.lp                       = lp
323         self.prefer_kerberos          = prefer_kerberos
324         self.ou                       = ou
325         self.base_dn                  = base_dn
326         self.domain                   = domain
327         self.statsdir                 = statsdir
328         self.global_tempdir           = tempdir
329         self.domain_sid               = domain_sid
330         self.realm                    = lp.get('realm')
331
332         # Bad password attempt controls
333         self.badpassword_frequency    = badpassword_frequency
334         self.last_lsarpc_bad          = False
335         self.last_lsarpc_named_bad    = False
336         self.last_simple_bind_bad     = False
337         self.last_bind_bad            = False
338         self.last_srvsvc_bad          = False
339         self.last_drsuapi_bad         = False
340         self.last_netlogon_bad        = False
341         self.last_samlogon_bad        = False
342         self.generate_ldap_search_tables()
343         self.next_conversation_id = itertools.count()
344
345     def generate_ldap_search_tables(self):
346         session = system_session()
347
348         db = SamDB(url="ldap://%s" % self.server,
349                    session_info=session,
350                    credentials=self.creds,
351                    lp=self.lp)
352
353         res = db.search(db.domain_dn(),
354                         scope=ldb.SCOPE_SUBTREE,
355                         controls=["paged_results:1:1000"],
356                         attrs=['dn'])
357
358         # find a list of dns for each pattern
359         # e.g. CN,CN,CN,DC,DC
360         dn_map = {}
361         attribute_clue_map = {
362             'invocationId': []
363         }
364
365         for r in res:
366             dn = str(r.dn)
367             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
368             dns = dn_map.setdefault(pattern, [])
369             dns.append(dn)
370             if dn.startswith('CN=NTDS Settings,'):
371                 attribute_clue_map['invocationId'].append(dn)
372
373         # extend the map in case we are working with a different
374         # number of DC components.
375         # for k, v in self.dn_map.items():
376         #     print >>sys.stderr, k, len(v)
377
378         for k in list(dn_map.keys()):
379             if k[-3:] != ',DC':
380                 continue
381             p = k[:-3]
382             while p[-3:] == ',DC':
383                 p = p[:-3]
384             for i in range(5):
385                 p += ',DC'
386                 if p != k and p in dn_map:
387                     print('dn_map collison %s %s' % (k, p),
388                           file=sys.stderr)
389                     continue
390                 dn_map[p] = dn_map[k]
391
392         self.dn_map = dn_map
393         self.attribute_clue_map = attribute_clue_map
394
395     def generate_process_local_config(self, account, conversation):
396         if account is None:
397             return
398         self.netbios_name             = account.netbios_name
399         self.machinepass              = account.machinepass
400         self.username                 = account.username
401         self.userpass                 = account.userpass
402
403         self.tempdir = mk_masked_dir(self.global_tempdir,
404                                      'conversation-%d' %
405                                      conversation.conversation_id)
406
407         self.lp.set("private dir",     self.tempdir)
408         self.lp.set("lock dir",        self.tempdir)
409         self.lp.set("state directory", self.tempdir)
410         self.lp.set("tls verify peer", "no_check")
411
412         # If the domain was not specified, check for the environment
413         # variable.
414         if self.domain is None:
415             self.domain = os.environ["DOMAIN"]
416
417         self.remoteAddress = "/root/ncalrpc_as_system"
418         self.samlogon_dn   = ("cn=%s,%s" %
419                               (self.netbios_name, self.ou))
420         self.user_dn       = ("cn=%s,%s" %
421                               (self.username, self.ou))
422
423         self.generate_machine_creds()
424         self.generate_user_creds()
425
426     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
427         """Execute the supplied logon function, randomly choosing the
428            bad credentials.
429
430            Based on the frequency in badpassword_frequency randomly perform the
431            function with the supplied bad credentials.
432            If run with bad credentials, the function is re-run with the good
433            credentials.
434            failed_last_time is used to prevent consecutive bad credential
435            attempts. So the over all bad credential frequency will be lower
436            than that requested, but not significantly.
437         """
438         if not failed_last_time:
439             if (self.badpassword_frequency > 0 and
440                random.random() < self.badpassword_frequency):
441                 try:
442                     f(bad)
443                 except:
444                     # Ignore any exceptions as the operation may fail
445                     # as it's being performed with bad credentials
446                     pass
447                 failed_last_time = True
448             else:
449                 failed_last_time = False
450
451         result = f(good)
452         return (result, failed_last_time)
453
454     def generate_user_creds(self):
455         """Generate the conversation specific user Credentials.
456
457         Each Conversation has an associated user account used to simulate
458         any non Administrative user traffic.
459
460         Generates user credentials with good and bad passwords and ldap
461         simple bind credentials with good and bad passwords.
462         """
463         self.user_creds = Credentials()
464         self.user_creds.guess(self.lp)
465         self.user_creds.set_workstation(self.netbios_name)
466         self.user_creds.set_password(self.userpass)
467         self.user_creds.set_username(self.username)
468         self.user_creds.set_domain(self.domain)
469         if self.prefer_kerberos:
470             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
471         else:
472             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
473
474         self.user_creds_bad = Credentials()
475         self.user_creds_bad.guess(self.lp)
476         self.user_creds_bad.set_workstation(self.netbios_name)
477         self.user_creds_bad.set_password(self.userpass[:-4])
478         self.user_creds_bad.set_username(self.username)
479         if self.prefer_kerberos:
480             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
481         else:
482             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
483
484         # Credentials for ldap simple bind.
485         self.simple_bind_creds = Credentials()
486         self.simple_bind_creds.guess(self.lp)
487         self.simple_bind_creds.set_workstation(self.netbios_name)
488         self.simple_bind_creds.set_password(self.userpass)
489         self.simple_bind_creds.set_username(self.username)
490         self.simple_bind_creds.set_gensec_features(
491             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
492         if self.prefer_kerberos:
493             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
494         else:
495             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
496         self.simple_bind_creds.set_bind_dn(self.user_dn)
497
498         self.simple_bind_creds_bad = Credentials()
499         self.simple_bind_creds_bad.guess(self.lp)
500         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
501         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
502         self.simple_bind_creds_bad.set_username(self.username)
503         self.simple_bind_creds_bad.set_gensec_features(
504             self.simple_bind_creds_bad.get_gensec_features() |
505             gensec.FEATURE_SEAL)
506         if self.prefer_kerberos:
507             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
508         else:
509             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
510         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
511
512     def generate_machine_creds(self):
513         """Generate the conversation specific machine Credentials.
514
515         Each Conversation has an associated machine account.
516
517         Generates machine credentials with good and bad passwords.
518         """
519
520         self.machine_creds = Credentials()
521         self.machine_creds.guess(self.lp)
522         self.machine_creds.set_workstation(self.netbios_name)
523         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
524         self.machine_creds.set_password(self.machinepass)
525         self.machine_creds.set_username(self.netbios_name + "$")
526         self.machine_creds.set_domain(self.domain)
527         if self.prefer_kerberos:
528             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
529         else:
530             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
531
532         self.machine_creds_bad = Credentials()
533         self.machine_creds_bad.guess(self.lp)
534         self.machine_creds_bad.set_workstation(self.netbios_name)
535         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
536         self.machine_creds_bad.set_password(self.machinepass[:-4])
537         self.machine_creds_bad.set_username(self.netbios_name + "$")
538         if self.prefer_kerberos:
539             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
540         else:
541             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
542
543     def get_matching_dn(self, pattern, attributes=None):
544         # If the pattern is an empty string, we assume ROOTDSE,
545         # Otherwise we try adding or removing DC suffixes, then
546         # shorter leading patterns until we hit one.
547         # e.g if there is no CN,CN,CN,CN,DC,DC
548         # we first try       CN,CN,CN,CN,DC
549         # and                CN,CN,CN,CN,DC,DC,DC
550         # then change to        CN,CN,CN,DC,DC
551         # and as last resort we use the base_dn
552         attr_clue = self.attribute_clue_map.get(attributes)
553         if attr_clue:
554             return random.choice(attr_clue)
555
556         pattern = pattern.upper()
557         while pattern:
558             if pattern in self.dn_map:
559                 return random.choice(self.dn_map[pattern])
560             # chop one off the front and try it all again.
561             pattern = pattern[3:]
562
563         return self.base_dn
564
565     def get_dcerpc_connection(self, new=False):
566         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
567         if self.dcerpc_connections and not new:
568             return self.dcerpc_connections[-1]
569         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
570                              (guid, 1), self.lp)
571         self.dcerpc_connections.append(c)
572         return c
573
574     def get_srvsvc_connection(self, new=False):
575         if self.srvsvc_connections and not new:
576             return self.srvsvc_connections[-1]
577
578         def connect(creds):
579             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
580                                  self.lp,
581                                  creds)
582
583         (c, self.last_srvsvc_bad) = \
584             self.with_random_bad_credentials(connect,
585                                              self.user_creds,
586                                              self.user_creds_bad,
587                                              self.last_srvsvc_bad)
588
589         self.srvsvc_connections.append(c)
590         return c
591
592     def get_lsarpc_connection(self, new=False):
593         if self.lsarpc_connections and not new:
594             return self.lsarpc_connections[-1]
595
596         def connect(creds):
597             binding_options = 'schannel,seal,sign'
598             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
599                               (self.server, binding_options),
600                               self.lp,
601                               creds)
602
603         (c, self.last_lsarpc_bad) = \
604             self.with_random_bad_credentials(connect,
605                                              self.machine_creds,
606                                              self.machine_creds_bad,
607                                              self.last_lsarpc_bad)
608
609         self.lsarpc_connections.append(c)
610         return c
611
612     def get_lsarpc_named_pipe_connection(self, new=False):
613         if self.lsarpc_connections_named and not new:
614             return self.lsarpc_connections_named[-1]
615
616         def connect(creds):
617             return lsa.lsarpc("ncacn_np:%s" % (self.server),
618                               self.lp,
619                               creds)
620
621         (c, self.last_lsarpc_named_bad) = \
622             self.with_random_bad_credentials(connect,
623                                              self.machine_creds,
624                                              self.machine_creds_bad,
625                                              self.last_lsarpc_named_bad)
626
627         self.lsarpc_connections_named.append(c)
628         return c
629
630     def get_drsuapi_connection_pair(self, new=False, unbind=False):
631         """get a (drs, drs_handle) tuple"""
632         if self.drsuapi_connections and not new:
633             c = self.drsuapi_connections[-1]
634             return c
635
636         def connect(creds):
637             binding_options = 'seal'
638             binding_string = "ncacn_ip_tcp:%s[%s]" %\
639                              (self.server, binding_options)
640             return drsuapi.drsuapi(binding_string, self.lp, creds)
641
642         (drs, self.last_drsuapi_bad) = \
643             self.with_random_bad_credentials(connect,
644                                              self.user_creds,
645                                              self.user_creds_bad,
646                                              self.last_drsuapi_bad)
647
648         (drs_handle, supported_extensions) = drs_DsBind(drs)
649         c = (drs, drs_handle)
650         self.drsuapi_connections.append(c)
651         return c
652
653     def get_ldap_connection(self, new=False, simple=False):
654         if self.ldap_connections and not new:
655             return self.ldap_connections[-1]
656
657         def simple_bind(creds):
658             """
659             To run simple bind against Windows, we need to run
660             following commands in PowerShell:
661
662                 Install-windowsfeature ADCS-Cert-Authority
663                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
664                 Restart-Computer
665
666             """
667             return SamDB('ldaps://%s' % self.server,
668                          credentials=creds,
669                          lp=self.lp)
670
671         def sasl_bind(creds):
672             return SamDB('ldap://%s' % self.server,
673                          credentials=creds,
674                          lp=self.lp)
675         if simple:
676             (samdb, self.last_simple_bind_bad) = \
677                 self.with_random_bad_credentials(simple_bind,
678                                                  self.simple_bind_creds,
679                                                  self.simple_bind_creds_bad,
680                                                  self.last_simple_bind_bad)
681         else:
682             (samdb, self.last_bind_bad) = \
683                 self.with_random_bad_credentials(sasl_bind,
684                                                  self.user_creds,
685                                                  self.user_creds_bad,
686                                                  self.last_bind_bad)
687
688         self.ldap_connections.append(samdb)
689         return samdb
690
691     def get_samr_context(self, new=False):
692         if not self.samr_contexts or new:
693             self.samr_contexts.append(
694                 SamrContext(self.server, lp=self.lp, creds=self.creds))
695         return self.samr_contexts[-1]
696
697     def get_netlogon_connection(self):
698
699         if self.netlogon_connection:
700             return self.netlogon_connection
701
702         def connect(creds):
703             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
704                                      (self.server),
705                                      self.lp,
706                                      creds)
707         (c, self.last_netlogon_bad) = \
708             self.with_random_bad_credentials(connect,
709                                              self.machine_creds,
710                                              self.machine_creds_bad,
711                                              self.last_netlogon_bad)
712         self.netlogon_connection = c
713         return c
714
715     def guess_a_dns_lookup(self):
716         return (self.realm, 'A')
717
718     def get_authenticator(self):
719         auth = self.machine_creds.new_client_authenticator()
720         current  = netr_Authenticator()
721         current.cred.data = [ord(x) for x in auth["credential"]]
722         current.timestamp = auth["timestamp"]
723
724         subsequent = netr_Authenticator()
725         return (current, subsequent)
726
727
728 class SamrContext(object):
729     """State/Context associated with a samr connection.
730     """
731     def __init__(self, server, lp=None, creds=None):
732         self.connection    = None
733         self.handle        = None
734         self.domain_handle = None
735         self.domain_sid    = None
736         self.group_handle  = None
737         self.user_handle   = None
738         self.rids          = None
739         self.server        = server
740         self.lp            = lp
741         self.creds         = creds
742
743     def get_connection(self):
744         if not self.connection:
745             self.connection = samr.samr(
746                 "ncacn_ip_tcp:%s[seal]" % (self.server),
747                 lp_ctx=self.lp,
748                 credentials=self.creds)
749
750         return self.connection
751
752     def get_handle(self):
753         if not self.handle:
754             c = self.get_connection()
755             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
756         return self.handle
757
758
759 class Conversation(object):
760     """Details of a converation between a simulated client and a server."""
761     conversation_id = None
762
763     def __init__(self, start_time=None, endpoints=None):
764         self.start_time = start_time
765         self.endpoints = endpoints
766         self.packets = []
767         self.msg = random_colour_print()
768         self.client_balance = 0.0
769
770     def __cmp__(self, other):
771         if self.start_time is None:
772             if other.start_time is None:
773                 return 0
774             return -1
775         if other.start_time is None:
776             return 1
777         return self.start_time - other.start_time
778
779     def add_packet(self, packet):
780         """Add a packet object to this conversation, making a local copy with
781         a conversation-relative timestamp."""
782         p = packet.copy()
783
784         if self.start_time is None:
785             self.start_time = p.timestamp
786
787         if self.endpoints is None:
788             self.endpoints = p.endpoints
789
790         if p.endpoints != self.endpoints:
791             raise FakePacketError("Conversation endpoints %s don't match"
792                                   "packet endpoints %s" %
793                                   (self.endpoints, p.endpoints))
794
795         p.timestamp -= self.start_time
796
797         if p.src == p.endpoints[0]:
798             self.client_balance -= p.client_score()
799         else:
800             self.client_balance += p.client_score()
801
802         if p.is_really_a_packet():
803             self.packets.append(p)
804
805     def add_short_packet(self, timestamp, protocol, opcode, extra,
806                          client=True):
807         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
808         (possibly empty) list of extra data. If client is True, assume
809         this packet is from the client to the server.
810         """
811         src, dest = self.guess_client_server()
812         if not client:
813             src, dest = dest, src
814         key = (protocol, opcode)
815         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
816         if protocol in IP_PROTOCOLS:
817             ip_protocol = IP_PROTOCOLS[protocol]
818         else:
819             ip_protocol = '06'
820         packet = Packet(timestamp - self.start_time, ip_protocol,
821                         '', src, dest,
822                         protocol, opcode, desc, extra)
823         # XXX we're assuming the timestamp is already adjusted for
824         # this conversation?
825         # XXX should we adjust client balance for guessed packets?
826         if packet.src == packet.endpoints[0]:
827             self.client_balance -= packet.client_score()
828         else:
829             self.client_balance += packet.client_score()
830         if packet.is_really_a_packet():
831             self.packets.append(packet)
832
833     def __str__(self):
834         return ("<Conversation %s %s starting %.3f %d packets>" %
835                 (self.conversation_id, self.endpoints, self.start_time,
836                  len(self.packets)))
837
838     __repr__ = __str__
839
840     def __iter__(self):
841         return iter(self.packets)
842
843     def __len__(self):
844         return len(self.packets)
845
846     def get_duration(self):
847         if len(self.packets) < 2:
848             return 0
849         return self.packets[-1].timestamp - self.packets[0].timestamp
850
851     def replay_as_summary_lines(self):
852         lines = []
853         for p in self.packets:
854             lines.append(p.as_summary(self.start_time))
855         return lines
856
857     def replay_in_fork_with_delay(self, start, context=None, account=None):
858         """Fork a new process and replay the conversation.
859         """
860         def signal_handler(signal, frame):
861             """Signal handler closes standard out and error.
862
863             Triggered by a sigterm, ensures that the log messages are flushed
864             to disk and not lost.
865             """
866             sys.stderr.close()
867             sys.stdout.close()
868             os._exit(0)
869
870         t = self.start_time
871         now = time.time() - start
872         gap = t - now
873         # we are replaying strictly in order, so it is safe to sleep
874         # in the main process if the gap is big enough. This reduces
875         # the number of concurrent threads, which allows us to make
876         # larger loads.
877         if gap > 0.15 and False:
878             print("sleeping for %f in main process" % (gap - 0.1),
879                   file=sys.stderr)
880             time.sleep(gap - 0.1)
881             now = time.time() - start
882             gap = t - now
883             print("gap is now %f" % gap, file=sys.stderr)
884
885         self.conversation_id = next(context.next_conversation_id)
886         pid = os.fork()
887         if pid != 0:
888             return pid
889         pid = os.getpid()
890         signal.signal(signal.SIGTERM, signal_handler)
891         # we must never return, or we'll end up running parts of the
892         # parent's clean-up code. So we work in a try...finally, and
893         # try to print any exceptions.
894
895         try:
896             context.generate_process_local_config(account, self)
897             sys.stdin.close()
898             os.close(0)
899             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
900                                     self.conversation_id)
901             sys.stdout.close()
902             sys.stdout = open(filename, 'w')
903
904             sleep_time = gap - SLEEP_OVERHEAD
905             if sleep_time > 0:
906                 time.sleep(sleep_time)
907
908             miss = t - (time.time() - start)
909             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
910             self.replay(context)
911         except Exception:
912             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
913                   file=sys.stderr)
914             traceback.print_exc(sys.stderr)
915         finally:
916             sys.stderr.close()
917             sys.stdout.close()
918             os._exit(0)
919
920     def replay(self, context=None):
921         start = time.time()
922
923         for p in self.packets:
924             now = time.time() - start
925             gap = p.timestamp - now
926             sleep_time = gap - SLEEP_OVERHEAD
927             if sleep_time > 0:
928                 time.sleep(sleep_time)
929
930             miss = p.timestamp - (time.time() - start)
931             if context is None:
932                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
933                                                            os.getpid()))
934                 continue
935             p.play(self, context)
936
937     def guess_client_server(self, server_clue=None):
938         """Have a go at deciding who is the server and who is the client.
939         returns (client, server)
940         """
941         a, b = self.endpoints
942
943         if self.client_balance < 0:
944             return (a, b)
945
946         # in the absense of a clue, we will fall through to assuming
947         # the lowest number is the server (which is usually true).
948
949         if self.client_balance == 0 and server_clue == b:
950             return (a, b)
951
952         return (b, a)
953
954     def forget_packets_outside_window(self, s, e):
955         """Prune any packets outside the timne window we're interested in
956
957         :param s: start of the window
958         :param e: end of the window
959         """
960         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
961         self.start_time = self.packets[0].timestamp if self.packets else None
962
963     def renormalise_times(self, start_time):
964         """Adjust the packet start times relative to the new start time."""
965         for p in self.packets:
966             p.timestamp -= start_time
967
968         if self.start_time is not None:
969             self.start_time -= start_time
970
971
972 class DnsHammer(Conversation):
973     """A lightweight conversation that generates a lot of dns:0 packets on
974     the fly"""
975
976     def __init__(self, dns_rate, duration):
977         n = int(dns_rate * duration)
978         self.times = [random.uniform(0, duration) for i in range(n)]
979         self.times.sort()
980         self.rate = dns_rate
981         self.duration = duration
982         self.start_time = 0
983         self.msg = random_colour_print()
984
985     def __str__(self):
986         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
987                 (len(self.times), self.duration, self.rate))
988
989     def replay_in_fork_with_delay(self, start, context=None, account=None):
990         return Conversation.replay_in_fork_with_delay(self,
991                                                       start,
992                                                       context,
993                                                       account)
994
995     def replay(self, context=None):
996         start = time.time()
997         fn = traffic_packets.packet_dns_0
998         for t in self.times:
999             now = time.time() - start
1000             gap = t - now
1001             sleep_time = gap - SLEEP_OVERHEAD
1002             if sleep_time > 0:
1003                 time.sleep(sleep_time)
1004
1005             if context is None:
1006                 miss = t - (time.time() - start)
1007                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1008                                                            os.getpid()))
1009                 continue
1010
1011             packet_start = time.time()
1012             try:
1013                 fn(self, self, context)
1014                 end = time.time()
1015                 duration = end - packet_start
1016                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1017             except Exception as e:
1018                 end = time.time()
1019                 duration = end - packet_start
1020                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1021
1022
1023 def ingest_summaries(files, dns_mode='count'):
1024     """Load a summary traffic summary file and generated Converations from it.
1025     """
1026
1027     dns_counts = defaultdict(int)
1028     packets = []
1029     for f in files:
1030         if isinstance(f, str):
1031             f = open(f)
1032         print("Ingesting %s" % (f.name,), file=sys.stderr)
1033         for line in f:
1034             p = Packet.from_line(line)
1035             if p.protocol == 'dns' and dns_mode != 'include':
1036                 dns_counts[p.opcode] += 1
1037             else:
1038                 packets.append(p)
1039
1040         f.close()
1041
1042     if not packets:
1043         return [], 0
1044
1045     start_time = min(p.timestamp for p in packets)
1046     last_packet = max(p.timestamp for p in packets)
1047
1048     print("gathering packets into conversations", file=sys.stderr)
1049     conversations = OrderedDict()
1050     for p in packets:
1051         p.timestamp -= start_time
1052         c = conversations.get(p.endpoints)
1053         if c is None:
1054             c = Conversation()
1055             conversations[p.endpoints] = c
1056         c.add_packet(p)
1057
1058     # We only care about conversations with actual traffic, so we
1059     # filter out conversations with nothing to say. We do that here,
1060     # rather than earlier, because those empty packets contain useful
1061     # hints as to which end of the conversation was the client.
1062     conversation_list = []
1063     for c in conversations.values():
1064         if len(c) != 0:
1065             conversation_list.append(c)
1066
1067     # This is obviously not correct, as many conversations will appear
1068     # to start roughly simultaneously at the beginning of the snapshot.
1069     # To which we say: oh well, so be it.
1070     duration = float(last_packet - start_time)
1071     mean_interval = len(conversations) / duration
1072
1073     return conversation_list, mean_interval, duration, dns_counts
1074
1075
1076 def guess_server_address(conversations):
1077     # we guess the most common address.
1078     addresses = Counter()
1079     for c in conversations:
1080         addresses.update(c.endpoints)
1081     if addresses:
1082         return addresses.most_common(1)[0]
1083
1084
1085 def stringify_keys(x):
1086     y = {}
1087     for k, v in x.items():
1088         k2 = '\t'.join(k)
1089         y[k2] = v
1090     return y
1091
1092
1093 def unstringify_keys(x):
1094     y = {}
1095     for k, v in x.items():
1096         t = tuple(str(k).split('\t'))
1097         y[t] = v
1098     return y
1099
1100
1101 class TrafficModel(object):
1102     def __init__(self, n=3):
1103         self.ngrams = {}
1104         self.query_details = {}
1105         self.n = n
1106         self.dns_opcounts = defaultdict(int)
1107         self.cumulative_duration = 0.0
1108         self.conversation_rate = [0, 1]
1109
1110     def learn(self, conversations, dns_opcounts={}):
1111         prev = 0.0
1112         cum_duration = 0.0
1113         key = (NON_PACKET,) * (self.n - 1)
1114
1115         server = guess_server_address(conversations)
1116
1117         for k, v in dns_opcounts.items():
1118             self.dns_opcounts[k] += v
1119
1120         if len(conversations) > 1:
1121             elapsed =\
1122                 conversations[-1].start_time - conversations[0].start_time
1123             self.conversation_rate[0] = len(conversations)
1124             self.conversation_rate[1] = elapsed
1125
1126         for c in conversations:
1127             client, server = c.guess_client_server(server)
1128             cum_duration += c.get_duration()
1129             key = (NON_PACKET,) * (self.n - 1)
1130             for p in c:
1131                 if p.src != client:
1132                     continue
1133
1134                 elapsed = p.timestamp - prev
1135                 prev = p.timestamp
1136                 if elapsed > WAIT_THRESHOLD:
1137                     # add the wait as an extra state
1138                     wait = 'wait:%d' % (math.log(max(1.0,
1139                                                      elapsed * WAIT_SCALE)))
1140                     self.ngrams.setdefault(key, []).append(wait)
1141                     key = key[1:] + (wait,)
1142
1143                 short_p = p.as_packet_type()
1144                 self.query_details.setdefault(short_p,
1145                                               []).append(tuple(p.extra))
1146                 self.ngrams.setdefault(key, []).append(short_p)
1147                 key = key[1:] + (short_p,)
1148
1149         self.cumulative_duration += cum_duration
1150         # add in the end
1151         self.ngrams.setdefault(key, []).append(NON_PACKET)
1152
1153     def save(self, f):
1154         ngrams = {}
1155         for k, v in self.ngrams.items():
1156             k = '\t'.join(k)
1157             ngrams[k] = dict(Counter(v))
1158
1159         query_details = {}
1160         for k, v in self.query_details.items():
1161             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1162                                             for x in v))
1163
1164         d = {
1165             'ngrams': ngrams,
1166             'query_details': query_details,
1167             'cumulative_duration': self.cumulative_duration,
1168             'conversation_rate': self.conversation_rate,
1169         }
1170         d['dns'] = self.dns_opcounts
1171
1172         if isinstance(f, str):
1173             f = open(f, 'w')
1174
1175         json.dump(d, f, indent=2)
1176
1177     def load(self, f):
1178         if isinstance(f, str):
1179             f = open(f)
1180
1181         d = json.load(f)
1182
1183         for k, v in d['ngrams'].items():
1184             k = tuple(str(k).split('\t'))
1185             values = self.ngrams.setdefault(k, [])
1186             for p, count in v.items():
1187                 values.extend([str(p)] * count)
1188
1189         for k, v in d['query_details'].items():
1190             values = self.query_details.setdefault(str(k), [])
1191             for p, count in v.items():
1192                 if p == '-':
1193                     values.extend([()] * count)
1194                 else:
1195                     values.extend([tuple(str(p).split('\t'))] * count)
1196
1197         if 'dns' in d:
1198             for k, v in d['dns'].items():
1199                 self.dns_opcounts[k] += v
1200
1201         self.cumulative_duration = d['cumulative_duration']
1202         self.conversation_rate = d['conversation_rate']
1203
1204     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1205                                hard_stop=None, packet_rate=1):
1206         """Construct a individual converation from the model."""
1207
1208         c = Conversation(timestamp, (server, client))
1209
1210         key = (NON_PACKET,) * (self.n - 1)
1211
1212         while key in self.ngrams:
1213             p = random.choice(self.ngrams.get(key, NON_PACKET))
1214             if p == NON_PACKET:
1215                 break
1216             if p in self.query_details:
1217                 extra = random.choice(self.query_details[p])
1218             else:
1219                 extra = []
1220
1221             protocol, opcode = p.split(':', 1)
1222             if protocol == 'wait':
1223                 log_wait_time = int(opcode) + random.random()
1224                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1225                 timestamp += wait
1226             else:
1227                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1228                 wait = math.exp(log_wait) / packet_rate
1229                 timestamp += wait
1230                 if hard_stop is not None and timestamp > hard_stop:
1231                     break
1232                 c.add_short_packet(timestamp, protocol, opcode, extra)
1233
1234             key = key[1:] + (p,)
1235
1236         return c
1237
1238     def generate_conversations(self, rate, duration, packet_rate=1):
1239         """Generate a list of conversations from the model."""
1240
1241         # We run the simulation for at least ten times as long as our
1242         # desired duration, and take a section near the start.
1243         rate_n, rate_t  = self.conversation_rate
1244
1245         duration2 = max(rate_t, duration * 2)
1246         n = rate * duration2 * rate_n / rate_t
1247
1248         server = 1
1249         client = 2
1250
1251         conversations = []
1252         end = duration2
1253         start = end - duration
1254
1255         while client < n + 2:
1256             start = random.uniform(0, duration2)
1257             c = self.construct_conversation(start,
1258                                             client,
1259                                             server,
1260                                             hard_stop=(duration2 * 5),
1261                                             packet_rate=packet_rate)
1262
1263             c.forget_packets_outside_window(start, end)
1264             c.renormalise_times(start)
1265             if len(c) != 0:
1266                 conversations.append(c)
1267             client += 1
1268
1269         print(("we have %d conversations at rate %f" %
1270                               (len(conversations), rate)), file=sys.stderr)
1271         conversations.sort()
1272         return conversations
1273
1274
1275 IP_PROTOCOLS = {
1276     'dns': '11',
1277     'rpc_netlogon': '06',
1278     'kerberos': '06',      # ratio 16248:258
1279     'smb': '06',
1280     'smb2': '06',
1281     'ldap': '06',
1282     'cldap': '11',
1283     'lsarpc': '06',
1284     'samr': '06',
1285     'dcerpc': '06',
1286     'epm': '06',
1287     'drsuapi': '06',
1288     'browser': '11',
1289     'smb_netlogon': '11',
1290     'srvsvc': '06',
1291     'nbns': '11',
1292 }
1293
1294 OP_DESCRIPTIONS = {
1295     ('browser', '0x01'): 'Host Announcement (0x01)',
1296     ('browser', '0x02'): 'Request Announcement (0x02)',
1297     ('browser', '0x08'): 'Browser Election Request (0x08)',
1298     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1299     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1300     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1301     ('cldap', '3'): 'searchRequest',
1302     ('cldap', '5'): 'searchResDone',
1303     ('dcerpc', '0'): 'Request',
1304     ('dcerpc', '11'): 'Bind',
1305     ('dcerpc', '12'): 'Bind_ack',
1306     ('dcerpc', '13'): 'Bind_nak',
1307     ('dcerpc', '14'): 'Alter_context',
1308     ('dcerpc', '15'): 'Alter_context_resp',
1309     ('dcerpc', '16'): 'AUTH3',
1310     ('dcerpc', '2'): 'Response',
1311     ('dns', '0'): 'query',
1312     ('dns', '1'): 'response',
1313     ('drsuapi', '0'): 'DsBind',
1314     ('drsuapi', '12'): 'DsCrackNames',
1315     ('drsuapi', '13'): 'DsWriteAccountSpn',
1316     ('drsuapi', '1'): 'DsUnbind',
1317     ('drsuapi', '2'): 'DsReplicaSync',
1318     ('drsuapi', '3'): 'DsGetNCChanges',
1319     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1320     ('epm', '3'): 'Map',
1321     ('kerberos', ''): '',
1322     ('ldap', '0'): 'bindRequest',
1323     ('ldap', '1'): 'bindResponse',
1324     ('ldap', '2'): 'unbindRequest',
1325     ('ldap', '3'): 'searchRequest',
1326     ('ldap', '4'): 'searchResEntry',
1327     ('ldap', '5'): 'searchResDone',
1328     ('ldap', ''): '*** Unknown ***',
1329     ('lsarpc', '14'): 'lsa_LookupNames',
1330     ('lsarpc', '15'): 'lsa_LookupSids',
1331     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1332     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1333     ('lsarpc', '6'): 'lsa_OpenPolicy',
1334     ('lsarpc', '76'): 'lsa_LookupSids3',
1335     ('lsarpc', '77'): 'lsa_LookupNames4',
1336     ('nbns', '0'): 'query',
1337     ('nbns', '1'): 'response',
1338     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1339     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1340     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1341     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1342     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1343     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1344     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1345     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1346     ('samr', '0',): 'Connect',
1347     ('samr', '16'): 'GetAliasMembership',
1348     ('samr', '17'): 'LookupNames',
1349     ('samr', '18'): 'LookupRids',
1350     ('samr', '19'): 'OpenGroup',
1351     ('samr', '1'): 'Close',
1352     ('samr', '25'): 'QueryGroupMember',
1353     ('samr', '34'): 'OpenUser',
1354     ('samr', '36'): 'QueryUserInfo',
1355     ('samr', '39'): 'GetGroupsForUser',
1356     ('samr', '3'): 'QuerySecurity',
1357     ('samr', '5'): 'LookupDomain',
1358     ('samr', '64'): 'Connect5',
1359     ('samr', '6'): 'EnumDomains',
1360     ('samr', '7'): 'OpenDomain',
1361     ('samr', '8'): 'QueryDomainInfo',
1362     ('smb', '0x04'): 'Close (0x04)',
1363     ('smb', '0x24'): 'Locking AndX (0x24)',
1364     ('smb', '0x2e'): 'Read AndX (0x2e)',
1365     ('smb', '0x32'): 'Trans2 (0x32)',
1366     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1367     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1368     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1369     ('smb', '0x74'): 'Logoff AndX (0x74)',
1370     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1371     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1372     ('smb2', '0'): 'NegotiateProtocol',
1373     ('smb2', '11'): 'Ioctl',
1374     ('smb2', '14'): 'Find',
1375     ('smb2', '16'): 'GetInfo',
1376     ('smb2', '18'): 'Break',
1377     ('smb2', '1'): 'SessionSetup',
1378     ('smb2', '2'): 'SessionLogoff',
1379     ('smb2', '3'): 'TreeConnect',
1380     ('smb2', '4'): 'TreeDisconnect',
1381     ('smb2', '5'): 'Create',
1382     ('smb2', '6'): 'Close',
1383     ('smb2', '8'): 'Read',
1384     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1385     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1386                                'user unknown (0x17)'),
1387     ('srvsvc', '16'): 'NetShareGetInfo',
1388     ('srvsvc', '21'): 'NetSrvGetInfo',
1389 }
1390
1391
1392 def expand_short_packet(p, timestamp, src, dest, extra):
1393     protocol, opcode = p.split(':', 1)
1394     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1395     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1396
1397     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1398     line.extend(extra)
1399     return '\t'.join(line)
1400
1401
1402 def replay(conversations,
1403            host=None,
1404            creds=None,
1405            lp=None,
1406            accounts=None,
1407            dns_rate=0,
1408            duration=None,
1409            **kwargs):
1410
1411     context = ReplayContext(server=host,
1412                             creds=creds,
1413                             lp=lp,
1414                             **kwargs)
1415
1416     if len(accounts) < len(conversations):
1417         print(("we have %d accounts but %d conversations" %
1418                (accounts, conversations)), file=sys.stderr)
1419
1420     cstack = list(zip(
1421         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1422         accounts))
1423
1424     # Set the process group so that the calling scripts are not killed
1425     # when the forked child processes are killed.
1426     os.setpgrp()
1427
1428     start = time.time()
1429
1430     if duration is None:
1431         # end 1 second after the last packet of the last conversation
1432         # to start. Conversations other than the last could still be
1433         # going, but we don't care.
1434         duration = cstack[0][0].packets[-1].timestamp + 1.0
1435         print("We will stop after %.1f seconds" % duration,
1436               file=sys.stderr)
1437
1438     end = start + duration
1439
1440     print("Replaying traffic for %u conversations over %d seconds"
1441           % (len(conversations), duration))
1442
1443     children = {}
1444     if dns_rate:
1445         dns_hammer = DnsHammer(dns_rate, duration)
1446         cstack.append((dns_hammer, None))
1447
1448     try:
1449         while True:
1450             # we spawn a batch, wait for finishers, then spawn another
1451             now = time.time()
1452             batch_end = min(now + 2.0, end)
1453             fork_time = 0.0
1454             fork_n = 0
1455             while cstack:
1456                 c, account = cstack.pop()
1457                 if c.start_time + start > batch_end:
1458                     cstack.append((c, account))
1459                     break
1460
1461                 st = time.time()
1462                 pid = c.replay_in_fork_with_delay(start, context, account)
1463                 children[pid] = c
1464                 t = time.time()
1465                 elapsed = t - st
1466                 fork_time += elapsed
1467                 fork_n += 1
1468                 print("forked %s in pid %s (in %fs)" % (c, pid,
1469                                                         elapsed),
1470                       file=sys.stderr)
1471
1472             if fork_n:
1473                 print(("forked %d times in %f seconds (avg %f)" %
1474                        (fork_n, fork_time, fork_time / fork_n)),
1475                       file=sys.stderr)
1476             elif cstack:
1477                 debug(2, "no forks in batch ending %f" % batch_end)
1478
1479             while time.time() < batch_end - 1.0:
1480                 time.sleep(0.01)
1481                 try:
1482                     pid, status = os.waitpid(-1, os.WNOHANG)
1483                 except OSError as e:
1484                     if e.errno != 10:  # no child processes
1485                         raise
1486                     break
1487                 if pid:
1488                     c = children.pop(pid, None)
1489                     print(("process %d finished conversation %s;"
1490                            " %d to go" %
1491                            (pid, c, len(children))), file=sys.stderr)
1492
1493             if time.time() >= end:
1494                 print("time to stop", file=sys.stderr)
1495                 break
1496
1497     except Exception:
1498         print("EXCEPTION in parent", file=sys.stderr)
1499         traceback.print_exc()
1500     finally:
1501         for s in (15, 15, 9):
1502             print(("killing %d children with -%d" %
1503                                  (len(children), s)), file=sys.stderr)
1504             for pid in children:
1505                 try:
1506                     os.kill(pid, s)
1507                 except OSError as e:
1508                     if e.errno != 3:  # don't fail if it has already died
1509                         raise
1510             time.sleep(0.5)
1511             end = time.time() + 1
1512             while children:
1513                 try:
1514                     pid, status = os.waitpid(-1, os.WNOHANG)
1515                 except OSError as e:
1516                     if e.errno != 10:
1517                         raise
1518                 if pid != 0:
1519                     c = children.pop(pid, None)
1520                     print(("kill -%d %d KILLED conversation %s; "
1521                            "%d to go" %
1522                            (s, pid, c, len(children))),
1523                           file=sys.stderr)
1524                 if time.time() >= end:
1525                     break
1526
1527             if not children:
1528                 break
1529             time.sleep(1)
1530
1531         if children:
1532             print("%d children are missing" % len(children),
1533                   file=sys.stderr)
1534
1535         # there may be stragglers that were forked just as ^C was hit
1536         # and don't appear in the list of children. We can get them
1537         # with killpg, but that will also kill us, so this is^H^H would be
1538         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1539         # so as not to have to fuss around writing signal handlers.
1540         try:
1541             os.killpg(0, 2)
1542         except KeyboardInterrupt:
1543             print("ignoring fake ^C", file=sys.stderr)
1544
1545
1546 def openLdb(host, creds, lp):
1547     session = system_session()
1548     ldb = SamDB(url="ldap://%s" % host,
1549                 session_info=session,
1550                 credentials=creds,
1551                 lp=lp)
1552     return ldb
1553
1554
1555 def ou_name(ldb, instance_id):
1556     """Generate an ou name from the instance id"""
1557     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1558                                                     ldb.domain_dn())
1559
1560
1561 def create_ou(ldb, instance_id):
1562     """Create an ou, all created user and machine accounts will belong to it.
1563
1564     This allows all the created resources to be cleaned up easily.
1565     """
1566     ou = ou_name(ldb, instance_id)
1567     try:
1568         ldb.add({"dn":          ou.split(',', 1)[1],
1569                  "objectclass": "organizationalunit"})
1570     except LdbError as e:
1571         (status, _) = e.args
1572         # ignore already exists
1573         if status != 68:
1574             raise
1575     try:
1576         ldb.add({"dn":          ou,
1577                  "objectclass": "organizationalunit"})
1578     except LdbError as e:
1579         (status, _) = e.args
1580         # ignore already exists
1581         if status != 68:
1582             raise
1583     return ou
1584
1585
1586 class ConversationAccounts(object):
1587     """Details of the machine and user accounts associated with a conversation.
1588     """
1589     def __init__(self, netbios_name, machinepass, username, userpass):
1590         self.netbios_name = netbios_name
1591         self.machinepass  = machinepass
1592         self.username     = username
1593         self.userpass     = userpass
1594
1595
1596 def generate_replay_accounts(ldb, instance_id, number, password):
1597     """Generate a series of unique machine and user account names."""
1598
1599     generate_traffic_accounts(ldb, instance_id, number, password)
1600     accounts = []
1601     for i in range(1, number + 1):
1602         netbios_name = "STGM-%d-%d" % (instance_id, i)
1603         username     = "STGU-%d-%d" % (instance_id, i)
1604
1605         account = ConversationAccounts(netbios_name, password, username,
1606                                        password)
1607         accounts.append(account)
1608     return accounts
1609
1610
1611 def generate_traffic_accounts(ldb, instance_id, number, password):
1612     """Create the specified number of user and machine accounts.
1613
1614     As accounts are not explicitly deleted between runs. This function starts
1615     with the last account and iterates backwards stopping either when it
1616     finds an already existing account or it has generated all the required
1617     accounts.
1618     """
1619     print(("Generating machine and conversation accounts, "
1620            "as required for %d conversations" % number),
1621           file=sys.stderr)
1622     added = 0
1623     for i in range(number, 0, -1):
1624         try:
1625             netbios_name = "STGM-%d-%d" % (instance_id, i)
1626             create_machine_account(ldb, instance_id, netbios_name, password)
1627             added += 1
1628         except LdbError as e:
1629             (status, _) = e.args
1630             if status == 68:
1631                 break
1632             else:
1633                 raise
1634     if added > 0:
1635         print("Added %d new machine accounts" % added,
1636               file=sys.stderr)
1637
1638     added = 0
1639     for i in range(number, 0, -1):
1640         try:
1641             username = "STGU-%d-%d" % (instance_id, i)
1642             create_user_account(ldb, instance_id, username, password)
1643             added += 1
1644         except LdbError as e:
1645             (status, _) = e.args
1646             if status == 68:
1647                 break
1648             else:
1649                 raise
1650
1651     if added > 0:
1652         print("Added %d new user accounts" % added,
1653               file=sys.stderr)
1654
1655
1656 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1657     """Create a machine account via ldap."""
1658
1659     ou = ou_name(ldb, instance_id)
1660     dn = "cn=%s,%s" % (netbios_name, ou)
1661     utf16pw = unicode(
1662         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1663     ).encode('utf-16-le')
1664     start = time.time()
1665     ldb.add({
1666         "dn": dn,
1667         "objectclass": "computer",
1668         "sAMAccountName": "%s$" % netbios_name,
1669         "userAccountControl":
1670             str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
1671         "unicodePwd": utf16pw})
1672     end = time.time()
1673     duration = end - start
1674     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1675
1676
1677 def create_user_account(ldb, instance_id, username, userpass):
1678     """Create a user account via ldap."""
1679     ou = ou_name(ldb, instance_id)
1680     user_dn = "cn=%s,%s" % (username, ou)
1681     utf16pw = unicode(
1682         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1683     ).encode('utf-16-le')
1684     start = time.time()
1685     ldb.add({
1686         "dn": user_dn,
1687         "objectclass": "user",
1688         "sAMAccountName": username,
1689         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1690         "unicodePwd": utf16pw
1691     })
1692
1693     # grant user write permission to do things like write account SPN
1694     sdutils = sd_utils.SDUtils(ldb)
1695     sdutils.dacl_add_ace(user_dn,  "(A;;WP;;;PS)")
1696
1697     end = time.time()
1698     duration = end - start
1699     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1700
1701
1702 def create_group(ldb, instance_id, name):
1703     """Create a group via ldap."""
1704
1705     ou = ou_name(ldb, instance_id)
1706     dn = "cn=%s,%s" % (name, ou)
1707     start = time.time()
1708     ldb.add({
1709         "dn": dn,
1710         "objectclass": "group",
1711         "sAMAccountName": name,
1712     })
1713     end = time.time()
1714     duration = end - start
1715     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1716
1717
1718 def user_name(instance_id, i):
1719     """Generate a user name based in the instance id"""
1720     return "STGU-%d-%d" % (instance_id, i)
1721
1722
1723 def generate_users(ldb, instance_id, number, password):
1724     """Add users to the server"""
1725     users = 0
1726     for i in range(number, 0, -1):
1727         try:
1728             username = user_name(instance_id, i)
1729             create_user_account(ldb, instance_id, username, password)
1730             users += 1
1731         except LdbError as e:
1732             (status, _) = e.args
1733             # Stop if entry exists
1734             if status == 68:
1735                 break
1736             else:
1737                 raise
1738
1739     return users
1740
1741
1742 def group_name(instance_id, i):
1743     """Generate a group name from instance id."""
1744     return "STGG-%d-%d" % (instance_id, i)
1745
1746
1747 def generate_groups(ldb, instance_id, number):
1748     """Create the required number of groups on the server."""
1749     groups = 0
1750     for i in range(number, 0, -1):
1751         try:
1752             name = group_name(instance_id, i)
1753             create_group(ldb, instance_id, name)
1754             groups += 1
1755         except LdbError as e:
1756             (status, _) = e.args
1757             # Stop if entry exists
1758             if status == 68:
1759                 break
1760             else:
1761                 raise
1762     return groups
1763
1764
1765 def clean_up_accounts(ldb, instance_id):
1766     """Remove the created accounts and groups from the server."""
1767     ou = ou_name(ldb, instance_id)
1768     try:
1769         ldb.delete(ou, ["tree_delete:1"])
1770     except LdbError as e:
1771         (status, _) = e.args
1772         # ignore does not exist
1773         if status != 32:
1774             raise
1775
1776
1777 def generate_users_and_groups(ldb, instance_id, password,
1778                               number_of_users, number_of_groups,
1779                               group_memberships):
1780     """Generate the required users and groups, allocating the users to
1781        those groups."""
1782     assignments = []
1783     groups_added  = 0
1784
1785     create_ou(ldb, instance_id)
1786
1787     print("Generating dummy user accounts", file=sys.stderr)
1788     users_added = generate_users(ldb, instance_id, number_of_users, password)
1789
1790     if number_of_groups > 0:
1791         print("Generating dummy groups", file=sys.stderr)
1792         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1793
1794     if group_memberships > 0:
1795         print("Assigning users to groups", file=sys.stderr)
1796         assignments = assign_groups(number_of_groups,
1797                                     groups_added,
1798                                     number_of_users,
1799                                     users_added,
1800                                     group_memberships)
1801         print("Adding users to groups", file=sys.stderr)
1802         add_users_to_groups(ldb, instance_id, assignments)
1803
1804     if (groups_added > 0 and users_added == 0 and
1805        number_of_groups != groups_added):
1806         print("Warning: the added groups will contain no members",
1807               file=sys.stderr)
1808
1809     print(("Added %d users, %d groups and %d group memberships" %
1810            (users_added, groups_added, len(assignments))),
1811           file=sys.stderr)
1812
1813
1814 def assign_groups(number_of_groups,
1815                   groups_added,
1816                   number_of_users,
1817                   users_added,
1818                   group_memberships):
1819     """Allocate users to groups.
1820
1821     The intention is to have a few users that belong to most groups, while
1822     the majority of users belong to a few groups.
1823
1824     A few groups will contain most users, with the remaining only having a
1825     few users.
1826     """
1827
1828     def generate_user_distribution(n):
1829         """Probability distribution of a user belonging to a group.
1830         """
1831         dist = []
1832         for x in range(1, n + 1):
1833             p = 1 / (x + 0.001)
1834             dist.append(p)
1835         return dist
1836
1837     def generate_group_distribution(n):
1838         """Probability distribution of a group containing a user."""
1839         dist = []
1840         for x in range(1, n + 1):
1841             p = 1 / (x**1.3)
1842             dist.append(p)
1843         return dist
1844
1845     assignments = set()
1846     if group_memberships <= 0:
1847         return assignments
1848
1849     group_dist = generate_group_distribution(number_of_groups)
1850     user_dist  = generate_user_distribution(number_of_users)
1851
1852     # Calculate the number of group menberships required
1853     group_memberships = math.ceil(
1854         float(group_memberships) *
1855         (float(users_added) / float(number_of_users)))
1856
1857     existing_users  = number_of_users  - users_added  - 1
1858     existing_groups = number_of_groups - groups_added - 1
1859     while len(assignments) < group_memberships:
1860         user        = random.randint(0, number_of_users - 1)
1861         group       = random.randint(0, number_of_groups - 1)
1862         probability = group_dist[group] * user_dist[user]
1863
1864         if ((random.random() < probability * 10000) and
1865            (group > existing_groups or user > existing_users)):
1866             # the + 1 converts the array index to the corresponding
1867             # group or user number
1868             assignments.add(((user + 1), (group + 1)))
1869
1870     return assignments
1871
1872
1873 def add_users_to_groups(db, instance_id, assignments):
1874     """Add users to their assigned groups.
1875
1876     Takes the list of (group,user) tuples generated by assign_groups and
1877     assign the users to their specified groups."""
1878
1879     ou = ou_name(db, instance_id)
1880
1881     def build_dn(name):
1882         return("cn=%s,%s" % (name, ou))
1883
1884     for (user, group) in assignments:
1885         user_dn  = build_dn(user_name(instance_id, user))
1886         group_dn = build_dn(group_name(instance_id, group))
1887
1888         m = ldb.Message()
1889         m.dn = ldb.Dn(db, group_dn)
1890         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1891         start = time.time()
1892         db.modify(m)
1893         end = time.time()
1894         duration = end - start
1895         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1896
1897
1898 def generate_stats(statsdir, timing_file):
1899     """Generate and print the summary stats for a run."""
1900     first      = sys.float_info.max
1901     last       = 0
1902     successful = 0
1903     failed     = 0
1904     latencies  = {}
1905     failures   = {}
1906     unique_converations = set()
1907     conversations = 0
1908
1909     if timing_file is not None:
1910         tw = timing_file.write
1911     else:
1912         def tw(x):
1913             pass
1914
1915     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1916
1917     for filename in os.listdir(statsdir):
1918         path = os.path.join(statsdir, filename)
1919         with open(path, 'r') as f:
1920             for line in f:
1921                 try:
1922                     fields       = line.rstrip('\n').split('\t')
1923                     conversation = fields[1]
1924                     protocol     = fields[2]
1925                     packet_type  = fields[3]
1926                     latency      = float(fields[4])
1927                     first        = min(float(fields[0]) - latency, first)
1928                     last         = max(float(fields[0]), last)
1929
1930                     if protocol not in latencies:
1931                         latencies[protocol] = {}
1932                     if packet_type not in latencies[protocol]:
1933                         latencies[protocol][packet_type] = []
1934
1935                     latencies[protocol][packet_type].append(latency)
1936
1937                     if protocol not in failures:
1938                         failures[protocol] = {}
1939                     if packet_type not in failures[protocol]:
1940                         failures[protocol][packet_type] = 0
1941
1942                     if fields[5] == 'True':
1943                         successful += 1
1944                     else:
1945                         failed += 1
1946                         failures[protocol][packet_type] += 1
1947
1948                     if conversation not in unique_converations:
1949                         unique_converations.add(conversation)
1950                         conversations += 1
1951
1952                     tw(line)
1953                 except (ValueError, IndexError):
1954                     # not a valid line print and ignore
1955                     print(line, file=sys.stderr)
1956                     pass
1957     duration = last - first
1958     if successful == 0:
1959         success_rate = 0
1960     else:
1961         success_rate = successful / duration
1962     if failed == 0:
1963         failure_rate = 0
1964     else:
1965         failure_rate = failed / duration
1966
1967     print("Total conversations:   %10d" % conversations)
1968     print("Successful operations: %10d (%.3f per second)"
1969           % (successful, success_rate))
1970     print("Failed operations:     %10d (%.3f per second)"
1971           % (failed, failure_rate))
1972
1973     print("Protocol    Op Code  Description                               "
1974           " Count       Failed         Mean       Median          "
1975           "95%        Range          Max")
1976
1977     protocols = sorted(latencies.keys())
1978     for protocol in protocols:
1979         packet_types = sorted(latencies[protocol], key=opcode_key)
1980         for packet_type in packet_types:
1981             values     = latencies[protocol][packet_type]
1982             values     = sorted(values)
1983             count      = len(values)
1984             failed     = failures[protocol][packet_type]
1985             mean       = sum(values) / count
1986             median     = calc_percentile(values, 0.50)
1987             percentile = calc_percentile(values, 0.95)
1988             rng        = values[-1] - values[0]
1989             maxv       = values[-1]
1990             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1991             if sys.stdout.isatty:
1992                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1993                       "%12.6f %12.6f %12.6f %12.6f"
1994                       % (protocol,
1995                          packet_type,
1996                          desc,
1997                          count,
1998                          failed,
1999                          mean,
2000                          median,
2001                          percentile,
2002                          rng,
2003                          maxv))
2004             else:
2005                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2006                       % (protocol,
2007                          packet_type,
2008                          desc,
2009                          count,
2010                          failed,
2011                          mean,
2012                          median,
2013                          percentile,
2014                          rng,
2015                          maxv))
2016
2017
2018 def opcode_key(v):
2019     """Sort key for the operation code to ensure that it sorts numerically"""
2020     try:
2021         return "%03d" % int(v)
2022     except:
2023         return v
2024
2025
2026 def calc_percentile(values, percentile):
2027     """Calculate the specified percentile from the list of values.
2028
2029     Assumes the list is sorted in ascending order.
2030     """
2031
2032     if not values:
2033         return 0
2034     k = (len(values) - 1) * percentile
2035     f = math.floor(k)
2036     c = math.ceil(k)
2037     if f == c:
2038         return values[int(k)]
2039     d0 = values[int(f)] * (c - k)
2040     d1 = values[int(c)] * (k - f)
2041     return d0 + d1
2042
2043
2044 def mk_masked_dir(*path):
2045     """In a testenv we end up with 0777 diectories that look an alarming
2046     green colour with ls. Use umask to avoid that."""
2047     d = os.path.join(*path)
2048     mask = os.umask(0o077)
2049     os.mkdir(d)
2050     os.umask(mask)
2051     return d