traffic: simplify forget_packets_outside_window
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49 from samba import sd_utils
50
51 SLEEP_OVERHEAD = 3e-4
52
53 # we don't use None, because it complicates [de]serialisation
54 NON_PACKET = '-'
55
56 CLIENT_CLUES = {
57     ('dns', '0'): 1.0,      # query
58     ('smb', '0x72'): 1.0,   # Negotiate protocol
59     ('ldap', '0'): 1.0,     # bind
60     ('ldap', '3'): 1.0,     # searchRequest
61     ('ldap', '2'): 1.0,     # unbindRequest
62     ('cldap', '3'): 1.0,
63     ('dcerpc', '11'): 1.0,  # bind
64     ('dcerpc', '14'): 1.0,  # Alter_context
65     ('nbns', '0'): 1.0,     # query
66 }
67
68 SERVER_CLUES = {
69     ('dns', '1'): 1.0,      # response
70     ('ldap', '1'): 1.0,     # bind response
71     ('ldap', '4'): 1.0,     # search result
72     ('ldap', '5'): 1.0,     # search done
73     ('cldap', '5'): 1.0,
74     ('dcerpc', '12'): 1.0,  # bind_ack
75     ('dcerpc', '13'): 1.0,  # bind_nak
76     ('dcerpc', '15'): 1.0,  # Alter_context response
77 }
78
79 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
80
81 WAIT_SCALE = 10.0
82 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
83 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
84
85 # DEBUG_LEVEL can be changed by scripts with -d
86 DEBUG_LEVEL = 0
87
88
89 def debug(level, msg, *args):
90     """Print a formatted debug message to standard error.
91
92
93     :param level: The debug level, message will be printed if it is <= the
94                   currently set debug level. The debug level can be set with
95                   the -d option.
96     :param msg:   The message to be logged, can contain C-Style format
97                   specifiers
98     :param args:  The parameters required by the format specifiers
99     """
100     if level <= DEBUG_LEVEL:
101         if not args:
102             print(msg, file=sys.stderr)
103         else:
104             print(msg % tuple(args), file=sys.stderr)
105
106
107 def debug_lineno(*args):
108     """ Print an unformatted log message to stderr, contaning the line number
109     """
110     tb = traceback.extract_stack(limit=2)
111     print((" %s:" "\033[01;33m"
112            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
113           file=sys.stderr)
114     for a in args:
115         print(a, file=sys.stderr)
116     print(file=sys.stderr)
117     sys.stderr.flush()
118
119
120 def random_colour_print():
121     """Return a function that prints a randomly coloured line to stderr"""
122     n = 18 + random.randrange(214)
123     prefix = "\033[38;5;%dm" % n
124
125     def p(*args):
126         for a in args:
127             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
128
129     return p
130
131
132 class FakePacketError(Exception):
133     pass
134
135
136 class Packet(object):
137     """Details of a network packet"""
138     def __init__(self, fields):
139         if isinstance(fields, str):
140             fields = fields.rstrip('\n').split('\t')
141
142         (timestamp,
143          ip_protocol,
144          stream_number,
145          src,
146          dest,
147          protocol,
148          opcode,
149          desc) = fields[:8]
150         extra = fields[8:]
151
152         self.timestamp = float(timestamp)
153         self.ip_protocol = ip_protocol
154         try:
155             self.stream_number = int(stream_number)
156         except (ValueError, TypeError):
157             self.stream_number = None
158         self.src = int(src)
159         self.dest = int(dest)
160         self.protocol = protocol
161         self.opcode = opcode
162         self.desc = desc
163         self.extra = extra
164
165         if self.src < self.dest:
166             self.endpoints = (self.src, self.dest)
167         else:
168             self.endpoints = (self.dest, self.src)
169
170     def as_summary(self, time_offset=0.0):
171         """Format the packet as a traffic_summary line.
172         """
173         extra = '\t'.join(self.extra)
174         t = self.timestamp + time_offset
175         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
176                 (t,
177                  self.ip_protocol,
178                  self.stream_number or '',
179                  self.src,
180                  self.dest,
181                  self.protocol,
182                  self.opcode,
183                  self.desc,
184                  extra))
185
186     def __str__(self):
187         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
188                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
189                  self.stream_number, self.protocol, self.opcode, self.desc,
190                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
191
192     def __repr__(self):
193         return "<Packet @%s>" % self
194
195     def copy(self):
196         return self.__class__([self.timestamp,
197                                self.ip_protocol,
198                                self.stream_number,
199                                self.src,
200                                self.dest,
201                                self.protocol,
202                                self.opcode,
203                                self.desc] + self.extra)
204
205     def as_packet_type(self):
206         t = '%s:%s' % (self.protocol, self.opcode)
207         return t
208
209     def client_score(self):
210         """A positive number means we think it is a client; a negative number
211         means we think it is a server. Zero means no idea. range: -1 to 1.
212         """
213         key = (self.protocol, self.opcode)
214         if key in CLIENT_CLUES:
215             return CLIENT_CLUES[key]
216         if key in SERVER_CLUES:
217             return -SERVER_CLUES[key]
218         return 0.0
219
220     def play(self, conversation, context):
221         """Send the packet over the network, if required.
222
223         Some packets are ignored, i.e. for  protocols not handled,
224         server response messages, or messages that are generated by the
225         protocol layer associated with other packets.
226         """
227         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
228         try:
229             fn = getattr(traffic_packets, fn_name)
230
231         except AttributeError as e:
232             print("Conversation(%s) Missing handler %s" % \
233                   (conversation.conversation_id, fn_name),
234                   file=sys.stderr)
235             return
236
237         # Don't display a message for kerberos packets, they're not directly
238         # generated they're used to indicate kerberos should be used
239         if self.protocol != "kerberos":
240             debug(2, "Conversation(%s) Calling handler %s" %
241                      (conversation.conversation_id, fn_name))
242
243         start = time.time()
244         try:
245             if fn(self, conversation, context):
246                 # Only collect timing data for functions that generate
247                 # network traffic, or fail
248                 end = time.time()
249                 duration = end - start
250                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
251                       (end, conversation.conversation_id, self.protocol,
252                        self.opcode, duration))
253         except Exception as e:
254             end = time.time()
255             duration = end - start
256             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
257                   (end, conversation.conversation_id, self.protocol,
258                    self.opcode, duration, e))
259
260     def __cmp__(self, other):
261         return self.timestamp - other.timestamp
262
263     def is_really_a_packet(self, missing_packet_stats=None):
264         """Is the packet one that can be ignored?
265
266         If so removing it will have no effect on the replay
267         """
268         if self.protocol in SKIPPED_PROTOCOLS:
269             # Ignore any packets for the protocols we're not interested in.
270             return False
271         if self.protocol == "ldap" and self.opcode == '':
272             # skip ldap continuation packets
273             return False
274
275         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
276         try:
277             fn = getattr(traffic_packets, fn_name)
278             if fn is traffic_packets.null_packet:
279                 return False
280         except AttributeError:
281             print("missing packet %s" % fn_name, file=sys.stderr)
282             return False
283         return True
284
285
286 class ReplayContext(object):
287     """State/Context for an individual conversation between an simulated client
288        and a server.
289     """
290
291     def __init__(self,
292                  server=None,
293                  lp=None,
294                  creds=None,
295                  badpassword_frequency=None,
296                  prefer_kerberos=None,
297                  tempdir=None,
298                  statsdir=None,
299                  ou=None,
300                  base_dn=None,
301                  domain=None,
302                  domain_sid=None):
303
304         self.server                   = server
305         self.ldap_connections         = []
306         self.dcerpc_connections       = []
307         self.lsarpc_connections       = []
308         self.lsarpc_connections_named = []
309         self.drsuapi_connections      = []
310         self.srvsvc_connections       = []
311         self.samr_contexts            = []
312         self.netlogon_connection      = None
313         self.creds                    = creds
314         self.lp                       = lp
315         self.prefer_kerberos          = prefer_kerberos
316         self.ou                       = ou
317         self.base_dn                  = base_dn
318         self.domain                   = domain
319         self.statsdir                 = statsdir
320         self.global_tempdir           = tempdir
321         self.domain_sid               = domain_sid
322         self.realm                    = lp.get('realm')
323
324         # Bad password attempt controls
325         self.badpassword_frequency    = badpassword_frequency
326         self.last_lsarpc_bad          = False
327         self.last_lsarpc_named_bad    = False
328         self.last_simple_bind_bad     = False
329         self.last_bind_bad            = False
330         self.last_srvsvc_bad          = False
331         self.last_drsuapi_bad         = False
332         self.last_netlogon_bad        = False
333         self.last_samlogon_bad        = False
334         self.generate_ldap_search_tables()
335         self.next_conversation_id = itertools.count().next
336
337     def generate_ldap_search_tables(self):
338         session = system_session()
339
340         db = SamDB(url="ldap://%s" % self.server,
341                    session_info=session,
342                    credentials=self.creds,
343                    lp=self.lp)
344
345         res = db.search(db.domain_dn(),
346                         scope=ldb.SCOPE_SUBTREE,
347                         controls=["paged_results:1:1000"],
348                         attrs=['dn'])
349
350         # find a list of dns for each pattern
351         # e.g. CN,CN,CN,DC,DC
352         dn_map = {}
353         attribute_clue_map = {
354             'invocationId': []
355         }
356
357         for r in res:
358             dn = str(r.dn)
359             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
360             dns = dn_map.setdefault(pattern, [])
361             dns.append(dn)
362             if dn.startswith('CN=NTDS Settings,'):
363                 attribute_clue_map['invocationId'].append(dn)
364
365         # extend the map in case we are working with a different
366         # number of DC components.
367         # for k, v in self.dn_map.items():
368         #     print >>sys.stderr, k, len(v)
369
370         for k, v in dn_map.items():
371             if k[-3:] != ',DC':
372                 continue
373             p = k[:-3]
374             while p[-3:] == ',DC':
375                 p = p[:-3]
376             for i in range(5):
377                 p += ',DC'
378                 if p != k and p in dn_map:
379                     print('dn_map collison %s %s' % (k, p),
380                           file=sys.stderr)
381                     continue
382                 dn_map[p] = dn_map[k]
383
384         self.dn_map = dn_map
385         self.attribute_clue_map = attribute_clue_map
386
387     def generate_process_local_config(self, account, conversation):
388         if account is None:
389             return
390         self.netbios_name             = account.netbios_name
391         self.machinepass              = account.machinepass
392         self.username                 = account.username
393         self.userpass                 = account.userpass
394
395         self.tempdir = mk_masked_dir(self.global_tempdir,
396                                      'conversation-%d' %
397                                      conversation.conversation_id)
398
399         self.lp.set("private dir",     self.tempdir)
400         self.lp.set("lock dir",        self.tempdir)
401         self.lp.set("state directory", self.tempdir)
402         self.lp.set("tls verify peer", "no_check")
403
404         # If the domain was not specified, check for the environment
405         # variable.
406         if self.domain is None:
407             self.domain = os.environ["DOMAIN"]
408
409         self.remoteAddress = "/root/ncalrpc_as_system"
410         self.samlogon_dn   = ("cn=%s,%s" %
411                               (self.netbios_name, self.ou))
412         self.user_dn       = ("cn=%s,%s" %
413                               (self.username, self.ou))
414
415         self.generate_machine_creds()
416         self.generate_user_creds()
417
418     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
419         """Execute the supplied logon function, randomly choosing the
420            bad credentials.
421
422            Based on the frequency in badpassword_frequency randomly perform the
423            function with the supplied bad credentials.
424            If run with bad credentials, the function is re-run with the good
425            credentials.
426            failed_last_time is used to prevent consecutive bad credential
427            attempts. So the over all bad credential frequency will be lower
428            than that requested, but not significantly.
429         """
430         if not failed_last_time:
431             if (self.badpassword_frequency > 0 and
432                random.random() < self.badpassword_frequency):
433                 try:
434                     f(bad)
435                 except:
436                     # Ignore any exceptions as the operation may fail
437                     # as it's being performed with bad credentials
438                     pass
439                 failed_last_time = True
440             else:
441                 failed_last_time = False
442
443         result = f(good)
444         return (result, failed_last_time)
445
446     def generate_user_creds(self):
447         """Generate the conversation specific user Credentials.
448
449         Each Conversation has an associated user account used to simulate
450         any non Administrative user traffic.
451
452         Generates user credentials with good and bad passwords and ldap
453         simple bind credentials with good and bad passwords.
454         """
455         self.user_creds = Credentials()
456         self.user_creds.guess(self.lp)
457         self.user_creds.set_workstation(self.netbios_name)
458         self.user_creds.set_password(self.userpass)
459         self.user_creds.set_username(self.username)
460         self.user_creds.set_domain(self.domain)
461         if self.prefer_kerberos:
462             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
463         else:
464             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
465
466         self.user_creds_bad = Credentials()
467         self.user_creds_bad.guess(self.lp)
468         self.user_creds_bad.set_workstation(self.netbios_name)
469         self.user_creds_bad.set_password(self.userpass[:-4])
470         self.user_creds_bad.set_username(self.username)
471         if self.prefer_kerberos:
472             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
473         else:
474             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
475
476         # Credentials for ldap simple bind.
477         self.simple_bind_creds = Credentials()
478         self.simple_bind_creds.guess(self.lp)
479         self.simple_bind_creds.set_workstation(self.netbios_name)
480         self.simple_bind_creds.set_password(self.userpass)
481         self.simple_bind_creds.set_username(self.username)
482         self.simple_bind_creds.set_gensec_features(
483             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
484         if self.prefer_kerberos:
485             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
486         else:
487             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
488         self.simple_bind_creds.set_bind_dn(self.user_dn)
489
490         self.simple_bind_creds_bad = Credentials()
491         self.simple_bind_creds_bad.guess(self.lp)
492         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
493         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
494         self.simple_bind_creds_bad.set_username(self.username)
495         self.simple_bind_creds_bad.set_gensec_features(
496             self.simple_bind_creds_bad.get_gensec_features() |
497             gensec.FEATURE_SEAL)
498         if self.prefer_kerberos:
499             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
500         else:
501             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
502         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
503
504     def generate_machine_creds(self):
505         """Generate the conversation specific machine Credentials.
506
507         Each Conversation has an associated machine account.
508
509         Generates machine credentials with good and bad passwords.
510         """
511
512         self.machine_creds = Credentials()
513         self.machine_creds.guess(self.lp)
514         self.machine_creds.set_workstation(self.netbios_name)
515         self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
516         self.machine_creds.set_password(self.machinepass)
517         self.machine_creds.set_username(self.netbios_name + "$")
518         self.machine_creds.set_domain(self.domain)
519         if self.prefer_kerberos:
520             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
521         else:
522             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
523
524         self.machine_creds_bad = Credentials()
525         self.machine_creds_bad.guess(self.lp)
526         self.machine_creds_bad.set_workstation(self.netbios_name)
527         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
528         self.machine_creds_bad.set_password(self.machinepass[:-4])
529         self.machine_creds_bad.set_username(self.netbios_name + "$")
530         if self.prefer_kerberos:
531             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
532         else:
533             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
534
535     def get_matching_dn(self, pattern, attributes=None):
536         # If the pattern is an empty string, we assume ROOTDSE,
537         # Otherwise we try adding or removing DC suffixes, then
538         # shorter leading patterns until we hit one.
539         # e.g if there is no CN,CN,CN,CN,DC,DC
540         # we first try       CN,CN,CN,CN,DC
541         # and                CN,CN,CN,CN,DC,DC,DC
542         # then change to        CN,CN,CN,DC,DC
543         # and as last resort we use the base_dn
544         attr_clue = self.attribute_clue_map.get(attributes)
545         if attr_clue:
546             return random.choice(attr_clue)
547
548         pattern = pattern.upper()
549         while pattern:
550             if pattern in self.dn_map:
551                 return random.choice(self.dn_map[pattern])
552             # chop one off the front and try it all again.
553             pattern = pattern[3:]
554
555         return self.base_dn
556
557     def get_dcerpc_connection(self, new=False):
558         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
559         if self.dcerpc_connections and not new:
560             return self.dcerpc_connections[-1]
561         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
562                              (guid, 1), self.lp)
563         self.dcerpc_connections.append(c)
564         return c
565
566     def get_srvsvc_connection(self, new=False):
567         if self.srvsvc_connections and not new:
568             return self.srvsvc_connections[-1]
569
570         def connect(creds):
571             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
572                                  self.lp,
573                                  creds)
574
575         (c, self.last_srvsvc_bad) = \
576             self.with_random_bad_credentials(connect,
577                                              self.user_creds,
578                                              self.user_creds_bad,
579                                              self.last_srvsvc_bad)
580
581         self.srvsvc_connections.append(c)
582         return c
583
584     def get_lsarpc_connection(self, new=False):
585         if self.lsarpc_connections and not new:
586             return self.lsarpc_connections[-1]
587
588         def connect(creds):
589             binding_options = 'schannel,seal,sign'
590             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
591                               (self.server, binding_options),
592                               self.lp,
593                               creds)
594
595         (c, self.last_lsarpc_bad) = \
596             self.with_random_bad_credentials(connect,
597                                              self.machine_creds,
598                                              self.machine_creds_bad,
599                                              self.last_lsarpc_bad)
600
601         self.lsarpc_connections.append(c)
602         return c
603
604     def get_lsarpc_named_pipe_connection(self, new=False):
605         if self.lsarpc_connections_named and not new:
606             return self.lsarpc_connections_named[-1]
607
608         def connect(creds):
609             return lsa.lsarpc("ncacn_np:%s" % (self.server),
610                               self.lp,
611                               creds)
612
613         (c, self.last_lsarpc_named_bad) = \
614             self.with_random_bad_credentials(connect,
615                                              self.machine_creds,
616                                              self.machine_creds_bad,
617                                              self.last_lsarpc_named_bad)
618
619         self.lsarpc_connections_named.append(c)
620         return c
621
622     def get_drsuapi_connection_pair(self, new=False, unbind=False):
623         """get a (drs, drs_handle) tuple"""
624         if self.drsuapi_connections and not new:
625             c = self.drsuapi_connections[-1]
626             return c
627
628         def connect(creds):
629             binding_options = 'seal'
630             binding_string = "ncacn_ip_tcp:%s[%s]" %\
631                              (self.server, binding_options)
632             return drsuapi.drsuapi(binding_string, self.lp, creds)
633
634         (drs, self.last_drsuapi_bad) = \
635             self.with_random_bad_credentials(connect,
636                                              self.user_creds,
637                                              self.user_creds_bad,
638                                              self.last_drsuapi_bad)
639
640         (drs_handle, supported_extensions) = drs_DsBind(drs)
641         c = (drs, drs_handle)
642         self.drsuapi_connections.append(c)
643         return c
644
645     def get_ldap_connection(self, new=False, simple=False):
646         if self.ldap_connections and not new:
647             return self.ldap_connections[-1]
648
649         def simple_bind(creds):
650             """
651             To run simple bind against Windows, we need to run
652             following commands in PowerShell:
653
654                 Install-windowsfeature ADCS-Cert-Authority
655                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
656                 Restart-Computer
657
658             """
659             return SamDB('ldaps://%s' % self.server,
660                          credentials=creds,
661                          lp=self.lp)
662
663         def sasl_bind(creds):
664             return SamDB('ldap://%s' % self.server,
665                          credentials=creds,
666                          lp=self.lp)
667         if simple:
668             (samdb, self.last_simple_bind_bad) = \
669                 self.with_random_bad_credentials(simple_bind,
670                                                  self.simple_bind_creds,
671                                                  self.simple_bind_creds_bad,
672                                                  self.last_simple_bind_bad)
673         else:
674             (samdb, self.last_bind_bad) = \
675                 self.with_random_bad_credentials(sasl_bind,
676                                                  self.user_creds,
677                                                  self.user_creds_bad,
678                                                  self.last_bind_bad)
679
680         self.ldap_connections.append(samdb)
681         return samdb
682
683     def get_samr_context(self, new=False):
684         if not self.samr_contexts or new:
685             self.samr_contexts.append(
686                 SamrContext(self.server, lp=self.lp, creds=self.creds))
687         return self.samr_contexts[-1]
688
689     def get_netlogon_connection(self):
690
691         if self.netlogon_connection:
692             return self.netlogon_connection
693
694         def connect(creds):
695             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
696                                      (self.server),
697                                      self.lp,
698                                      creds)
699         (c, self.last_netlogon_bad) = \
700             self.with_random_bad_credentials(connect,
701                                              self.machine_creds,
702                                              self.machine_creds_bad,
703                                              self.last_netlogon_bad)
704         self.netlogon_connection = c
705         return c
706
707     def guess_a_dns_lookup(self):
708         return (self.realm, 'A')
709
710     def get_authenticator(self):
711         auth = self.machine_creds.new_client_authenticator()
712         current  = netr_Authenticator()
713         current.cred.data = [ord(x) for x in auth["credential"]]
714         current.timestamp = auth["timestamp"]
715
716         subsequent = netr_Authenticator()
717         return (current, subsequent)
718
719
720 class SamrContext(object):
721     """State/Context associated with a samr connection.
722     """
723     def __init__(self, server, lp=None, creds=None):
724         self.connection    = None
725         self.handle        = None
726         self.domain_handle = None
727         self.domain_sid    = None
728         self.group_handle  = None
729         self.user_handle   = None
730         self.rids          = None
731         self.server        = server
732         self.lp            = lp
733         self.creds         = creds
734
735     def get_connection(self):
736         if not self.connection:
737             self.connection = samr.samr(
738                 "ncacn_ip_tcp:%s[seal]" % (self.server),
739                 lp_ctx=self.lp,
740                 credentials=self.creds)
741
742         return self.connection
743
744     def get_handle(self):
745         if not self.handle:
746             c = self.get_connection()
747             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
748         return self.handle
749
750
751 class Conversation(object):
752     """Details of a converation between a simulated client and a server."""
753     conversation_id = None
754
755     def __init__(self, start_time=None, endpoints=None):
756         self.start_time = start_time
757         self.endpoints = endpoints
758         self.packets = []
759         self.msg = random_colour_print()
760         self.client_balance = 0.0
761
762     def __cmp__(self, other):
763         if self.start_time is None:
764             if other.start_time is None:
765                 return 0
766             return -1
767         if other.start_time is None:
768             return 1
769         return self.start_time - other.start_time
770
771     def add_packet(self, packet):
772         """Add a packet object to this conversation, making a local copy with
773         a conversation-relative timestamp."""
774         p = packet.copy()
775
776         if self.start_time is None:
777             self.start_time = p.timestamp
778
779         if self.endpoints is None:
780             self.endpoints = p.endpoints
781
782         if p.endpoints != self.endpoints:
783             raise FakePacketError("Conversation endpoints %s don't match"
784                                   "packet endpoints %s" %
785                                   (self.endpoints, p.endpoints))
786
787         p.timestamp -= self.start_time
788
789         if p.src == p.endpoints[0]:
790             self.client_balance -= p.client_score()
791         else:
792             self.client_balance += p.client_score()
793
794         if p.is_really_a_packet():
795             self.packets.append(p)
796
797     def add_short_packet(self, timestamp, p, extra, client=True):
798         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
799         (possibly empty) list of extra data. If client is True, assume
800         this packet is from the client to the server.
801         """
802         protocol, opcode = p.split(':', 1)
803         src, dest = self.guess_client_server()
804         if not client:
805             src, dest = dest, src
806
807         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
808         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
809         fields = [timestamp - self.start_time, ip_protocol,
810                   '', src, dest,
811                   protocol, opcode, desc]
812         fields.extend(extra)
813         packet = Packet(fields)
814         # XXX we're assuming the timestamp is already adjusted for
815         # this conversation?
816         # XXX should we adjust client balance for guessed packets?
817         if packet.src == packet.endpoints[0]:
818             self.client_balance -= packet.client_score()
819         else:
820             self.client_balance += packet.client_score()
821         if packet.is_really_a_packet():
822             self.packets.append(packet)
823
824     def __str__(self):
825         return ("<Conversation %s %s starting %.3f %d packets>" %
826                 (self.conversation_id, self.endpoints, self.start_time,
827                  len(self.packets)))
828
829     __repr__ = __str__
830
831     def __iter__(self):
832         return iter(self.packets)
833
834     def __len__(self):
835         return len(self.packets)
836
837     def get_duration(self):
838         if len(self.packets) < 2:
839             return 0
840         return self.packets[-1].timestamp - self.packets[0].timestamp
841
842     def replay_as_summary_lines(self):
843         lines = []
844         for p in self.packets:
845             lines.append(p.as_summary(self.start_time))
846         return lines
847
848     def replay_in_fork_with_delay(self, start, context=None, account=None):
849         """Fork a new process and replay the conversation.
850         """
851         def signal_handler(signal, frame):
852             """Signal handler closes standard out and error.
853
854             Triggered by a sigterm, ensures that the log messages are flushed
855             to disk and not lost.
856             """
857             sys.stderr.close()
858             sys.stdout.close()
859             os._exit(0)
860
861         t = self.start_time
862         now = time.time() - start
863         gap = t - now
864         # we are replaying strictly in order, so it is safe to sleep
865         # in the main process if the gap is big enough. This reduces
866         # the number of concurrent threads, which allows us to make
867         # larger loads.
868         if gap > 0.15 and False:
869             print("sleeping for %f in main process" % (gap - 0.1),
870                   file=sys.stderr)
871             time.sleep(gap - 0.1)
872             now = time.time() - start
873             gap = t - now
874             print("gap is now %f" % gap, file=sys.stderr)
875
876         self.conversation_id = context.next_conversation_id()
877         pid = os.fork()
878         if pid != 0:
879             return pid
880         pid = os.getpid()
881         signal.signal(signal.SIGTERM, signal_handler)
882         # we must never return, or we'll end up running parts of the
883         # parent's clean-up code. So we work in a try...finally, and
884         # try to print any exceptions.
885
886         try:
887             context.generate_process_local_config(account, self)
888             sys.stdin.close()
889             os.close(0)
890             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
891                                     self.conversation_id)
892             sys.stdout.close()
893             sys.stdout = open(filename, 'w')
894
895             sleep_time = gap - SLEEP_OVERHEAD
896             if sleep_time > 0:
897                 time.sleep(sleep_time)
898
899             miss = t - (time.time() - start)
900             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
901             self.replay(context)
902         except Exception:
903             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
904                   file=sys.stderr)
905             traceback.print_exc(sys.stderr)
906         finally:
907             sys.stderr.close()
908             sys.stdout.close()
909             os._exit(0)
910
911     def replay(self, context=None):
912         start = time.time()
913
914         for p in self.packets:
915             now = time.time() - start
916             gap = p.timestamp - now
917             sleep_time = gap - SLEEP_OVERHEAD
918             if sleep_time > 0:
919                 time.sleep(sleep_time)
920
921             miss = p.timestamp - (time.time() - start)
922             if context is None:
923                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
924                                                            os.getpid()))
925                 continue
926             p.play(self, context)
927
928     def guess_client_server(self, server_clue=None):
929         """Have a go at deciding who is the server and who is the client.
930         returns (client, server)
931         """
932         a, b = self.endpoints
933
934         if self.client_balance < 0:
935             return (a, b)
936
937         # in the absense of a clue, we will fall through to assuming
938         # the lowest number is the server (which is usually true).
939
940         if self.client_balance == 0 and server_clue == b:
941             return (a, b)
942
943         return (b, a)
944
945     def forget_packets_outside_window(self, s, e):
946         """Prune any packets outside the timne window we're interested in
947
948         :param s: start of the window
949         :param e: end of the window
950         """
951         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
952         self.start_time = self.packets[0].timestamp if self.packets else None
953
954     def renormalise_times(self, start_time):
955         """Adjust the packet start times relative to the new start time."""
956         for p in self.packets:
957             p.timestamp -= start_time
958
959         if self.start_time is not None:
960             self.start_time -= start_time
961
962
963 class DnsHammer(Conversation):
964     """A lightweight conversation that generates a lot of dns:0 packets on
965     the fly"""
966
967     def __init__(self, dns_rate, duration):
968         n = int(dns_rate * duration)
969         self.times = [random.uniform(0, duration) for i in range(n)]
970         self.times.sort()
971         self.rate = dns_rate
972         self.duration = duration
973         self.start_time = 0
974         self.msg = random_colour_print()
975
976     def __str__(self):
977         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
978                 (len(self.times), self.duration, self.rate))
979
980     def replay_in_fork_with_delay(self, start, context=None, account=None):
981         return Conversation.replay_in_fork_with_delay(self,
982                                                       start,
983                                                       context,
984                                                       account)
985
986     def replay(self, context=None):
987         start = time.time()
988         fn = traffic_packets.packet_dns_0
989         for t in self.times:
990             now = time.time() - start
991             gap = t - now
992             sleep_time = gap - SLEEP_OVERHEAD
993             if sleep_time > 0:
994                 time.sleep(sleep_time)
995
996             if context is None:
997                 miss = t - (time.time() - start)
998                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
999                                                            os.getpid()))
1000                 continue
1001
1002             packet_start = time.time()
1003             try:
1004                 fn(self, self, context)
1005                 end = time.time()
1006                 duration = end - packet_start
1007                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1008             except Exception as e:
1009                 end = time.time()
1010                 duration = end - packet_start
1011                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1012
1013
1014 def ingest_summaries(files, dns_mode='count'):
1015     """Load a summary traffic summary file and generated Converations from it.
1016     """
1017
1018     dns_counts = defaultdict(int)
1019     packets = []
1020     for f in files:
1021         if isinstance(f, str):
1022             f = open(f)
1023         print("Ingesting %s" % (f.name,), file=sys.stderr)
1024         for line in f:
1025             p = Packet(line)
1026             if p.protocol == 'dns' and dns_mode != 'include':
1027                 dns_counts[p.opcode] += 1
1028             else:
1029                 packets.append(p)
1030
1031         f.close()
1032
1033     if not packets:
1034         return [], 0
1035
1036     start_time = min(p.timestamp for p in packets)
1037     last_packet = max(p.timestamp for p in packets)
1038
1039     print("gathering packets into conversations", file=sys.stderr)
1040     conversations = OrderedDict()
1041     for p in packets:
1042         p.timestamp -= start_time
1043         c = conversations.get(p.endpoints)
1044         if c is None:
1045             c = Conversation()
1046             conversations[p.endpoints] = c
1047         c.add_packet(p)
1048
1049     # We only care about conversations with actual traffic, so we
1050     # filter out conversations with nothing to say. We do that here,
1051     # rather than earlier, because those empty packets contain useful
1052     # hints as to which end of the conversation was the client.
1053     conversation_list = []
1054     for c in conversations.values():
1055         if len(c) != 0:
1056             conversation_list.append(c)
1057
1058     # This is obviously not correct, as many conversations will appear
1059     # to start roughly simultaneously at the beginning of the snapshot.
1060     # To which we say: oh well, so be it.
1061     duration = float(last_packet - start_time)
1062     mean_interval = len(conversations) / duration
1063
1064     return conversation_list, mean_interval, duration, dns_counts
1065
1066
1067 def guess_server_address(conversations):
1068     # we guess the most common address.
1069     addresses = Counter()
1070     for c in conversations:
1071         addresses.update(c.endpoints)
1072     if addresses:
1073         return addresses.most_common(1)[0]
1074
1075
1076 def stringify_keys(x):
1077     y = {}
1078     for k, v in x.items():
1079         k2 = '\t'.join(k)
1080         y[k2] = v
1081     return y
1082
1083
1084 def unstringify_keys(x):
1085     y = {}
1086     for k, v in x.items():
1087         t = tuple(str(k).split('\t'))
1088         y[t] = v
1089     return y
1090
1091
1092 class TrafficModel(object):
1093     def __init__(self, n=3):
1094         self.ngrams = {}
1095         self.query_details = {}
1096         self.n = n
1097         self.dns_opcounts = defaultdict(int)
1098         self.cumulative_duration = 0.0
1099         self.conversation_rate = [0, 1]
1100
1101     def learn(self, conversations, dns_opcounts={}):
1102         prev = 0.0
1103         cum_duration = 0.0
1104         key = (NON_PACKET,) * (self.n - 1)
1105
1106         server = guess_server_address(conversations)
1107
1108         for k, v in dns_opcounts.items():
1109             self.dns_opcounts[k] += v
1110
1111         if len(conversations) > 1:
1112             elapsed =\
1113                 conversations[-1].start_time - conversations[0].start_time
1114             self.conversation_rate[0] = len(conversations)
1115             self.conversation_rate[1] = elapsed
1116
1117         for c in conversations:
1118             client, server = c.guess_client_server(server)
1119             cum_duration += c.get_duration()
1120             key = (NON_PACKET,) * (self.n - 1)
1121             for p in c:
1122                 if p.src != client:
1123                     continue
1124
1125                 elapsed = p.timestamp - prev
1126                 prev = p.timestamp
1127                 if elapsed > WAIT_THRESHOLD:
1128                     # add the wait as an extra state
1129                     wait = 'wait:%d' % (math.log(max(1.0,
1130                                                      elapsed * WAIT_SCALE)))
1131                     self.ngrams.setdefault(key, []).append(wait)
1132                     key = key[1:] + (wait,)
1133
1134                 short_p = p.as_packet_type()
1135                 self.query_details.setdefault(short_p,
1136                                               []).append(tuple(p.extra))
1137                 self.ngrams.setdefault(key, []).append(short_p)
1138                 key = key[1:] + (short_p,)
1139
1140         self.cumulative_duration += cum_duration
1141         # add in the end
1142         self.ngrams.setdefault(key, []).append(NON_PACKET)
1143
1144     def save(self, f):
1145         ngrams = {}
1146         for k, v in self.ngrams.items():
1147             k = '\t'.join(k)
1148             ngrams[k] = dict(Counter(v))
1149
1150         query_details = {}
1151         for k, v in self.query_details.items():
1152             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1153                                             for x in v))
1154
1155         d = {
1156             'ngrams': ngrams,
1157             'query_details': query_details,
1158             'cumulative_duration': self.cumulative_duration,
1159             'conversation_rate': self.conversation_rate,
1160         }
1161         d['dns'] = self.dns_opcounts
1162
1163         if isinstance(f, str):
1164             f = open(f, 'w')
1165
1166         json.dump(d, f, indent=2)
1167
1168     def load(self, f):
1169         if isinstance(f, str):
1170             f = open(f)
1171
1172         d = json.load(f)
1173
1174         for k, v in d['ngrams'].items():
1175             k = tuple(str(k).split('\t'))
1176             values = self.ngrams.setdefault(k, [])
1177             for p, count in v.items():
1178                 values.extend([str(p)] * count)
1179
1180         for k, v in d['query_details'].items():
1181             values = self.query_details.setdefault(str(k), [])
1182             for p, count in v.items():
1183                 if p == '-':
1184                     values.extend([()] * count)
1185                 else:
1186                     values.extend([tuple(str(p).split('\t'))] * count)
1187
1188         if 'dns' in d:
1189             for k, v in d['dns'].items():
1190                 self.dns_opcounts[k] += v
1191
1192         self.cumulative_duration = d['cumulative_duration']
1193         self.conversation_rate = d['conversation_rate']
1194
1195     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1196                                hard_stop=None, packet_rate=1):
1197         """Construct a individual converation from the model."""
1198
1199         c = Conversation(timestamp, (server, client))
1200
1201         key = (NON_PACKET,) * (self.n - 1)
1202
1203         while key in self.ngrams:
1204             p = random.choice(self.ngrams.get(key, NON_PACKET))
1205             if p == NON_PACKET:
1206                 break
1207             if p in self.query_details:
1208                 extra = random.choice(self.query_details[p])
1209             else:
1210                 extra = []
1211
1212             protocol, opcode = p.split(':', 1)
1213             if protocol == 'wait':
1214                 log_wait_time = int(opcode) + random.random()
1215                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1216                 timestamp += wait
1217             else:
1218                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1219                 wait = math.exp(log_wait) / packet_rate
1220                 timestamp += wait
1221                 if hard_stop is not None and timestamp > hard_stop:
1222                     break
1223                 c.add_short_packet(timestamp, p, extra)
1224
1225             key = key[1:] + (p,)
1226
1227         return c
1228
1229     def generate_conversations(self, rate, duration, packet_rate=1):
1230         """Generate a list of conversations from the model."""
1231
1232         # We run the simulation for at least ten times as long as our
1233         # desired duration, and take a section near the start.
1234         rate_n, rate_t  = self.conversation_rate
1235
1236         duration2 = max(rate_t, duration * 2)
1237         n = rate * duration2 * rate_n / rate_t
1238
1239         server = 1
1240         client = 2
1241
1242         conversations = []
1243         end = duration2
1244         start = end - duration
1245
1246         while client < n + 2:
1247             start = random.uniform(0, duration2)
1248             c = self.construct_conversation(start,
1249                                             client,
1250                                             server,
1251                                             hard_stop=(duration2 * 5),
1252                                             packet_rate=packet_rate)
1253
1254             c.forget_packets_outside_window(start, end)
1255             c.renormalise_times(start)
1256             if len(c) != 0:
1257                 conversations.append(c)
1258             client += 1
1259
1260         print(("we have %d conversations at rate %f" %
1261                               (len(conversations), rate)), file=sys.stderr)
1262         conversations.sort()
1263         return conversations
1264
1265
1266 IP_PROTOCOLS = {
1267     'dns': '11',
1268     'rpc_netlogon': '06',
1269     'kerberos': '06',      # ratio 16248:258
1270     'smb': '06',
1271     'smb2': '06',
1272     'ldap': '06',
1273     'cldap': '11',
1274     'lsarpc': '06',
1275     'samr': '06',
1276     'dcerpc': '06',
1277     'epm': '06',
1278     'drsuapi': '06',
1279     'browser': '11',
1280     'smb_netlogon': '11',
1281     'srvsvc': '06',
1282     'nbns': '11',
1283 }
1284
1285 OP_DESCRIPTIONS = {
1286     ('browser', '0x01'): 'Host Announcement (0x01)',
1287     ('browser', '0x02'): 'Request Announcement (0x02)',
1288     ('browser', '0x08'): 'Browser Election Request (0x08)',
1289     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1290     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1291     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1292     ('cldap', '3'): 'searchRequest',
1293     ('cldap', '5'): 'searchResDone',
1294     ('dcerpc', '0'): 'Request',
1295     ('dcerpc', '11'): 'Bind',
1296     ('dcerpc', '12'): 'Bind_ack',
1297     ('dcerpc', '13'): 'Bind_nak',
1298     ('dcerpc', '14'): 'Alter_context',
1299     ('dcerpc', '15'): 'Alter_context_resp',
1300     ('dcerpc', '16'): 'AUTH3',
1301     ('dcerpc', '2'): 'Response',
1302     ('dns', '0'): 'query',
1303     ('dns', '1'): 'response',
1304     ('drsuapi', '0'): 'DsBind',
1305     ('drsuapi', '12'): 'DsCrackNames',
1306     ('drsuapi', '13'): 'DsWriteAccountSpn',
1307     ('drsuapi', '1'): 'DsUnbind',
1308     ('drsuapi', '2'): 'DsReplicaSync',
1309     ('drsuapi', '3'): 'DsGetNCChanges',
1310     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1311     ('epm', '3'): 'Map',
1312     ('kerberos', ''): '',
1313     ('ldap', '0'): 'bindRequest',
1314     ('ldap', '1'): 'bindResponse',
1315     ('ldap', '2'): 'unbindRequest',
1316     ('ldap', '3'): 'searchRequest',
1317     ('ldap', '4'): 'searchResEntry',
1318     ('ldap', '5'): 'searchResDone',
1319     ('ldap', ''): '*** Unknown ***',
1320     ('lsarpc', '14'): 'lsa_LookupNames',
1321     ('lsarpc', '15'): 'lsa_LookupSids',
1322     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1323     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1324     ('lsarpc', '6'): 'lsa_OpenPolicy',
1325     ('lsarpc', '76'): 'lsa_LookupSids3',
1326     ('lsarpc', '77'): 'lsa_LookupNames4',
1327     ('nbns', '0'): 'query',
1328     ('nbns', '1'): 'response',
1329     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1330     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1331     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1332     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1333     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1334     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1335     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1336     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1337     ('samr', '0',): 'Connect',
1338     ('samr', '16'): 'GetAliasMembership',
1339     ('samr', '17'): 'LookupNames',
1340     ('samr', '18'): 'LookupRids',
1341     ('samr', '19'): 'OpenGroup',
1342     ('samr', '1'): 'Close',
1343     ('samr', '25'): 'QueryGroupMember',
1344     ('samr', '34'): 'OpenUser',
1345     ('samr', '36'): 'QueryUserInfo',
1346     ('samr', '39'): 'GetGroupsForUser',
1347     ('samr', '3'): 'QuerySecurity',
1348     ('samr', '5'): 'LookupDomain',
1349     ('samr', '64'): 'Connect5',
1350     ('samr', '6'): 'EnumDomains',
1351     ('samr', '7'): 'OpenDomain',
1352     ('samr', '8'): 'QueryDomainInfo',
1353     ('smb', '0x04'): 'Close (0x04)',
1354     ('smb', '0x24'): 'Locking AndX (0x24)',
1355     ('smb', '0x2e'): 'Read AndX (0x2e)',
1356     ('smb', '0x32'): 'Trans2 (0x32)',
1357     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1358     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1359     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1360     ('smb', '0x74'): 'Logoff AndX (0x74)',
1361     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1362     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1363     ('smb2', '0'): 'NegotiateProtocol',
1364     ('smb2', '11'): 'Ioctl',
1365     ('smb2', '14'): 'Find',
1366     ('smb2', '16'): 'GetInfo',
1367     ('smb2', '18'): 'Break',
1368     ('smb2', '1'): 'SessionSetup',
1369     ('smb2', '2'): 'SessionLogoff',
1370     ('smb2', '3'): 'TreeConnect',
1371     ('smb2', '4'): 'TreeDisconnect',
1372     ('smb2', '5'): 'Create',
1373     ('smb2', '6'): 'Close',
1374     ('smb2', '8'): 'Read',
1375     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1376     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1377                                'user unknown (0x17)'),
1378     ('srvsvc', '16'): 'NetShareGetInfo',
1379     ('srvsvc', '21'): 'NetSrvGetInfo',
1380 }
1381
1382
1383 def expand_short_packet(p, timestamp, src, dest, extra):
1384     protocol, opcode = p.split(':', 1)
1385     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1386     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1387
1388     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1389     line.extend(extra)
1390     return '\t'.join(line)
1391
1392
1393 def replay(conversations,
1394            host=None,
1395            creds=None,
1396            lp=None,
1397            accounts=None,
1398            dns_rate=0,
1399            duration=None,
1400            **kwargs):
1401
1402     context = ReplayContext(server=host,
1403                             creds=creds,
1404                             lp=lp,
1405                             **kwargs)
1406
1407     if len(accounts) < len(conversations):
1408         print(("we have %d accounts but %d conversations" %
1409                (accounts, conversations)), file=sys.stderr)
1410
1411     cstack = list(zip(
1412         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1413         accounts))
1414
1415     # Set the process group so that the calling scripts are not killed
1416     # when the forked child processes are killed.
1417     os.setpgrp()
1418
1419     start = time.time()
1420
1421     if duration is None:
1422         # end 1 second after the last packet of the last conversation
1423         # to start. Conversations other than the last could still be
1424         # going, but we don't care.
1425         duration = cstack[0][0].packets[-1].timestamp + 1.0
1426         print("We will stop after %.1f seconds" % duration,
1427               file=sys.stderr)
1428
1429     end = start + duration
1430
1431     print("Replaying traffic for %u conversations over %d seconds"
1432           % (len(conversations), duration))
1433
1434     children = {}
1435     if dns_rate:
1436         dns_hammer = DnsHammer(dns_rate, duration)
1437         cstack.append((dns_hammer, None))
1438
1439     try:
1440         while True:
1441             # we spawn a batch, wait for finishers, then spawn another
1442             now = time.time()
1443             batch_end = min(now + 2.0, end)
1444             fork_time = 0.0
1445             fork_n = 0
1446             while cstack:
1447                 c, account = cstack.pop()
1448                 if c.start_time + start > batch_end:
1449                     cstack.append((c, account))
1450                     break
1451
1452                 st = time.time()
1453                 pid = c.replay_in_fork_with_delay(start, context, account)
1454                 children[pid] = c
1455                 t = time.time()
1456                 elapsed = t - st
1457                 fork_time += elapsed
1458                 fork_n += 1
1459                 print("forked %s in pid %s (in %fs)" % (c, pid,
1460                                                         elapsed),
1461                       file=sys.stderr)
1462
1463             if fork_n:
1464                 print(("forked %d times in %f seconds (avg %f)" %
1465                        (fork_n, fork_time, fork_time / fork_n)),
1466                       file=sys.stderr)
1467             elif cstack:
1468                 debug(2, "no forks in batch ending %f" % batch_end)
1469
1470             while time.time() < batch_end - 1.0:
1471                 time.sleep(0.01)
1472                 try:
1473                     pid, status = os.waitpid(-1, os.WNOHANG)
1474                 except OSError as e:
1475                     if e.errno != 10:  # no child processes
1476                         raise
1477                     break
1478                 if pid:
1479                     c = children.pop(pid, None)
1480                     print(("process %d finished conversation %s;"
1481                            " %d to go" %
1482                            (pid, c, len(children))), file=sys.stderr)
1483
1484             if time.time() >= end:
1485                 print("time to stop", file=sys.stderr)
1486                 break
1487
1488     except Exception:
1489         print("EXCEPTION in parent", file=sys.stderr)
1490         traceback.print_exc()
1491     finally:
1492         for s in (15, 15, 9):
1493             print(("killing %d children with -%d" %
1494                                  (len(children), s)), file=sys.stderr)
1495             for pid in children:
1496                 try:
1497                     os.kill(pid, s)
1498                 except OSError as e:
1499                     if e.errno != 3:  # don't fail if it has already died
1500                         raise
1501             time.sleep(0.5)
1502             end = time.time() + 1
1503             while children:
1504                 try:
1505                     pid, status = os.waitpid(-1, os.WNOHANG)
1506                 except OSError as e:
1507                     if e.errno != 10:
1508                         raise
1509                 if pid != 0:
1510                     c = children.pop(pid, None)
1511                     print(("kill -%d %d KILLED conversation %s; "
1512                            "%d to go" %
1513                            (s, pid, c, len(children))),
1514                           file=sys.stderr)
1515                 if time.time() >= end:
1516                     break
1517
1518             if not children:
1519                 break
1520             time.sleep(1)
1521
1522         if children:
1523             print("%d children are missing" % len(children),
1524                   file=sys.stderr)
1525
1526         # there may be stragglers that were forked just as ^C was hit
1527         # and don't appear in the list of children. We can get them
1528         # with killpg, but that will also kill us, so this is^H^H would be
1529         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1530         # so as not to have to fuss around writing signal handlers.
1531         try:
1532             os.killpg(0, 2)
1533         except KeyboardInterrupt:
1534             print("ignoring fake ^C", file=sys.stderr)
1535
1536
1537 def openLdb(host, creds, lp):
1538     session = system_session()
1539     ldb = SamDB(url="ldap://%s" % host,
1540                 session_info=session,
1541                 credentials=creds,
1542                 lp=lp)
1543     return ldb
1544
1545
1546 def ou_name(ldb, instance_id):
1547     """Generate an ou name from the instance id"""
1548     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1549                                                     ldb.domain_dn())
1550
1551
1552 def create_ou(ldb, instance_id):
1553     """Create an ou, all created user and machine accounts will belong to it.
1554
1555     This allows all the created resources to be cleaned up easily.
1556     """
1557     ou = ou_name(ldb, instance_id)
1558     try:
1559         ldb.add({"dn":          ou.split(',', 1)[1],
1560                  "objectclass": "organizationalunit"})
1561     except LdbError as e:
1562         (status, _) = e
1563         # ignore already exists
1564         if status != 68:
1565             raise
1566     try:
1567         ldb.add({"dn":          ou,
1568                  "objectclass": "organizationalunit"})
1569     except LdbError as e:
1570         (status, _) = e
1571         # ignore already exists
1572         if status != 68:
1573             raise
1574     return ou
1575
1576
1577 class ConversationAccounts(object):
1578     """Details of the machine and user accounts associated with a conversation.
1579     """
1580     def __init__(self, netbios_name, machinepass, username, userpass):
1581         self.netbios_name = netbios_name
1582         self.machinepass  = machinepass
1583         self.username     = username
1584         self.userpass     = userpass
1585
1586
1587 def generate_replay_accounts(ldb, instance_id, number, password):
1588     """Generate a series of unique machine and user account names."""
1589
1590     generate_traffic_accounts(ldb, instance_id, number, password)
1591     accounts = []
1592     for i in range(1, number + 1):
1593         netbios_name = "STGM-%d-%d" % (instance_id, i)
1594         username     = "STGU-%d-%d" % (instance_id, i)
1595
1596         account = ConversationAccounts(netbios_name, password, username,
1597                                        password)
1598         accounts.append(account)
1599     return accounts
1600
1601
1602 def generate_traffic_accounts(ldb, instance_id, number, password):
1603     """Create the specified number of user and machine accounts.
1604
1605     As accounts are not explicitly deleted between runs. This function starts
1606     with the last account and iterates backwards stopping either when it
1607     finds an already existing account or it has generated all the required
1608     accounts.
1609     """
1610     print(("Generating machine and conversation accounts, "
1611            "as required for %d conversations" % number),
1612           file=sys.stderr)
1613     added = 0
1614     for i in range(number, 0, -1):
1615         try:
1616             netbios_name = "STGM-%d-%d" % (instance_id, i)
1617             create_machine_account(ldb, instance_id, netbios_name, password)
1618             added += 1
1619         except LdbError as e:
1620             (status, _) = e
1621             if status == 68:
1622                 break
1623             else:
1624                 raise
1625     if added > 0:
1626         print("Added %d new machine accounts" % added,
1627               file=sys.stderr)
1628
1629     added = 0
1630     for i in range(number, 0, -1):
1631         try:
1632             username = "STGU-%d-%d" % (instance_id, i)
1633             create_user_account(ldb, instance_id, username, password)
1634             added += 1
1635         except LdbError as e:
1636             (status, _) = e
1637             if status == 68:
1638                 break
1639             else:
1640                 raise
1641
1642     if added > 0:
1643         print("Added %d new user accounts" % added,
1644               file=sys.stderr)
1645
1646
1647 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1648     """Create a machine account via ldap."""
1649
1650     ou = ou_name(ldb, instance_id)
1651     dn = "cn=%s,%s" % (netbios_name, ou)
1652     utf16pw = unicode(
1653         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1654     ).encode('utf-16-le')
1655     start = time.time()
1656     ldb.add({
1657         "dn": dn,
1658         "objectclass": "computer",
1659         "sAMAccountName": "%s$" % netbios_name,
1660         "userAccountControl":
1661         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1662         "unicodePwd": utf16pw})
1663     end = time.time()
1664     duration = end - start
1665     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1666
1667
1668 def create_user_account(ldb, instance_id, username, userpass):
1669     """Create a user account via ldap."""
1670     ou = ou_name(ldb, instance_id)
1671     user_dn = "cn=%s,%s" % (username, ou)
1672     utf16pw = unicode(
1673         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1674     ).encode('utf-16-le')
1675     start = time.time()
1676     ldb.add({
1677         "dn": user_dn,
1678         "objectclass": "user",
1679         "sAMAccountName": username,
1680         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1681         "unicodePwd": utf16pw
1682     })
1683
1684     # grant user write permission to do things like write account SPN
1685     sdutils = sd_utils.SDUtils(ldb)
1686     sdutils.dacl_add_ace(user_dn,  "(A;;WP;;;PS)")
1687
1688     end = time.time()
1689     duration = end - start
1690     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1691
1692
1693 def create_group(ldb, instance_id, name):
1694     """Create a group via ldap."""
1695
1696     ou = ou_name(ldb, instance_id)
1697     dn = "cn=%s,%s" % (name, ou)
1698     start = time.time()
1699     ldb.add({
1700         "dn": dn,
1701         "objectclass": "group",
1702     })
1703     end = time.time()
1704     duration = end - start
1705     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1706
1707
1708 def user_name(instance_id, i):
1709     """Generate a user name based in the instance id"""
1710     return "STGU-%d-%d" % (instance_id, i)
1711
1712
1713 def generate_users(ldb, instance_id, number, password):
1714     """Add users to the server"""
1715     users = 0
1716     for i in range(number, 0, -1):
1717         try:
1718             username = user_name(instance_id, i)
1719             create_user_account(ldb, instance_id, username, password)
1720             users += 1
1721         except LdbError as e:
1722             (status, _) = e
1723             # Stop if entry exists
1724             if status == 68:
1725                 break
1726             else:
1727                 raise
1728
1729     return users
1730
1731
1732 def group_name(instance_id, i):
1733     """Generate a group name from instance id."""
1734     return "STGG-%d-%d" % (instance_id, i)
1735
1736
1737 def generate_groups(ldb, instance_id, number):
1738     """Create the required number of groups on the server."""
1739     groups = 0
1740     for i in range(number, 0, -1):
1741         try:
1742             name = group_name(instance_id, i)
1743             create_group(ldb, instance_id, name)
1744             groups += 1
1745         except LdbError as e:
1746             (status, _) = e
1747             # Stop if entry exists
1748             if status == 68:
1749                 break
1750             else:
1751                 raise
1752     return groups
1753
1754
1755 def clean_up_accounts(ldb, instance_id):
1756     """Remove the created accounts and groups from the server."""
1757     ou = ou_name(ldb, instance_id)
1758     try:
1759         ldb.delete(ou, ["tree_delete:1"])
1760     except LdbError as e:
1761         (status, _) = e
1762         # ignore does not exist
1763         if status != 32:
1764             raise
1765
1766
1767 def generate_users_and_groups(ldb, instance_id, password,
1768                               number_of_users, number_of_groups,
1769                               group_memberships):
1770     """Generate the required users and groups, allocating the users to
1771        those groups."""
1772     assignments = []
1773     groups_added  = 0
1774
1775     create_ou(ldb, instance_id)
1776
1777     print("Generating dummy user accounts", file=sys.stderr)
1778     users_added = generate_users(ldb, instance_id, number_of_users, password)
1779
1780     if number_of_groups > 0:
1781         print("Generating dummy groups", file=sys.stderr)
1782         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1783
1784     if group_memberships > 0:
1785         print("Assigning users to groups", file=sys.stderr)
1786         assignments = assign_groups(number_of_groups,
1787                                     groups_added,
1788                                     number_of_users,
1789                                     users_added,
1790                                     group_memberships)
1791         print("Adding users to groups", file=sys.stderr)
1792         add_users_to_groups(ldb, instance_id, assignments)
1793
1794     if (groups_added > 0 and users_added == 0 and
1795        number_of_groups != groups_added):
1796         print("Warning: the added groups will contain no members",
1797               file=sys.stderr)
1798
1799     print(("Added %d users, %d groups and %d group memberships" %
1800            (users_added, groups_added, len(assignments))),
1801           file=sys.stderr)
1802
1803
1804 def assign_groups(number_of_groups,
1805                   groups_added,
1806                   number_of_users,
1807                   users_added,
1808                   group_memberships):
1809     """Allocate users to groups.
1810
1811     The intention is to have a few users that belong to most groups, while
1812     the majority of users belong to a few groups.
1813
1814     A few groups will contain most users, with the remaining only having a
1815     few users.
1816     """
1817
1818     def generate_user_distribution(n):
1819         """Probability distribution of a user belonging to a group.
1820         """
1821         dist = []
1822         for x in range(1, n + 1):
1823             p = 1 / (x + 0.001)
1824             dist.append(p)
1825         return dist
1826
1827     def generate_group_distribution(n):
1828         """Probability distribution of a group containing a user."""
1829         dist = []
1830         for x in range(1, n + 1):
1831             p = 1 / (x**1.3)
1832             dist.append(p)
1833         return dist
1834
1835     assignments = set()
1836     if group_memberships <= 0:
1837         return assignments
1838
1839     group_dist = generate_group_distribution(number_of_groups)
1840     user_dist  = generate_user_distribution(number_of_users)
1841
1842     # Calculate the number of group menberships required
1843     group_memberships = math.ceil(
1844         float(group_memberships) *
1845         (float(users_added) / float(number_of_users)))
1846
1847     existing_users  = number_of_users  - users_added  - 1
1848     existing_groups = number_of_groups - groups_added - 1
1849     while len(assignments) < group_memberships:
1850         user        = random.randint(0, number_of_users - 1)
1851         group       = random.randint(0, number_of_groups - 1)
1852         probability = group_dist[group] * user_dist[user]
1853
1854         if ((random.random() < probability * 10000) and
1855            (group > existing_groups or user > existing_users)):
1856             # the + 1 converts the array index to the corresponding
1857             # group or user number
1858             assignments.add(((user + 1), (group + 1)))
1859
1860     return assignments
1861
1862
1863 def add_users_to_groups(db, instance_id, assignments):
1864     """Add users to their assigned groups.
1865
1866     Takes the list of (group,user) tuples generated by assign_groups and
1867     assign the users to their specified groups."""
1868
1869     ou = ou_name(db, instance_id)
1870
1871     def build_dn(name):
1872         return("cn=%s,%s" % (name, ou))
1873
1874     for (user, group) in assignments:
1875         user_dn  = build_dn(user_name(instance_id, user))
1876         group_dn = build_dn(group_name(instance_id, group))
1877
1878         m = ldb.Message()
1879         m.dn = ldb.Dn(db, group_dn)
1880         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1881         start = time.time()
1882         db.modify(m)
1883         end = time.time()
1884         duration = end - start
1885         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1886
1887
1888 def generate_stats(statsdir, timing_file):
1889     """Generate and print the summary stats for a run."""
1890     first      = sys.float_info.max
1891     last       = 0
1892     successful = 0
1893     failed     = 0
1894     latencies  = {}
1895     failures   = {}
1896     unique_converations = set()
1897     conversations = 0
1898
1899     if timing_file is not None:
1900         tw = timing_file.write
1901     else:
1902         def tw(x):
1903             pass
1904
1905     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1906
1907     for filename in os.listdir(statsdir):
1908         path = os.path.join(statsdir, filename)
1909         with open(path, 'r') as f:
1910             for line in f:
1911                 try:
1912                     fields       = line.rstrip('\n').split('\t')
1913                     conversation = fields[1]
1914                     protocol     = fields[2]
1915                     packet_type  = fields[3]
1916                     latency      = float(fields[4])
1917                     first        = min(float(fields[0]) - latency, first)
1918                     last         = max(float(fields[0]), last)
1919
1920                     if protocol not in latencies:
1921                         latencies[protocol] = {}
1922                     if packet_type not in latencies[protocol]:
1923                         latencies[protocol][packet_type] = []
1924
1925                     latencies[protocol][packet_type].append(latency)
1926
1927                     if protocol not in failures:
1928                         failures[protocol] = {}
1929                     if packet_type not in failures[protocol]:
1930                         failures[protocol][packet_type] = 0
1931
1932                     if fields[5] == 'True':
1933                         successful += 1
1934                     else:
1935                         failed += 1
1936                         failures[protocol][packet_type] += 1
1937
1938                     if conversation not in unique_converations:
1939                         unique_converations.add(conversation)
1940                         conversations += 1
1941
1942                     tw(line)
1943                 except (ValueError, IndexError):
1944                     # not a valid line print and ignore
1945                     print(line, file=sys.stderr)
1946                     pass
1947     duration = last - first
1948     if successful == 0:
1949         success_rate = 0
1950     else:
1951         success_rate = successful / duration
1952     if failed == 0:
1953         failure_rate = 0
1954     else:
1955         failure_rate = failed / duration
1956
1957     # print the stats in more human-readable format when stdout is going to the
1958     # console (as opposed to being redirected to a file)
1959     if sys.stdout.isatty():
1960         print("Total conversations:   %10d" % conversations)
1961         print("Successful operations: %10d (%.3f per second)"
1962               % (successful, success_rate))
1963         print("Failed operations:     %10d (%.3f per second)"
1964               % (failed, failure_rate))
1965     else:
1966         print("(%d, %d, %d, %.3f, %.3f)" %
1967               (conversations, successful, failed, success_rate, failure_rate))
1968
1969     if sys.stdout.isatty():
1970         print("Protocol    Op Code  Description                               "
1971               " Count       Failed         Mean       Median          "
1972               "95%        Range          Max")
1973     else:
1974         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1975               "\tmax")
1976     protocols = sorted(latencies.keys())
1977     for protocol in protocols:
1978         packet_types = sorted(latencies[protocol], key=opcode_key)
1979         for packet_type in packet_types:
1980             values     = latencies[protocol][packet_type]
1981             values     = sorted(values)
1982             count      = len(values)
1983             failed     = failures[protocol][packet_type]
1984             mean       = sum(values) / count
1985             median     = calc_percentile(values, 0.50)
1986             percentile = calc_percentile(values, 0.95)
1987             rng        = values[-1] - values[0]
1988             maxv       = values[-1]
1989             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1990             if sys.stdout.isatty:
1991                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1992                       "%12.6f %12.6f %12.6f %12.6f"
1993                       % (protocol,
1994                          packet_type,
1995                          desc,
1996                          count,
1997                          failed,
1998                          mean,
1999                          median,
2000                          percentile,
2001                          rng,
2002                          maxv))
2003             else:
2004                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2005                       % (protocol,
2006                          packet_type,
2007                          desc,
2008                          count,
2009                          failed,
2010                          mean,
2011                          median,
2012                          percentile,
2013                          rng,
2014                          maxv))
2015
2016
2017 def opcode_key(v):
2018     """Sort key for the operation code to ensure that it sorts numerically"""
2019     try:
2020         return "%03d" % int(v)
2021     except:
2022         return v
2023
2024
2025 def calc_percentile(values, percentile):
2026     """Calculate the specified percentile from the list of values.
2027
2028     Assumes the list is sorted in ascending order.
2029     """
2030
2031     if not values:
2032         return 0
2033     k = (len(values) - 1) * percentile
2034     f = math.floor(k)
2035     c = math.ceil(k)
2036     if f == c:
2037         return values[int(k)]
2038     d0 = values[int(f)] * (c - k)
2039     d1 = values[int(c)] * (k - f)
2040     return d0 + d1
2041
2042
2043 def mk_masked_dir(*path):
2044     """In a testenv we end up with 0777 diectories that look an alarming
2045     green colour with ls. Use umask to avoid that."""
2046     d = os.path.join(*path)
2047     mask = os.umask(0o077)
2048     os.mkdir(d)
2049     os.umask(mask)
2050     return d