traffic: change machine creds secure channel type
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_BDC
48 from samba import gensec
49 from samba import sd_utils
50
51 SLEEP_OVERHEAD = 3e-4
52
53 # we don't use None, because it complicates [de]serialisation
54 NON_PACKET = '-'
55
56 CLIENT_CLUES = {
57     ('dns', '0'): 1.0,      # query
58     ('smb', '0x72'): 1.0,   # Negotiate protocol
59     ('ldap', '0'): 1.0,     # bind
60     ('ldap', '3'): 1.0,     # searchRequest
61     ('ldap', '2'): 1.0,     # unbindRequest
62     ('cldap', '3'): 1.0,
63     ('dcerpc', '11'): 1.0,  # bind
64     ('dcerpc', '14'): 1.0,  # Alter_context
65     ('nbns', '0'): 1.0,     # query
66 }
67
68 SERVER_CLUES = {
69     ('dns', '1'): 1.0,      # response
70     ('ldap', '1'): 1.0,     # bind response
71     ('ldap', '4'): 1.0,     # search result
72     ('ldap', '5'): 1.0,     # search done
73     ('cldap', '5'): 1.0,
74     ('dcerpc', '12'): 1.0,  # bind_ack
75     ('dcerpc', '13'): 1.0,  # bind_nak
76     ('dcerpc', '15'): 1.0,  # Alter_context response
77 }
78
79 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
80
81 WAIT_SCALE = 10.0
82 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
83 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
84
85 # DEBUG_LEVEL can be changed by scripts with -d
86 DEBUG_LEVEL = 0
87
88
89 def debug(level, msg, *args):
90     """Print a formatted debug message to standard error.
91
92
93     :param level: The debug level, message will be printed if it is <= the
94                   currently set debug level. The debug level can be set with
95                   the -d option.
96     :param msg:   The message to be logged, can contain C-Style format
97                   specifiers
98     :param args:  The parameters required by the format specifiers
99     """
100     if level <= DEBUG_LEVEL:
101         if not args:
102             print(msg, file=sys.stderr)
103         else:
104             print(msg % tuple(args), file=sys.stderr)
105
106
107 def debug_lineno(*args):
108     """ Print an unformatted log message to stderr, contaning the line number
109     """
110     tb = traceback.extract_stack(limit=2)
111     print((" %s:" "\033[01;33m"
112            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
113           file=sys.stderr)
114     for a in args:
115         print(a, file=sys.stderr)
116     print(file=sys.stderr)
117     sys.stderr.flush()
118
119
120 def random_colour_print():
121     """Return a function that prints a randomly coloured line to stderr"""
122     n = 18 + random.randrange(214)
123     prefix = "\033[38;5;%dm" % n
124
125     def p(*args):
126         for a in args:
127             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
128
129     return p
130
131
132 class FakePacketError(Exception):
133     pass
134
135
136 class Packet(object):
137     """Details of a network packet"""
138     def __init__(self, fields):
139         if isinstance(fields, str):
140             fields = fields.rstrip('\n').split('\t')
141
142         (timestamp,
143          ip_protocol,
144          stream_number,
145          src,
146          dest,
147          protocol,
148          opcode,
149          desc) = fields[:8]
150         extra = fields[8:]
151
152         self.timestamp = float(timestamp)
153         self.ip_protocol = ip_protocol
154         try:
155             self.stream_number = int(stream_number)
156         except (ValueError, TypeError):
157             self.stream_number = None
158         self.src = int(src)
159         self.dest = int(dest)
160         self.protocol = protocol
161         self.opcode = opcode
162         self.desc = desc
163         self.extra = extra
164
165         if self.src < self.dest:
166             self.endpoints = (self.src, self.dest)
167         else:
168             self.endpoints = (self.dest, self.src)
169
170     def as_summary(self, time_offset=0.0):
171         """Format the packet as a traffic_summary line.
172         """
173         extra = '\t'.join(self.extra)
174         t = self.timestamp + time_offset
175         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
176                 (t,
177                  self.ip_protocol,
178                  self.stream_number or '',
179                  self.src,
180                  self.dest,
181                  self.protocol,
182                  self.opcode,
183                  self.desc,
184                  extra))
185
186     def __str__(self):
187         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
188                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
189                  self.stream_number, self.protocol, self.opcode, self.desc,
190                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
191
192     def __repr__(self):
193         return "<Packet @%s>" % self
194
195     def copy(self):
196         return self.__class__([self.timestamp,
197                                self.ip_protocol,
198                                self.stream_number,
199                                self.src,
200                                self.dest,
201                                self.protocol,
202                                self.opcode,
203                                self.desc] + self.extra)
204
205     def as_packet_type(self):
206         t = '%s:%s' % (self.protocol, self.opcode)
207         return t
208
209     def client_score(self):
210         """A positive number means we think it is a client; a negative number
211         means we think it is a server. Zero means no idea. range: -1 to 1.
212         """
213         key = (self.protocol, self.opcode)
214         if key in CLIENT_CLUES:
215             return CLIENT_CLUES[key]
216         if key in SERVER_CLUES:
217             return -SERVER_CLUES[key]
218         return 0.0
219
220     def play(self, conversation, context):
221         """Send the packet over the network, if required.
222
223         Some packets are ignored, i.e. for  protocols not handled,
224         server response messages, or messages that are generated by the
225         protocol layer associated with other packets.
226         """
227         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
228         try:
229             fn = getattr(traffic_packets, fn_name)
230
231         except AttributeError as e:
232             print("Conversation(%s) Missing handler %s" % \
233                   (conversation.conversation_id, fn_name),
234                   file=sys.stderr)
235             return
236
237         # Don't display a message for kerberos packets, they're not directly
238         # generated they're used to indicate kerberos should be used
239         if self.protocol != "kerberos":
240             debug(2, "Conversation(%s) Calling handler %s" %
241                      (conversation.conversation_id, fn_name))
242
243         start = time.time()
244         try:
245             if fn(self, conversation, context):
246                 # Only collect timing data for functions that generate
247                 # network traffic, or fail
248                 end = time.time()
249                 duration = end - start
250                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
251                       (end, conversation.conversation_id, self.protocol,
252                        self.opcode, duration))
253         except Exception as e:
254             end = time.time()
255             duration = end - start
256             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
257                   (end, conversation.conversation_id, self.protocol,
258                    self.opcode, duration, e))
259
260     def __cmp__(self, other):
261         return self.timestamp - other.timestamp
262
263     def is_really_a_packet(self, missing_packet_stats=None):
264         """Is the packet one that can be ignored?
265
266         If so removing it will have no effect on the replay
267         """
268         if self.protocol in SKIPPED_PROTOCOLS:
269             # Ignore any packets for the protocols we're not interested in.
270             return False
271         if self.protocol == "ldap" and self.opcode == '':
272             # skip ldap continuation packets
273             return False
274
275         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
276         fn = getattr(traffic_packets, fn_name, None)
277         if not fn:
278             print("missing packet %s" % fn_name, file=sys.stderr)
279             return False
280         if fn is traffic_packets.null_packet:
281             return False
282         return True
283
284
285 class ReplayContext(object):
286     """State/Context for an individual conversation between an simulated client
287        and a server.
288     """
289
290     def __init__(self,
291                  server=None,
292                  lp=None,
293                  creds=None,
294                  badpassword_frequency=None,
295                  prefer_kerberos=None,
296                  tempdir=None,
297                  statsdir=None,
298                  ou=None,
299                  base_dn=None,
300                  domain=None,
301                  domain_sid=None):
302
303         self.server                   = server
304         self.ldap_connections         = []
305         self.dcerpc_connections       = []
306         self.lsarpc_connections       = []
307         self.lsarpc_connections_named = []
308         self.drsuapi_connections      = []
309         self.srvsvc_connections       = []
310         self.samr_contexts            = []
311         self.netlogon_connection      = None
312         self.creds                    = creds
313         self.lp                       = lp
314         self.prefer_kerberos          = prefer_kerberos
315         self.ou                       = ou
316         self.base_dn                  = base_dn
317         self.domain                   = domain
318         self.statsdir                 = statsdir
319         self.global_tempdir           = tempdir
320         self.domain_sid               = domain_sid
321         self.realm                    = lp.get('realm')
322
323         # Bad password attempt controls
324         self.badpassword_frequency    = badpassword_frequency
325         self.last_lsarpc_bad          = False
326         self.last_lsarpc_named_bad    = False
327         self.last_simple_bind_bad     = False
328         self.last_bind_bad            = False
329         self.last_srvsvc_bad          = False
330         self.last_drsuapi_bad         = False
331         self.last_netlogon_bad        = False
332         self.last_samlogon_bad        = False
333         self.generate_ldap_search_tables()
334         self.next_conversation_id = itertools.count().next
335
336     def generate_ldap_search_tables(self):
337         session = system_session()
338
339         db = SamDB(url="ldap://%s" % self.server,
340                    session_info=session,
341                    credentials=self.creds,
342                    lp=self.lp)
343
344         res = db.search(db.domain_dn(),
345                         scope=ldb.SCOPE_SUBTREE,
346                         controls=["paged_results:1:1000"],
347                         attrs=['dn'])
348
349         # find a list of dns for each pattern
350         # e.g. CN,CN,CN,DC,DC
351         dn_map = {}
352         attribute_clue_map = {
353             'invocationId': []
354         }
355
356         for r in res:
357             dn = str(r.dn)
358             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
359             dns = dn_map.setdefault(pattern, [])
360             dns.append(dn)
361             if dn.startswith('CN=NTDS Settings,'):
362                 attribute_clue_map['invocationId'].append(dn)
363
364         # extend the map in case we are working with a different
365         # number of DC components.
366         # for k, v in self.dn_map.items():
367         #     print >>sys.stderr, k, len(v)
368
369         for k, v in dn_map.items():
370             if k[-3:] != ',DC':
371                 continue
372             p = k[:-3]
373             while p[-3:] == ',DC':
374                 p = p[:-3]
375             for i in range(5):
376                 p += ',DC'
377                 if p != k and p in dn_map:
378                     print('dn_map collison %s %s' % (k, p),
379                           file=sys.stderr)
380                     continue
381                 dn_map[p] = dn_map[k]
382
383         self.dn_map = dn_map
384         self.attribute_clue_map = attribute_clue_map
385
386     def generate_process_local_config(self, account, conversation):
387         if account is None:
388             return
389         self.netbios_name             = account.netbios_name
390         self.machinepass              = account.machinepass
391         self.username                 = account.username
392         self.userpass                 = account.userpass
393
394         self.tempdir = mk_masked_dir(self.global_tempdir,
395                                      'conversation-%d' %
396                                      conversation.conversation_id)
397
398         self.lp.set("private dir",     self.tempdir)
399         self.lp.set("lock dir",        self.tempdir)
400         self.lp.set("state directory", self.tempdir)
401         self.lp.set("tls verify peer", "no_check")
402
403         # If the domain was not specified, check for the environment
404         # variable.
405         if self.domain is None:
406             self.domain = os.environ["DOMAIN"]
407
408         self.remoteAddress = "/root/ncalrpc_as_system"
409         self.samlogon_dn   = ("cn=%s,%s" %
410                               (self.netbios_name, self.ou))
411         self.user_dn       = ("cn=%s,%s" %
412                               (self.username, self.ou))
413
414         self.generate_machine_creds()
415         self.generate_user_creds()
416
417     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
418         """Execute the supplied logon function, randomly choosing the
419            bad credentials.
420
421            Based on the frequency in badpassword_frequency randomly perform the
422            function with the supplied bad credentials.
423            If run with bad credentials, the function is re-run with the good
424            credentials.
425            failed_last_time is used to prevent consecutive bad credential
426            attempts. So the over all bad credential frequency will be lower
427            than that requested, but not significantly.
428         """
429         if not failed_last_time:
430             if (self.badpassword_frequency > 0 and
431                random.random() < self.badpassword_frequency):
432                 try:
433                     f(bad)
434                 except:
435                     # Ignore any exceptions as the operation may fail
436                     # as it's being performed with bad credentials
437                     pass
438                 failed_last_time = True
439             else:
440                 failed_last_time = False
441
442         result = f(good)
443         return (result, failed_last_time)
444
445     def generate_user_creds(self):
446         """Generate the conversation specific user Credentials.
447
448         Each Conversation has an associated user account used to simulate
449         any non Administrative user traffic.
450
451         Generates user credentials with good and bad passwords and ldap
452         simple bind credentials with good and bad passwords.
453         """
454         self.user_creds = Credentials()
455         self.user_creds.guess(self.lp)
456         self.user_creds.set_workstation(self.netbios_name)
457         self.user_creds.set_password(self.userpass)
458         self.user_creds.set_username(self.username)
459         self.user_creds.set_domain(self.domain)
460         if self.prefer_kerberos:
461             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
462         else:
463             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
464
465         self.user_creds_bad = Credentials()
466         self.user_creds_bad.guess(self.lp)
467         self.user_creds_bad.set_workstation(self.netbios_name)
468         self.user_creds_bad.set_password(self.userpass[:-4])
469         self.user_creds_bad.set_username(self.username)
470         if self.prefer_kerberos:
471             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
472         else:
473             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
474
475         # Credentials for ldap simple bind.
476         self.simple_bind_creds = Credentials()
477         self.simple_bind_creds.guess(self.lp)
478         self.simple_bind_creds.set_workstation(self.netbios_name)
479         self.simple_bind_creds.set_password(self.userpass)
480         self.simple_bind_creds.set_username(self.username)
481         self.simple_bind_creds.set_gensec_features(
482             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
483         if self.prefer_kerberos:
484             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
485         else:
486             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
487         self.simple_bind_creds.set_bind_dn(self.user_dn)
488
489         self.simple_bind_creds_bad = Credentials()
490         self.simple_bind_creds_bad.guess(self.lp)
491         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
492         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
493         self.simple_bind_creds_bad.set_username(self.username)
494         self.simple_bind_creds_bad.set_gensec_features(
495             self.simple_bind_creds_bad.get_gensec_features() |
496             gensec.FEATURE_SEAL)
497         if self.prefer_kerberos:
498             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
499         else:
500             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
501         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
502
503     def generate_machine_creds(self):
504         """Generate the conversation specific machine Credentials.
505
506         Each Conversation has an associated machine account.
507
508         Generates machine credentials with good and bad passwords.
509         """
510
511         self.machine_creds = Credentials()
512         self.machine_creds.guess(self.lp)
513         self.machine_creds.set_workstation(self.netbios_name)
514         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
515         self.machine_creds.set_password(self.machinepass)
516         self.machine_creds.set_username(self.netbios_name + "$")
517         self.machine_creds.set_domain(self.domain)
518         if self.prefer_kerberos:
519             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
520         else:
521             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
522
523         self.machine_creds_bad = Credentials()
524         self.machine_creds_bad.guess(self.lp)
525         self.machine_creds_bad.set_workstation(self.netbios_name)
526         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
527         self.machine_creds_bad.set_password(self.machinepass[:-4])
528         self.machine_creds_bad.set_username(self.netbios_name + "$")
529         if self.prefer_kerberos:
530             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
531         else:
532             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
533
534     def get_matching_dn(self, pattern, attributes=None):
535         # If the pattern is an empty string, we assume ROOTDSE,
536         # Otherwise we try adding or removing DC suffixes, then
537         # shorter leading patterns until we hit one.
538         # e.g if there is no CN,CN,CN,CN,DC,DC
539         # we first try       CN,CN,CN,CN,DC
540         # and                CN,CN,CN,CN,DC,DC,DC
541         # then change to        CN,CN,CN,DC,DC
542         # and as last resort we use the base_dn
543         attr_clue = self.attribute_clue_map.get(attributes)
544         if attr_clue:
545             return random.choice(attr_clue)
546
547         pattern = pattern.upper()
548         while pattern:
549             if pattern in self.dn_map:
550                 return random.choice(self.dn_map[pattern])
551             # chop one off the front and try it all again.
552             pattern = pattern[3:]
553
554         return self.base_dn
555
556     def get_dcerpc_connection(self, new=False):
557         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
558         if self.dcerpc_connections and not new:
559             return self.dcerpc_connections[-1]
560         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
561                              (guid, 1), self.lp)
562         self.dcerpc_connections.append(c)
563         return c
564
565     def get_srvsvc_connection(self, new=False):
566         if self.srvsvc_connections and not new:
567             return self.srvsvc_connections[-1]
568
569         def connect(creds):
570             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
571                                  self.lp,
572                                  creds)
573
574         (c, self.last_srvsvc_bad) = \
575             self.with_random_bad_credentials(connect,
576                                              self.user_creds,
577                                              self.user_creds_bad,
578                                              self.last_srvsvc_bad)
579
580         self.srvsvc_connections.append(c)
581         return c
582
583     def get_lsarpc_connection(self, new=False):
584         if self.lsarpc_connections and not new:
585             return self.lsarpc_connections[-1]
586
587         def connect(creds):
588             binding_options = 'schannel,seal,sign'
589             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
590                               (self.server, binding_options),
591                               self.lp,
592                               creds)
593
594         (c, self.last_lsarpc_bad) = \
595             self.with_random_bad_credentials(connect,
596                                              self.machine_creds,
597                                              self.machine_creds_bad,
598                                              self.last_lsarpc_bad)
599
600         self.lsarpc_connections.append(c)
601         return c
602
603     def get_lsarpc_named_pipe_connection(self, new=False):
604         if self.lsarpc_connections_named and not new:
605             return self.lsarpc_connections_named[-1]
606
607         def connect(creds):
608             return lsa.lsarpc("ncacn_np:%s" % (self.server),
609                               self.lp,
610                               creds)
611
612         (c, self.last_lsarpc_named_bad) = \
613             self.with_random_bad_credentials(connect,
614                                              self.machine_creds,
615                                              self.machine_creds_bad,
616                                              self.last_lsarpc_named_bad)
617
618         self.lsarpc_connections_named.append(c)
619         return c
620
621     def get_drsuapi_connection_pair(self, new=False, unbind=False):
622         """get a (drs, drs_handle) tuple"""
623         if self.drsuapi_connections and not new:
624             c = self.drsuapi_connections[-1]
625             return c
626
627         def connect(creds):
628             binding_options = 'seal'
629             binding_string = "ncacn_ip_tcp:%s[%s]" %\
630                              (self.server, binding_options)
631             return drsuapi.drsuapi(binding_string, self.lp, creds)
632
633         (drs, self.last_drsuapi_bad) = \
634             self.with_random_bad_credentials(connect,
635                                              self.user_creds,
636                                              self.user_creds_bad,
637                                              self.last_drsuapi_bad)
638
639         (drs_handle, supported_extensions) = drs_DsBind(drs)
640         c = (drs, drs_handle)
641         self.drsuapi_connections.append(c)
642         return c
643
644     def get_ldap_connection(self, new=False, simple=False):
645         if self.ldap_connections and not new:
646             return self.ldap_connections[-1]
647
648         def simple_bind(creds):
649             """
650             To run simple bind against Windows, we need to run
651             following commands in PowerShell:
652
653                 Install-windowsfeature ADCS-Cert-Authority
654                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
655                 Restart-Computer
656
657             """
658             return SamDB('ldaps://%s' % self.server,
659                          credentials=creds,
660                          lp=self.lp)
661
662         def sasl_bind(creds):
663             return SamDB('ldap://%s' % self.server,
664                          credentials=creds,
665                          lp=self.lp)
666         if simple:
667             (samdb, self.last_simple_bind_bad) = \
668                 self.with_random_bad_credentials(simple_bind,
669                                                  self.simple_bind_creds,
670                                                  self.simple_bind_creds_bad,
671                                                  self.last_simple_bind_bad)
672         else:
673             (samdb, self.last_bind_bad) = \
674                 self.with_random_bad_credentials(sasl_bind,
675                                                  self.user_creds,
676                                                  self.user_creds_bad,
677                                                  self.last_bind_bad)
678
679         self.ldap_connections.append(samdb)
680         return samdb
681
682     def get_samr_context(self, new=False):
683         if not self.samr_contexts or new:
684             self.samr_contexts.append(
685                 SamrContext(self.server, lp=self.lp, creds=self.creds))
686         return self.samr_contexts[-1]
687
688     def get_netlogon_connection(self):
689
690         if self.netlogon_connection:
691             return self.netlogon_connection
692
693         def connect(creds):
694             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
695                                      (self.server),
696                                      self.lp,
697                                      creds)
698         (c, self.last_netlogon_bad) = \
699             self.with_random_bad_credentials(connect,
700                                              self.machine_creds,
701                                              self.machine_creds_bad,
702                                              self.last_netlogon_bad)
703         self.netlogon_connection = c
704         return c
705
706     def guess_a_dns_lookup(self):
707         return (self.realm, 'A')
708
709     def get_authenticator(self):
710         auth = self.machine_creds.new_client_authenticator()
711         current  = netr_Authenticator()
712         current.cred.data = [ord(x) for x in auth["credential"]]
713         current.timestamp = auth["timestamp"]
714
715         subsequent = netr_Authenticator()
716         return (current, subsequent)
717
718
719 class SamrContext(object):
720     """State/Context associated with a samr connection.
721     """
722     def __init__(self, server, lp=None, creds=None):
723         self.connection    = None
724         self.handle        = None
725         self.domain_handle = None
726         self.domain_sid    = None
727         self.group_handle  = None
728         self.user_handle   = None
729         self.rids          = None
730         self.server        = server
731         self.lp            = lp
732         self.creds         = creds
733
734     def get_connection(self):
735         if not self.connection:
736             self.connection = samr.samr(
737                 "ncacn_ip_tcp:%s[seal]" % (self.server),
738                 lp_ctx=self.lp,
739                 credentials=self.creds)
740
741         return self.connection
742
743     def get_handle(self):
744         if not self.handle:
745             c = self.get_connection()
746             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
747         return self.handle
748
749
750 class Conversation(object):
751     """Details of a converation between a simulated client and a server."""
752     conversation_id = None
753
754     def __init__(self, start_time=None, endpoints=None):
755         self.start_time = start_time
756         self.endpoints = endpoints
757         self.packets = []
758         self.msg = random_colour_print()
759         self.client_balance = 0.0
760
761     def __cmp__(self, other):
762         if self.start_time is None:
763             if other.start_time is None:
764                 return 0
765             return -1
766         if other.start_time is None:
767             return 1
768         return self.start_time - other.start_time
769
770     def add_packet(self, packet):
771         """Add a packet object to this conversation, making a local copy with
772         a conversation-relative timestamp."""
773         p = packet.copy()
774
775         if self.start_time is None:
776             self.start_time = p.timestamp
777
778         if self.endpoints is None:
779             self.endpoints = p.endpoints
780
781         if p.endpoints != self.endpoints:
782             raise FakePacketError("Conversation endpoints %s don't match"
783                                   "packet endpoints %s" %
784                                   (self.endpoints, p.endpoints))
785
786         p.timestamp -= self.start_time
787
788         if p.src == p.endpoints[0]:
789             self.client_balance -= p.client_score()
790         else:
791             self.client_balance += p.client_score()
792
793         if p.is_really_a_packet():
794             self.packets.append(p)
795
796     def add_short_packet(self, timestamp, protocol, opcode, extra,
797                          client=True):
798         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
799         (possibly empty) list of extra data. If client is True, assume
800         this packet is from the client to the server.
801         """
802         src, dest = self.guess_client_server()
803         if not client:
804             src, dest = dest, src
805
806         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
807         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
808         fields = [timestamp - self.start_time, ip_protocol,
809                   '', src, dest,
810                   protocol, opcode, desc]
811         fields.extend(extra)
812         packet = Packet(fields)
813         # XXX we're assuming the timestamp is already adjusted for
814         # this conversation?
815         # XXX should we adjust client balance for guessed packets?
816         if packet.src == packet.endpoints[0]:
817             self.client_balance -= packet.client_score()
818         else:
819             self.client_balance += packet.client_score()
820         if packet.is_really_a_packet():
821             self.packets.append(packet)
822
823     def __str__(self):
824         return ("<Conversation %s %s starting %.3f %d packets>" %
825                 (self.conversation_id, self.endpoints, self.start_time,
826                  len(self.packets)))
827
828     __repr__ = __str__
829
830     def __iter__(self):
831         return iter(self.packets)
832
833     def __len__(self):
834         return len(self.packets)
835
836     def get_duration(self):
837         if len(self.packets) < 2:
838             return 0
839         return self.packets[-1].timestamp - self.packets[0].timestamp
840
841     def replay_as_summary_lines(self):
842         lines = []
843         for p in self.packets:
844             lines.append(p.as_summary(self.start_time))
845         return lines
846
847     def replay_in_fork_with_delay(self, start, context=None, account=None):
848         """Fork a new process and replay the conversation.
849         """
850         def signal_handler(signal, frame):
851             """Signal handler closes standard out and error.
852
853             Triggered by a sigterm, ensures that the log messages are flushed
854             to disk and not lost.
855             """
856             sys.stderr.close()
857             sys.stdout.close()
858             os._exit(0)
859
860         t = self.start_time
861         now = time.time() - start
862         gap = t - now
863         # we are replaying strictly in order, so it is safe to sleep
864         # in the main process if the gap is big enough. This reduces
865         # the number of concurrent threads, which allows us to make
866         # larger loads.
867         if gap > 0.15 and False:
868             print("sleeping for %f in main process" % (gap - 0.1),
869                   file=sys.stderr)
870             time.sleep(gap - 0.1)
871             now = time.time() - start
872             gap = t - now
873             print("gap is now %f" % gap, file=sys.stderr)
874
875         self.conversation_id = context.next_conversation_id()
876         pid = os.fork()
877         if pid != 0:
878             return pid
879         pid = os.getpid()
880         signal.signal(signal.SIGTERM, signal_handler)
881         # we must never return, or we'll end up running parts of the
882         # parent's clean-up code. So we work in a try...finally, and
883         # try to print any exceptions.
884
885         try:
886             context.generate_process_local_config(account, self)
887             sys.stdin.close()
888             os.close(0)
889             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
890                                     self.conversation_id)
891             sys.stdout.close()
892             sys.stdout = open(filename, 'w')
893
894             sleep_time = gap - SLEEP_OVERHEAD
895             if sleep_time > 0:
896                 time.sleep(sleep_time)
897
898             miss = t - (time.time() - start)
899             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
900             self.replay(context)
901         except Exception:
902             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
903                   file=sys.stderr)
904             traceback.print_exc(sys.stderr)
905         finally:
906             sys.stderr.close()
907             sys.stdout.close()
908             os._exit(0)
909
910     def replay(self, context=None):
911         start = time.time()
912
913         for p in self.packets:
914             now = time.time() - start
915             gap = p.timestamp - now
916             sleep_time = gap - SLEEP_OVERHEAD
917             if sleep_time > 0:
918                 time.sleep(sleep_time)
919
920             miss = p.timestamp - (time.time() - start)
921             if context is None:
922                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
923                                                            os.getpid()))
924                 continue
925             p.play(self, context)
926
927     def guess_client_server(self, server_clue=None):
928         """Have a go at deciding who is the server and who is the client.
929         returns (client, server)
930         """
931         a, b = self.endpoints
932
933         if self.client_balance < 0:
934             return (a, b)
935
936         # in the absense of a clue, we will fall through to assuming
937         # the lowest number is the server (which is usually true).
938
939         if self.client_balance == 0 and server_clue == b:
940             return (a, b)
941
942         return (b, a)
943
944     def forget_packets_outside_window(self, s, e):
945         """Prune any packets outside the timne window we're interested in
946
947         :param s: start of the window
948         :param e: end of the window
949         """
950         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
951         self.start_time = self.packets[0].timestamp if self.packets else None
952
953     def renormalise_times(self, start_time):
954         """Adjust the packet start times relative to the new start time."""
955         for p in self.packets:
956             p.timestamp -= start_time
957
958         if self.start_time is not None:
959             self.start_time -= start_time
960
961
962 class DnsHammer(Conversation):
963     """A lightweight conversation that generates a lot of dns:0 packets on
964     the fly"""
965
966     def __init__(self, dns_rate, duration):
967         n = int(dns_rate * duration)
968         self.times = [random.uniform(0, duration) for i in range(n)]
969         self.times.sort()
970         self.rate = dns_rate
971         self.duration = duration
972         self.start_time = 0
973         self.msg = random_colour_print()
974
975     def __str__(self):
976         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
977                 (len(self.times), self.duration, self.rate))
978
979     def replay_in_fork_with_delay(self, start, context=None, account=None):
980         return Conversation.replay_in_fork_with_delay(self,
981                                                       start,
982                                                       context,
983                                                       account)
984
985     def replay(self, context=None):
986         start = time.time()
987         fn = traffic_packets.packet_dns_0
988         for t in self.times:
989             now = time.time() - start
990             gap = t - now
991             sleep_time = gap - SLEEP_OVERHEAD
992             if sleep_time > 0:
993                 time.sleep(sleep_time)
994
995             if context is None:
996                 miss = t - (time.time() - start)
997                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
998                                                            os.getpid()))
999                 continue
1000
1001             packet_start = time.time()
1002             try:
1003                 fn(self, self, context)
1004                 end = time.time()
1005                 duration = end - packet_start
1006                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1007             except Exception as e:
1008                 end = time.time()
1009                 duration = end - packet_start
1010                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1011
1012
1013 def ingest_summaries(files, dns_mode='count'):
1014     """Load a summary traffic summary file and generated Converations from it.
1015     """
1016
1017     dns_counts = defaultdict(int)
1018     packets = []
1019     for f in files:
1020         if isinstance(f, str):
1021             f = open(f)
1022         print("Ingesting %s" % (f.name,), file=sys.stderr)
1023         for line in f:
1024             p = Packet(line)
1025             if p.protocol == 'dns' and dns_mode != 'include':
1026                 dns_counts[p.opcode] += 1
1027             else:
1028                 packets.append(p)
1029
1030         f.close()
1031
1032     if not packets:
1033         return [], 0
1034
1035     start_time = min(p.timestamp for p in packets)
1036     last_packet = max(p.timestamp for p in packets)
1037
1038     print("gathering packets into conversations", file=sys.stderr)
1039     conversations = OrderedDict()
1040     for p in packets:
1041         p.timestamp -= start_time
1042         c = conversations.get(p.endpoints)
1043         if c is None:
1044             c = Conversation()
1045             conversations[p.endpoints] = c
1046         c.add_packet(p)
1047
1048     # We only care about conversations with actual traffic, so we
1049     # filter out conversations with nothing to say. We do that here,
1050     # rather than earlier, because those empty packets contain useful
1051     # hints as to which end of the conversation was the client.
1052     conversation_list = []
1053     for c in conversations.values():
1054         if len(c) != 0:
1055             conversation_list.append(c)
1056
1057     # This is obviously not correct, as many conversations will appear
1058     # to start roughly simultaneously at the beginning of the snapshot.
1059     # To which we say: oh well, so be it.
1060     duration = float(last_packet - start_time)
1061     mean_interval = len(conversations) / duration
1062
1063     return conversation_list, mean_interval, duration, dns_counts
1064
1065
1066 def guess_server_address(conversations):
1067     # we guess the most common address.
1068     addresses = Counter()
1069     for c in conversations:
1070         addresses.update(c.endpoints)
1071     if addresses:
1072         return addresses.most_common(1)[0]
1073
1074
1075 def stringify_keys(x):
1076     y = {}
1077     for k, v in x.items():
1078         k2 = '\t'.join(k)
1079         y[k2] = v
1080     return y
1081
1082
1083 def unstringify_keys(x):
1084     y = {}
1085     for k, v in x.items():
1086         t = tuple(str(k).split('\t'))
1087         y[t] = v
1088     return y
1089
1090
1091 class TrafficModel(object):
1092     def __init__(self, n=3):
1093         self.ngrams = {}
1094         self.query_details = {}
1095         self.n = n
1096         self.dns_opcounts = defaultdict(int)
1097         self.cumulative_duration = 0.0
1098         self.conversation_rate = [0, 1]
1099
1100     def learn(self, conversations, dns_opcounts={}):
1101         prev = 0.0
1102         cum_duration = 0.0
1103         key = (NON_PACKET,) * (self.n - 1)
1104
1105         server = guess_server_address(conversations)
1106
1107         for k, v in dns_opcounts.items():
1108             self.dns_opcounts[k] += v
1109
1110         if len(conversations) > 1:
1111             elapsed =\
1112                 conversations[-1].start_time - conversations[0].start_time
1113             self.conversation_rate[0] = len(conversations)
1114             self.conversation_rate[1] = elapsed
1115
1116         for c in conversations:
1117             client, server = c.guess_client_server(server)
1118             cum_duration += c.get_duration()
1119             key = (NON_PACKET,) * (self.n - 1)
1120             for p in c:
1121                 if p.src != client:
1122                     continue
1123
1124                 elapsed = p.timestamp - prev
1125                 prev = p.timestamp
1126                 if elapsed > WAIT_THRESHOLD:
1127                     # add the wait as an extra state
1128                     wait = 'wait:%d' % (math.log(max(1.0,
1129                                                      elapsed * WAIT_SCALE)))
1130                     self.ngrams.setdefault(key, []).append(wait)
1131                     key = key[1:] + (wait,)
1132
1133                 short_p = p.as_packet_type()
1134                 self.query_details.setdefault(short_p,
1135                                               []).append(tuple(p.extra))
1136                 self.ngrams.setdefault(key, []).append(short_p)
1137                 key = key[1:] + (short_p,)
1138
1139         self.cumulative_duration += cum_duration
1140         # add in the end
1141         self.ngrams.setdefault(key, []).append(NON_PACKET)
1142
1143     def save(self, f):
1144         ngrams = {}
1145         for k, v in self.ngrams.items():
1146             k = '\t'.join(k)
1147             ngrams[k] = dict(Counter(v))
1148
1149         query_details = {}
1150         for k, v in self.query_details.items():
1151             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1152                                             for x in v))
1153
1154         d = {
1155             'ngrams': ngrams,
1156             'query_details': query_details,
1157             'cumulative_duration': self.cumulative_duration,
1158             'conversation_rate': self.conversation_rate,
1159         }
1160         d['dns'] = self.dns_opcounts
1161
1162         if isinstance(f, str):
1163             f = open(f, 'w')
1164
1165         json.dump(d, f, indent=2)
1166
1167     def load(self, f):
1168         if isinstance(f, str):
1169             f = open(f)
1170
1171         d = json.load(f)
1172
1173         for k, v in d['ngrams'].items():
1174             k = tuple(str(k).split('\t'))
1175             values = self.ngrams.setdefault(k, [])
1176             for p, count in v.items():
1177                 values.extend([str(p)] * count)
1178
1179         for k, v in d['query_details'].items():
1180             values = self.query_details.setdefault(str(k), [])
1181             for p, count in v.items():
1182                 if p == '-':
1183                     values.extend([()] * count)
1184                 else:
1185                     values.extend([tuple(str(p).split('\t'))] * count)
1186
1187         if 'dns' in d:
1188             for k, v in d['dns'].items():
1189                 self.dns_opcounts[k] += v
1190
1191         self.cumulative_duration = d['cumulative_duration']
1192         self.conversation_rate = d['conversation_rate']
1193
1194     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1195                                hard_stop=None, packet_rate=1):
1196         """Construct a individual converation from the model."""
1197
1198         c = Conversation(timestamp, (server, client))
1199
1200         key = (NON_PACKET,) * (self.n - 1)
1201
1202         while key in self.ngrams:
1203             p = random.choice(self.ngrams.get(key, NON_PACKET))
1204             if p == NON_PACKET:
1205                 break
1206             if p in self.query_details:
1207                 extra = random.choice(self.query_details[p])
1208             else:
1209                 extra = []
1210
1211             protocol, opcode = p.split(':', 1)
1212             if protocol == 'wait':
1213                 log_wait_time = int(opcode) + random.random()
1214                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1215                 timestamp += wait
1216             else:
1217                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1218                 wait = math.exp(log_wait) / packet_rate
1219                 timestamp += wait
1220                 if hard_stop is not None and timestamp > hard_stop:
1221                     break
1222                 c.add_short_packet(timestamp, protocol, opcode, extra)
1223
1224             key = key[1:] + (p,)
1225
1226         return c
1227
1228     def generate_conversations(self, rate, duration, packet_rate=1):
1229         """Generate a list of conversations from the model."""
1230
1231         # We run the simulation for at least ten times as long as our
1232         # desired duration, and take a section near the start.
1233         rate_n, rate_t  = self.conversation_rate
1234
1235         duration2 = max(rate_t, duration * 2)
1236         n = rate * duration2 * rate_n / rate_t
1237
1238         server = 1
1239         client = 2
1240
1241         conversations = []
1242         end = duration2
1243         start = end - duration
1244
1245         while client < n + 2:
1246             start = random.uniform(0, duration2)
1247             c = self.construct_conversation(start,
1248                                             client,
1249                                             server,
1250                                             hard_stop=(duration2 * 5),
1251                                             packet_rate=packet_rate)
1252
1253             c.forget_packets_outside_window(start, end)
1254             c.renormalise_times(start)
1255             if len(c) != 0:
1256                 conversations.append(c)
1257             client += 1
1258
1259         print(("we have %d conversations at rate %f" %
1260                               (len(conversations), rate)), file=sys.stderr)
1261         conversations.sort()
1262         return conversations
1263
1264
1265 IP_PROTOCOLS = {
1266     'dns': '11',
1267     'rpc_netlogon': '06',
1268     'kerberos': '06',      # ratio 16248:258
1269     'smb': '06',
1270     'smb2': '06',
1271     'ldap': '06',
1272     'cldap': '11',
1273     'lsarpc': '06',
1274     'samr': '06',
1275     'dcerpc': '06',
1276     'epm': '06',
1277     'drsuapi': '06',
1278     'browser': '11',
1279     'smb_netlogon': '11',
1280     'srvsvc': '06',
1281     'nbns': '11',
1282 }
1283
1284 OP_DESCRIPTIONS = {
1285     ('browser', '0x01'): 'Host Announcement (0x01)',
1286     ('browser', '0x02'): 'Request Announcement (0x02)',
1287     ('browser', '0x08'): 'Browser Election Request (0x08)',
1288     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1289     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1290     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1291     ('cldap', '3'): 'searchRequest',
1292     ('cldap', '5'): 'searchResDone',
1293     ('dcerpc', '0'): 'Request',
1294     ('dcerpc', '11'): 'Bind',
1295     ('dcerpc', '12'): 'Bind_ack',
1296     ('dcerpc', '13'): 'Bind_nak',
1297     ('dcerpc', '14'): 'Alter_context',
1298     ('dcerpc', '15'): 'Alter_context_resp',
1299     ('dcerpc', '16'): 'AUTH3',
1300     ('dcerpc', '2'): 'Response',
1301     ('dns', '0'): 'query',
1302     ('dns', '1'): 'response',
1303     ('drsuapi', '0'): 'DsBind',
1304     ('drsuapi', '12'): 'DsCrackNames',
1305     ('drsuapi', '13'): 'DsWriteAccountSpn',
1306     ('drsuapi', '1'): 'DsUnbind',
1307     ('drsuapi', '2'): 'DsReplicaSync',
1308     ('drsuapi', '3'): 'DsGetNCChanges',
1309     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1310     ('epm', '3'): 'Map',
1311     ('kerberos', ''): '',
1312     ('ldap', '0'): 'bindRequest',
1313     ('ldap', '1'): 'bindResponse',
1314     ('ldap', '2'): 'unbindRequest',
1315     ('ldap', '3'): 'searchRequest',
1316     ('ldap', '4'): 'searchResEntry',
1317     ('ldap', '5'): 'searchResDone',
1318     ('ldap', ''): '*** Unknown ***',
1319     ('lsarpc', '14'): 'lsa_LookupNames',
1320     ('lsarpc', '15'): 'lsa_LookupSids',
1321     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1322     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1323     ('lsarpc', '6'): 'lsa_OpenPolicy',
1324     ('lsarpc', '76'): 'lsa_LookupSids3',
1325     ('lsarpc', '77'): 'lsa_LookupNames4',
1326     ('nbns', '0'): 'query',
1327     ('nbns', '1'): 'response',
1328     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1329     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1330     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1331     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1332     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1333     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1334     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1335     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1336     ('samr', '0',): 'Connect',
1337     ('samr', '16'): 'GetAliasMembership',
1338     ('samr', '17'): 'LookupNames',
1339     ('samr', '18'): 'LookupRids',
1340     ('samr', '19'): 'OpenGroup',
1341     ('samr', '1'): 'Close',
1342     ('samr', '25'): 'QueryGroupMember',
1343     ('samr', '34'): 'OpenUser',
1344     ('samr', '36'): 'QueryUserInfo',
1345     ('samr', '39'): 'GetGroupsForUser',
1346     ('samr', '3'): 'QuerySecurity',
1347     ('samr', '5'): 'LookupDomain',
1348     ('samr', '64'): 'Connect5',
1349     ('samr', '6'): 'EnumDomains',
1350     ('samr', '7'): 'OpenDomain',
1351     ('samr', '8'): 'QueryDomainInfo',
1352     ('smb', '0x04'): 'Close (0x04)',
1353     ('smb', '0x24'): 'Locking AndX (0x24)',
1354     ('smb', '0x2e'): 'Read AndX (0x2e)',
1355     ('smb', '0x32'): 'Trans2 (0x32)',
1356     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1357     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1358     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1359     ('smb', '0x74'): 'Logoff AndX (0x74)',
1360     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1361     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1362     ('smb2', '0'): 'NegotiateProtocol',
1363     ('smb2', '11'): 'Ioctl',
1364     ('smb2', '14'): 'Find',
1365     ('smb2', '16'): 'GetInfo',
1366     ('smb2', '18'): 'Break',
1367     ('smb2', '1'): 'SessionSetup',
1368     ('smb2', '2'): 'SessionLogoff',
1369     ('smb2', '3'): 'TreeConnect',
1370     ('smb2', '4'): 'TreeDisconnect',
1371     ('smb2', '5'): 'Create',
1372     ('smb2', '6'): 'Close',
1373     ('smb2', '8'): 'Read',
1374     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1375     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1376                                'user unknown (0x17)'),
1377     ('srvsvc', '16'): 'NetShareGetInfo',
1378     ('srvsvc', '21'): 'NetSrvGetInfo',
1379 }
1380
1381
1382 def expand_short_packet(p, timestamp, src, dest, extra):
1383     protocol, opcode = p.split(':', 1)
1384     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1385     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1386
1387     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1388     line.extend(extra)
1389     return '\t'.join(line)
1390
1391
1392 def replay(conversations,
1393            host=None,
1394            creds=None,
1395            lp=None,
1396            accounts=None,
1397            dns_rate=0,
1398            duration=None,
1399            **kwargs):
1400
1401     context = ReplayContext(server=host,
1402                             creds=creds,
1403                             lp=lp,
1404                             **kwargs)
1405
1406     if len(accounts) < len(conversations):
1407         print(("we have %d accounts but %d conversations" %
1408                (accounts, conversations)), file=sys.stderr)
1409
1410     cstack = list(zip(
1411         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1412         accounts))
1413
1414     # Set the process group so that the calling scripts are not killed
1415     # when the forked child processes are killed.
1416     os.setpgrp()
1417
1418     start = time.time()
1419
1420     if duration is None:
1421         # end 1 second after the last packet of the last conversation
1422         # to start. Conversations other than the last could still be
1423         # going, but we don't care.
1424         duration = cstack[0][0].packets[-1].timestamp + 1.0
1425         print("We will stop after %.1f seconds" % duration,
1426               file=sys.stderr)
1427
1428     end = start + duration
1429
1430     print("Replaying traffic for %u conversations over %d seconds"
1431           % (len(conversations), duration))
1432
1433     children = {}
1434     if dns_rate:
1435         dns_hammer = DnsHammer(dns_rate, duration)
1436         cstack.append((dns_hammer, None))
1437
1438     try:
1439         while True:
1440             # we spawn a batch, wait for finishers, then spawn another
1441             now = time.time()
1442             batch_end = min(now + 2.0, end)
1443             fork_time = 0.0
1444             fork_n = 0
1445             while cstack:
1446                 c, account = cstack.pop()
1447                 if c.start_time + start > batch_end:
1448                     cstack.append((c, account))
1449                     break
1450
1451                 st = time.time()
1452                 pid = c.replay_in_fork_with_delay(start, context, account)
1453                 children[pid] = c
1454                 t = time.time()
1455                 elapsed = t - st
1456                 fork_time += elapsed
1457                 fork_n += 1
1458                 print("forked %s in pid %s (in %fs)" % (c, pid,
1459                                                         elapsed),
1460                       file=sys.stderr)
1461
1462             if fork_n:
1463                 print(("forked %d times in %f seconds (avg %f)" %
1464                        (fork_n, fork_time, fork_time / fork_n)),
1465                       file=sys.stderr)
1466             elif cstack:
1467                 debug(2, "no forks in batch ending %f" % batch_end)
1468
1469             while time.time() < batch_end - 1.0:
1470                 time.sleep(0.01)
1471                 try:
1472                     pid, status = os.waitpid(-1, os.WNOHANG)
1473                 except OSError as e:
1474                     if e.errno != 10:  # no child processes
1475                         raise
1476                     break
1477                 if pid:
1478                     c = children.pop(pid, None)
1479                     print(("process %d finished conversation %s;"
1480                            " %d to go" %
1481                            (pid, c, len(children))), file=sys.stderr)
1482
1483             if time.time() >= end:
1484                 print("time to stop", file=sys.stderr)
1485                 break
1486
1487     except Exception:
1488         print("EXCEPTION in parent", file=sys.stderr)
1489         traceback.print_exc()
1490     finally:
1491         for s in (15, 15, 9):
1492             print(("killing %d children with -%d" %
1493                                  (len(children), s)), file=sys.stderr)
1494             for pid in children:
1495                 try:
1496                     os.kill(pid, s)
1497                 except OSError as e:
1498                     if e.errno != 3:  # don't fail if it has already died
1499                         raise
1500             time.sleep(0.5)
1501             end = time.time() + 1
1502             while children:
1503                 try:
1504                     pid, status = os.waitpid(-1, os.WNOHANG)
1505                 except OSError as e:
1506                     if e.errno != 10:
1507                         raise
1508                 if pid != 0:
1509                     c = children.pop(pid, None)
1510                     print(("kill -%d %d KILLED conversation %s; "
1511                            "%d to go" %
1512                            (s, pid, c, len(children))),
1513                           file=sys.stderr)
1514                 if time.time() >= end:
1515                     break
1516
1517             if not children:
1518                 break
1519             time.sleep(1)
1520
1521         if children:
1522             print("%d children are missing" % len(children),
1523                   file=sys.stderr)
1524
1525         # there may be stragglers that were forked just as ^C was hit
1526         # and don't appear in the list of children. We can get them
1527         # with killpg, but that will also kill us, so this is^H^H would be
1528         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1529         # so as not to have to fuss around writing signal handlers.
1530         try:
1531             os.killpg(0, 2)
1532         except KeyboardInterrupt:
1533             print("ignoring fake ^C", file=sys.stderr)
1534
1535
1536 def openLdb(host, creds, lp):
1537     session = system_session()
1538     ldb = SamDB(url="ldap://%s" % host,
1539                 session_info=session,
1540                 credentials=creds,
1541                 lp=lp)
1542     return ldb
1543
1544
1545 def ou_name(ldb, instance_id):
1546     """Generate an ou name from the instance id"""
1547     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1548                                                     ldb.domain_dn())
1549
1550
1551 def create_ou(ldb, instance_id):
1552     """Create an ou, all created user and machine accounts will belong to it.
1553
1554     This allows all the created resources to be cleaned up easily.
1555     """
1556     ou = ou_name(ldb, instance_id)
1557     try:
1558         ldb.add({"dn":          ou.split(',', 1)[1],
1559                  "objectclass": "organizationalunit"})
1560     except LdbError as e:
1561         (status, _) = e
1562         # ignore already exists
1563         if status != 68:
1564             raise
1565     try:
1566         ldb.add({"dn":          ou,
1567                  "objectclass": "organizationalunit"})
1568     except LdbError as e:
1569         (status, _) = e
1570         # ignore already exists
1571         if status != 68:
1572             raise
1573     return ou
1574
1575
1576 class ConversationAccounts(object):
1577     """Details of the machine and user accounts associated with a conversation.
1578     """
1579     def __init__(self, netbios_name, machinepass, username, userpass):
1580         self.netbios_name = netbios_name
1581         self.machinepass  = machinepass
1582         self.username     = username
1583         self.userpass     = userpass
1584
1585
1586 def generate_replay_accounts(ldb, instance_id, number, password):
1587     """Generate a series of unique machine and user account names."""
1588
1589     generate_traffic_accounts(ldb, instance_id, number, password)
1590     accounts = []
1591     for i in range(1, number + 1):
1592         netbios_name = "STGM-%d-%d" % (instance_id, i)
1593         username     = "STGU-%d-%d" % (instance_id, i)
1594
1595         account = ConversationAccounts(netbios_name, password, username,
1596                                        password)
1597         accounts.append(account)
1598     return accounts
1599
1600
1601 def generate_traffic_accounts(ldb, instance_id, number, password):
1602     """Create the specified number of user and machine accounts.
1603
1604     As accounts are not explicitly deleted between runs. This function starts
1605     with the last account and iterates backwards stopping either when it
1606     finds an already existing account or it has generated all the required
1607     accounts.
1608     """
1609     print(("Generating machine and conversation accounts, "
1610            "as required for %d conversations" % number),
1611           file=sys.stderr)
1612     added = 0
1613     for i in range(number, 0, -1):
1614         try:
1615             netbios_name = "STGM-%d-%d" % (instance_id, i)
1616             create_machine_account(ldb, instance_id, netbios_name, password)
1617             added += 1
1618         except LdbError as e:
1619             (status, _) = e
1620             if status == 68:
1621                 break
1622             else:
1623                 raise
1624     if added > 0:
1625         print("Added %d new machine accounts" % added,
1626               file=sys.stderr)
1627
1628     added = 0
1629     for i in range(number, 0, -1):
1630         try:
1631             username = "STGU-%d-%d" % (instance_id, i)
1632             create_user_account(ldb, instance_id, username, password)
1633             added += 1
1634         except LdbError as e:
1635             (status, _) = e
1636             if status == 68:
1637                 break
1638             else:
1639                 raise
1640
1641     if added > 0:
1642         print("Added %d new user accounts" % added,
1643               file=sys.stderr)
1644
1645
1646 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1647     """Create a machine account via ldap."""
1648
1649     ou = ou_name(ldb, instance_id)
1650     dn = "cn=%s,%s" % (netbios_name, ou)
1651     utf16pw = unicode(
1652         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1653     ).encode('utf-16-le')
1654     start = time.time()
1655     ldb.add({
1656         "dn": dn,
1657         "objectclass": "computer",
1658         "sAMAccountName": "%s$" % netbios_name,
1659         "userAccountControl":
1660         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1661         "unicodePwd": utf16pw})
1662     end = time.time()
1663     duration = end - start
1664     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1665
1666
1667 def create_user_account(ldb, instance_id, username, userpass):
1668     """Create a user account via ldap."""
1669     ou = ou_name(ldb, instance_id)
1670     user_dn = "cn=%s,%s" % (username, ou)
1671     utf16pw = unicode(
1672         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1673     ).encode('utf-16-le')
1674     start = time.time()
1675     ldb.add({
1676         "dn": user_dn,
1677         "objectclass": "user",
1678         "sAMAccountName": username,
1679         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1680         "unicodePwd": utf16pw
1681     })
1682
1683     # grant user write permission to do things like write account SPN
1684     sdutils = sd_utils.SDUtils(ldb)
1685     sdutils.dacl_add_ace(user_dn,  "(A;;WP;;;PS)")
1686
1687     end = time.time()
1688     duration = end - start
1689     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1690
1691
1692 def create_group(ldb, instance_id, name):
1693     """Create a group via ldap."""
1694
1695     ou = ou_name(ldb, instance_id)
1696     dn = "cn=%s,%s" % (name, ou)
1697     start = time.time()
1698     ldb.add({
1699         "dn": dn,
1700         "objectclass": "group",
1701     })
1702     end = time.time()
1703     duration = end - start
1704     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1705
1706
1707 def user_name(instance_id, i):
1708     """Generate a user name based in the instance id"""
1709     return "STGU-%d-%d" % (instance_id, i)
1710
1711
1712 def generate_users(ldb, instance_id, number, password):
1713     """Add users to the server"""
1714     users = 0
1715     for i in range(number, 0, -1):
1716         try:
1717             username = user_name(instance_id, i)
1718             create_user_account(ldb, instance_id, username, password)
1719             users += 1
1720         except LdbError as e:
1721             (status, _) = e
1722             # Stop if entry exists
1723             if status == 68:
1724                 break
1725             else:
1726                 raise
1727
1728     return users
1729
1730
1731 def group_name(instance_id, i):
1732     """Generate a group name from instance id."""
1733     return "STGG-%d-%d" % (instance_id, i)
1734
1735
1736 def generate_groups(ldb, instance_id, number):
1737     """Create the required number of groups on the server."""
1738     groups = 0
1739     for i in range(number, 0, -1):
1740         try:
1741             name = group_name(instance_id, i)
1742             create_group(ldb, instance_id, name)
1743             groups += 1
1744         except LdbError as e:
1745             (status, _) = e
1746             # Stop if entry exists
1747             if status == 68:
1748                 break
1749             else:
1750                 raise
1751     return groups
1752
1753
1754 def clean_up_accounts(ldb, instance_id):
1755     """Remove the created accounts and groups from the server."""
1756     ou = ou_name(ldb, instance_id)
1757     try:
1758         ldb.delete(ou, ["tree_delete:1"])
1759     except LdbError as e:
1760         (status, _) = e
1761         # ignore does not exist
1762         if status != 32:
1763             raise
1764
1765
1766 def generate_users_and_groups(ldb, instance_id, password,
1767                               number_of_users, number_of_groups,
1768                               group_memberships):
1769     """Generate the required users and groups, allocating the users to
1770        those groups."""
1771     assignments = []
1772     groups_added  = 0
1773
1774     create_ou(ldb, instance_id)
1775
1776     print("Generating dummy user accounts", file=sys.stderr)
1777     users_added = generate_users(ldb, instance_id, number_of_users, password)
1778
1779     if number_of_groups > 0:
1780         print("Generating dummy groups", file=sys.stderr)
1781         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1782
1783     if group_memberships > 0:
1784         print("Assigning users to groups", file=sys.stderr)
1785         assignments = assign_groups(number_of_groups,
1786                                     groups_added,
1787                                     number_of_users,
1788                                     users_added,
1789                                     group_memberships)
1790         print("Adding users to groups", file=sys.stderr)
1791         add_users_to_groups(ldb, instance_id, assignments)
1792
1793     if (groups_added > 0 and users_added == 0 and
1794        number_of_groups != groups_added):
1795         print("Warning: the added groups will contain no members",
1796               file=sys.stderr)
1797
1798     print(("Added %d users, %d groups and %d group memberships" %
1799            (users_added, groups_added, len(assignments))),
1800           file=sys.stderr)
1801
1802
1803 def assign_groups(number_of_groups,
1804                   groups_added,
1805                   number_of_users,
1806                   users_added,
1807                   group_memberships):
1808     """Allocate users to groups.
1809
1810     The intention is to have a few users that belong to most groups, while
1811     the majority of users belong to a few groups.
1812
1813     A few groups will contain most users, with the remaining only having a
1814     few users.
1815     """
1816
1817     def generate_user_distribution(n):
1818         """Probability distribution of a user belonging to a group.
1819         """
1820         dist = []
1821         for x in range(1, n + 1):
1822             p = 1 / (x + 0.001)
1823             dist.append(p)
1824         return dist
1825
1826     def generate_group_distribution(n):
1827         """Probability distribution of a group containing a user."""
1828         dist = []
1829         for x in range(1, n + 1):
1830             p = 1 / (x**1.3)
1831             dist.append(p)
1832         return dist
1833
1834     assignments = set()
1835     if group_memberships <= 0:
1836         return assignments
1837
1838     group_dist = generate_group_distribution(number_of_groups)
1839     user_dist  = generate_user_distribution(number_of_users)
1840
1841     # Calculate the number of group menberships required
1842     group_memberships = math.ceil(
1843         float(group_memberships) *
1844         (float(users_added) / float(number_of_users)))
1845
1846     existing_users  = number_of_users  - users_added  - 1
1847     existing_groups = number_of_groups - groups_added - 1
1848     while len(assignments) < group_memberships:
1849         user        = random.randint(0, number_of_users - 1)
1850         group       = random.randint(0, number_of_groups - 1)
1851         probability = group_dist[group] * user_dist[user]
1852
1853         if ((random.random() < probability * 10000) and
1854            (group > existing_groups or user > existing_users)):
1855             # the + 1 converts the array index to the corresponding
1856             # group or user number
1857             assignments.add(((user + 1), (group + 1)))
1858
1859     return assignments
1860
1861
1862 def add_users_to_groups(db, instance_id, assignments):
1863     """Add users to their assigned groups.
1864
1865     Takes the list of (group,user) tuples generated by assign_groups and
1866     assign the users to their specified groups."""
1867
1868     ou = ou_name(db, instance_id)
1869
1870     def build_dn(name):
1871         return("cn=%s,%s" % (name, ou))
1872
1873     for (user, group) in assignments:
1874         user_dn  = build_dn(user_name(instance_id, user))
1875         group_dn = build_dn(group_name(instance_id, group))
1876
1877         m = ldb.Message()
1878         m.dn = ldb.Dn(db, group_dn)
1879         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1880         start = time.time()
1881         db.modify(m)
1882         end = time.time()
1883         duration = end - start
1884         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1885
1886
1887 def generate_stats(statsdir, timing_file):
1888     """Generate and print the summary stats for a run."""
1889     first      = sys.float_info.max
1890     last       = 0
1891     successful = 0
1892     failed     = 0
1893     latencies  = {}
1894     failures   = {}
1895     unique_converations = set()
1896     conversations = 0
1897
1898     if timing_file is not None:
1899         tw = timing_file.write
1900     else:
1901         def tw(x):
1902             pass
1903
1904     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1905
1906     for filename in os.listdir(statsdir):
1907         path = os.path.join(statsdir, filename)
1908         with open(path, 'r') as f:
1909             for line in f:
1910                 try:
1911                     fields       = line.rstrip('\n').split('\t')
1912                     conversation = fields[1]
1913                     protocol     = fields[2]
1914                     packet_type  = fields[3]
1915                     latency      = float(fields[4])
1916                     first        = min(float(fields[0]) - latency, first)
1917                     last         = max(float(fields[0]), last)
1918
1919                     if protocol not in latencies:
1920                         latencies[protocol] = {}
1921                     if packet_type not in latencies[protocol]:
1922                         latencies[protocol][packet_type] = []
1923
1924                     latencies[protocol][packet_type].append(latency)
1925
1926                     if protocol not in failures:
1927                         failures[protocol] = {}
1928                     if packet_type not in failures[protocol]:
1929                         failures[protocol][packet_type] = 0
1930
1931                     if fields[5] == 'True':
1932                         successful += 1
1933                     else:
1934                         failed += 1
1935                         failures[protocol][packet_type] += 1
1936
1937                     if conversation not in unique_converations:
1938                         unique_converations.add(conversation)
1939                         conversations += 1
1940
1941                     tw(line)
1942                 except (ValueError, IndexError):
1943                     # not a valid line print and ignore
1944                     print(line, file=sys.stderr)
1945                     pass
1946     duration = last - first
1947     if successful == 0:
1948         success_rate = 0
1949     else:
1950         success_rate = successful / duration
1951     if failed == 0:
1952         failure_rate = 0
1953     else:
1954         failure_rate = failed / duration
1955
1956     # print the stats in more human-readable format when stdout is going to the
1957     # console (as opposed to being redirected to a file)
1958     if sys.stdout.isatty():
1959         print("Total conversations:   %10d" % conversations)
1960         print("Successful operations: %10d (%.3f per second)"
1961               % (successful, success_rate))
1962         print("Failed operations:     %10d (%.3f per second)"
1963               % (failed, failure_rate))
1964     else:
1965         print("(%d, %d, %d, %.3f, %.3f)" %
1966               (conversations, successful, failed, success_rate, failure_rate))
1967
1968     if sys.stdout.isatty():
1969         print("Protocol    Op Code  Description                               "
1970               " Count       Failed         Mean       Median          "
1971               "95%        Range          Max")
1972     else:
1973         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1974               "\tmax")
1975     protocols = sorted(latencies.keys())
1976     for protocol in protocols:
1977         packet_types = sorted(latencies[protocol], key=opcode_key)
1978         for packet_type in packet_types:
1979             values     = latencies[protocol][packet_type]
1980             values     = sorted(values)
1981             count      = len(values)
1982             failed     = failures[protocol][packet_type]
1983             mean       = sum(values) / count
1984             median     = calc_percentile(values, 0.50)
1985             percentile = calc_percentile(values, 0.95)
1986             rng        = values[-1] - values[0]
1987             maxv       = values[-1]
1988             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1989             if sys.stdout.isatty:
1990                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1991                       "%12.6f %12.6f %12.6f %12.6f"
1992                       % (protocol,
1993                          packet_type,
1994                          desc,
1995                          count,
1996                          failed,
1997                          mean,
1998                          median,
1999                          percentile,
2000                          rng,
2001                          maxv))
2002             else:
2003                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2004                       % (protocol,
2005                          packet_type,
2006                          desc,
2007                          count,
2008                          failed,
2009                          mean,
2010                          median,
2011                          percentile,
2012                          rng,
2013                          maxv))
2014
2015
2016 def opcode_key(v):
2017     """Sort key for the operation code to ensure that it sorts numerically"""
2018     try:
2019         return "%03d" % int(v)
2020     except:
2021         return v
2022
2023
2024 def calc_percentile(values, percentile):
2025     """Calculate the specified percentile from the list of values.
2026
2027     Assumes the list is sorted in ascending order.
2028     """
2029
2030     if not values:
2031         return 0
2032     k = (len(values) - 1) * percentile
2033     f = math.floor(k)
2034     c = math.ceil(k)
2035     if f == c:
2036         return values[int(k)]
2037     d0 = values[int(f)] * (c - k)
2038     d1 = values[int(c)] * (k - f)
2039     return d0 + d1
2040
2041
2042 def mk_masked_dir(*path):
2043     """In a testenv we end up with 0777 diectories that look an alarming
2044     green colour with ls. Use umask to avoid that."""
2045     d = os.path.join(*path)
2046     mask = os.umask(0o077)
2047     os.mkdir(d)
2048     os.umask(mask)
2049     return d