python: bulk convert zip to list
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49
50 SLEEP_OVERHEAD = 3e-4
51
52 # we don't use None, because it complicates [de]serialisation
53 NON_PACKET = '-'
54
55 CLIENT_CLUES = {
56     ('dns', '0'): 1.0,      # query
57     ('smb', '0x72'): 1.0,   # Negotiate protocol
58     ('ldap', '0'): 1.0,     # bind
59     ('ldap', '3'): 1.0,     # searchRequest
60     ('ldap', '2'): 1.0,     # unbindRequest
61     ('cldap', '3'): 1.0,
62     ('dcerpc', '11'): 1.0,  # bind
63     ('dcerpc', '14'): 1.0,  # Alter_context
64     ('nbns', '0'): 1.0,     # query
65 }
66
67 SERVER_CLUES = {
68     ('dns', '1'): 1.0,      # response
69     ('ldap', '1'): 1.0,     # bind response
70     ('ldap', '4'): 1.0,     # search result
71     ('ldap', '5'): 1.0,     # search done
72     ('cldap', '5'): 1.0,
73     ('dcerpc', '12'): 1.0,  # bind_ack
74     ('dcerpc', '13'): 1.0,  # bind_nak
75     ('dcerpc', '15'): 1.0,  # Alter_context response
76 }
77
78 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
79
80 WAIT_SCALE = 10.0
81 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
82 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
83
84 # DEBUG_LEVEL can be changed by scripts with -d
85 DEBUG_LEVEL = 0
86
87
88 def debug(level, msg, *args):
89     """Print a formatted debug message to standard error.
90
91
92     :param level: The debug level, message will be printed if it is <= the
93                   currently set debug level. The debug level can be set with
94                   the -d option.
95     :param msg:   The message to be logged, can contain C-Style format
96                   specifiers
97     :param args:  The parameters required by the format specifiers
98     """
99     if level <= DEBUG_LEVEL:
100         if not args:
101             print(msg, file=sys.stderr)
102         else:
103             print(msg % tuple(args), file=sys.stderr)
104
105
106 def debug_lineno(*args):
107     """ Print an unformatted log message to stderr, contaning the line number
108     """
109     tb = traceback.extract_stack(limit=2)
110     print((" %s:" "\033[01;33m"
111            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
112           file=sys.stderr)
113     for a in args:
114         print(a, file=sys.stderr)
115     print(file=sys.stderr)
116     sys.stderr.flush()
117
118
119 def random_colour_print():
120     """Return a function that prints a randomly coloured line to stderr"""
121     n = 18 + random.randrange(214)
122     prefix = "\033[38;5;%dm" % n
123
124     def p(*args):
125         for a in args:
126             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
127
128     return p
129
130
131 class FakePacketError(Exception):
132     pass
133
134
135 class Packet(object):
136     """Details of a network packet"""
137     def __init__(self, fields):
138         if isinstance(fields, str):
139             fields = fields.rstrip('\n').split('\t')
140
141         (timestamp,
142          ip_protocol,
143          stream_number,
144          src,
145          dest,
146          protocol,
147          opcode,
148          desc) = fields[:8]
149         extra = fields[8:]
150
151         self.timestamp = float(timestamp)
152         self.ip_protocol = ip_protocol
153         try:
154             self.stream_number = int(stream_number)
155         except (ValueError, TypeError):
156             self.stream_number = None
157         self.src = int(src)
158         self.dest = int(dest)
159         self.protocol = protocol
160         self.opcode = opcode
161         self.desc = desc
162         self.extra = extra
163
164         if self.src < self.dest:
165             self.endpoints = (self.src, self.dest)
166         else:
167             self.endpoints = (self.dest, self.src)
168
169     def as_summary(self, time_offset=0.0):
170         """Format the packet as a traffic_summary line.
171         """
172         extra = '\t'.join(self.extra)
173         t = self.timestamp + time_offset
174         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
175                 (t,
176                  self.ip_protocol,
177                  self.stream_number or '',
178                  self.src,
179                  self.dest,
180                  self.protocol,
181                  self.opcode,
182                  self.desc,
183                  extra))
184
185     def __str__(self):
186         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
187                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
188                  self.stream_number, self.protocol, self.opcode, self.desc,
189                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
190
191     def __repr__(self):
192         return "<Packet @%s>" % self
193
194     def copy(self):
195         return self.__class__([self.timestamp,
196                                self.ip_protocol,
197                                self.stream_number,
198                                self.src,
199                                self.dest,
200                                self.protocol,
201                                self.opcode,
202                                self.desc] + self.extra)
203
204     def as_packet_type(self):
205         t = '%s:%s' % (self.protocol, self.opcode)
206         return t
207
208     def client_score(self):
209         """A positive number means we think it is a client; a negative number
210         means we think it is a server. Zero means no idea. range: -1 to 1.
211         """
212         key = (self.protocol, self.opcode)
213         if key in CLIENT_CLUES:
214             return CLIENT_CLUES[key]
215         if key in SERVER_CLUES:
216             return -SERVER_CLUES[key]
217         return 0.0
218
219     def play(self, conversation, context):
220         """Send the packet over the network, if required.
221
222         Some packets are ignored, i.e. for  protocols not handled,
223         server response messages, or messages that are generated by the
224         protocol layer associated with other packets.
225         """
226         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
227         try:
228             fn = getattr(traffic_packets, fn_name)
229
230         except AttributeError as e:
231             print("Conversation(%s) Missing handler %s" % \
232                   (conversation.conversation_id, fn_name),
233                   file=sys.stderr)
234             return
235
236         # Don't display a message for kerberos packets, they're not directly
237         # generated they're used to indicate kerberos should be used
238         if self.protocol != "kerberos":
239             debug(2, "Conversation(%s) Calling handler %s" %
240                      (conversation.conversation_id, fn_name))
241
242         start = time.time()
243         try:
244             if fn(self, conversation, context):
245                 # Only collect timing data for functions that generate
246                 # network traffic, or fail
247                 end = time.time()
248                 duration = end - start
249                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
250                       (end, conversation.conversation_id, self.protocol,
251                        self.opcode, duration))
252         except Exception as e:
253             end = time.time()
254             duration = end - start
255             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
256                   (end, conversation.conversation_id, self.protocol,
257                    self.opcode, duration, e))
258
259     def __cmp__(self, other):
260         return self.timestamp - other.timestamp
261
262     def is_really_a_packet(self, missing_packet_stats=None):
263         """Is the packet one that can be ignored?
264
265         If so removing it will have no effect on the replay
266         """
267         if self.protocol in SKIPPED_PROTOCOLS:
268             # Ignore any packets for the protocols we're not interested in.
269             return False
270         if self.protocol == "ldap" and self.opcode == '':
271             # skip ldap continuation packets
272             return False
273
274         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
275         try:
276             fn = getattr(traffic_packets, fn_name)
277             if fn is traffic_packets.null_packet:
278                 return False
279         except AttributeError:
280             print("missing packet %s" % fn_name, file=sys.stderr)
281             return False
282         return True
283
284
285 class ReplayContext(object):
286     """State/Context for an individual conversation between an simulated client
287        and a server.
288     """
289
290     def __init__(self,
291                  server=None,
292                  lp=None,
293                  creds=None,
294                  badpassword_frequency=None,
295                  prefer_kerberos=None,
296                  tempdir=None,
297                  statsdir=None,
298                  ou=None,
299                  base_dn=None,
300                  domain=None,
301                  domain_sid=None):
302
303         self.server                   = server
304         self.ldap_connections         = []
305         self.dcerpc_connections       = []
306         self.lsarpc_connections       = []
307         self.lsarpc_connections_named = []
308         self.drsuapi_connections      = []
309         self.srvsvc_connections       = []
310         self.samr_contexts            = []
311         self.netlogon_connection      = None
312         self.creds                    = creds
313         self.lp                       = lp
314         self.prefer_kerberos          = prefer_kerberos
315         self.ou                       = ou
316         self.base_dn                  = base_dn
317         self.domain                   = domain
318         self.statsdir                 = statsdir
319         self.global_tempdir           = tempdir
320         self.domain_sid               = domain_sid
321         self.realm                    = lp.get('realm')
322
323         # Bad password attempt controls
324         self.badpassword_frequency    = badpassword_frequency
325         self.last_lsarpc_bad          = False
326         self.last_lsarpc_named_bad    = False
327         self.last_simple_bind_bad     = False
328         self.last_bind_bad            = False
329         self.last_srvsvc_bad          = False
330         self.last_drsuapi_bad         = False
331         self.last_netlogon_bad        = False
332         self.last_samlogon_bad        = False
333         self.generate_ldap_search_tables()
334         self.next_conversation_id = itertools.count().next
335
336     def generate_ldap_search_tables(self):
337         session = system_session()
338
339         db = SamDB(url="ldap://%s" % self.server,
340                    session_info=session,
341                    credentials=self.creds,
342                    lp=self.lp)
343
344         res = db.search(db.domain_dn(),
345                         scope=ldb.SCOPE_SUBTREE,
346                         attrs=['dn'])
347
348         # find a list of dns for each pattern
349         # e.g. CN,CN,CN,DC,DC
350         dn_map = {}
351         attribute_clue_map = {
352             'invocationId': []
353         }
354
355         for r in res:
356             dn = str(r.dn)
357             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
358             dns = dn_map.setdefault(pattern, [])
359             dns.append(dn)
360             if dn.startswith('CN=NTDS Settings,'):
361                 attribute_clue_map['invocationId'].append(dn)
362
363         # extend the map in case we are working with a different
364         # number of DC components.
365         # for k, v in self.dn_map.items():
366         #     print >>sys.stderr, k, len(v)
367
368         for k, v in dn_map.items():
369             if k[-3:] != ',DC':
370                 continue
371             p = k[:-3]
372             while p[-3:] == ',DC':
373                 p = p[:-3]
374             for i in range(5):
375                 p += ',DC'
376                 if p != k and p in dn_map:
377                     print('dn_map collison %s %s' % (k, p),
378                           file=sys.stderr)
379                     continue
380                 dn_map[p] = dn_map[k]
381
382         self.dn_map = dn_map
383         self.attribute_clue_map = attribute_clue_map
384
385     def generate_process_local_config(self, account, conversation):
386         if account is None:
387             return
388         self.netbios_name             = account.netbios_name
389         self.machinepass              = account.machinepass
390         self.username                 = account.username
391         self.userpass                 = account.userpass
392
393         self.tempdir = mk_masked_dir(self.global_tempdir,
394                                      'conversation-%d' %
395                                      conversation.conversation_id)
396
397         self.lp.set("private dir",     self.tempdir)
398         self.lp.set("lock dir",        self.tempdir)
399         self.lp.set("state directory", self.tempdir)
400         self.lp.set("tls verify peer", "no_check")
401
402         # If the domain was not specified, check for the environment
403         # variable.
404         if self.domain is None:
405             self.domain = os.environ["DOMAIN"]
406
407         self.remoteAddress = "/root/ncalrpc_as_system"
408         self.samlogon_dn   = ("cn=%s,%s" %
409                               (self.netbios_name, self.ou))
410         self.user_dn       = ("cn=%s,%s" %
411                               (self.username, self.ou))
412
413         self.generate_machine_creds()
414         self.generate_user_creds()
415
416     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
417         """Execute the supplied logon function, randomly choosing the
418            bad credentials.
419
420            Based on the frequency in badpassword_frequency randomly perform the
421            function with the supplied bad credentials.
422            If run with bad credentials, the function is re-run with the good
423            credentials.
424            failed_last_time is used to prevent consecutive bad credential
425            attempts. So the over all bad credential frequency will be lower
426            than that requested, but not significantly.
427         """
428         if not failed_last_time:
429             if (self.badpassword_frequency > 0 and
430                random.random() < self.badpassword_frequency):
431                 try:
432                     f(bad)
433                 except:
434                     # Ignore any exceptions as the operation may fail
435                     # as it's being performed with bad credentials
436                     pass
437                 failed_last_time = True
438             else:
439                 failed_last_time = False
440
441         result = f(good)
442         return (result, failed_last_time)
443
444     def generate_user_creds(self):
445         """Generate the conversation specific user Credentials.
446
447         Each Conversation has an associated user account used to simulate
448         any non Administrative user traffic.
449
450         Generates user credentials with good and bad passwords and ldap
451         simple bind credentials with good and bad passwords.
452         """
453         self.user_creds = Credentials()
454         self.user_creds.guess(self.lp)
455         self.user_creds.set_workstation(self.netbios_name)
456         self.user_creds.set_password(self.userpass)
457         self.user_creds.set_username(self.username)
458         if self.prefer_kerberos:
459             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
460         else:
461             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
462
463         self.user_creds_bad = Credentials()
464         self.user_creds_bad.guess(self.lp)
465         self.user_creds_bad.set_workstation(self.netbios_name)
466         self.user_creds_bad.set_password(self.userpass[:-4])
467         self.user_creds_bad.set_username(self.username)
468         if self.prefer_kerberos:
469             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
470         else:
471             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
472
473         # Credentials for ldap simple bind.
474         self.simple_bind_creds = Credentials()
475         self.simple_bind_creds.guess(self.lp)
476         self.simple_bind_creds.set_workstation(self.netbios_name)
477         self.simple_bind_creds.set_password(self.userpass)
478         self.simple_bind_creds.set_username(self.username)
479         self.simple_bind_creds.set_gensec_features(
480             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
481         if self.prefer_kerberos:
482             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
483         else:
484             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
485         self.simple_bind_creds.set_bind_dn(self.user_dn)
486
487         self.simple_bind_creds_bad = Credentials()
488         self.simple_bind_creds_bad.guess(self.lp)
489         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
490         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
491         self.simple_bind_creds_bad.set_username(self.username)
492         self.simple_bind_creds_bad.set_gensec_features(
493             self.simple_bind_creds_bad.get_gensec_features() |
494             gensec.FEATURE_SEAL)
495         if self.prefer_kerberos:
496             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
497         else:
498             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
499         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
500
501     def generate_machine_creds(self):
502         """Generate the conversation specific machine Credentials.
503
504         Each Conversation has an associated machine account.
505
506         Generates machine credentials with good and bad passwords.
507         """
508
509         self.machine_creds = Credentials()
510         self.machine_creds.guess(self.lp)
511         self.machine_creds.set_workstation(self.netbios_name)
512         self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
513         self.machine_creds.set_password(self.machinepass)
514         self.machine_creds.set_username(self.netbios_name + "$")
515         if self.prefer_kerberos:
516             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
517         else:
518             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
519
520         self.machine_creds_bad = Credentials()
521         self.machine_creds_bad.guess(self.lp)
522         self.machine_creds_bad.set_workstation(self.netbios_name)
523         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
524         self.machine_creds_bad.set_password(self.machinepass[:-4])
525         self.machine_creds_bad.set_username(self.netbios_name + "$")
526         if self.prefer_kerberos:
527             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
528         else:
529             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
530
531     def get_matching_dn(self, pattern, attributes=None):
532         # If the pattern is an empty string, we assume ROOTDSE,
533         # Otherwise we try adding or removing DC suffixes, then
534         # shorter leading patterns until we hit one.
535         # e.g if there is no CN,CN,CN,CN,DC,DC
536         # we first try       CN,CN,CN,CN,DC
537         # and                CN,CN,CN,CN,DC,DC,DC
538         # then change to        CN,CN,CN,DC,DC
539         # and as last resort we use the base_dn
540         attr_clue = self.attribute_clue_map.get(attributes)
541         if attr_clue:
542             return random.choice(attr_clue)
543
544         pattern = pattern.upper()
545         while pattern:
546             if pattern in self.dn_map:
547                 return random.choice(self.dn_map[pattern])
548             # chop one off the front and try it all again.
549             pattern = pattern[3:]
550
551         return self.base_dn
552
553     def get_dcerpc_connection(self, new=False):
554         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
555         if self.dcerpc_connections and not new:
556             return self.dcerpc_connections[-1]
557         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
558                              (guid, 1), self.lp)
559         self.dcerpc_connections.append(c)
560         return c
561
562     def get_srvsvc_connection(self, new=False):
563         if self.srvsvc_connections and not new:
564             return self.srvsvc_connections[-1]
565
566         def connect(creds):
567             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
568                                  self.lp,
569                                  creds)
570
571         (c, self.last_srvsvc_bad) = \
572             self.with_random_bad_credentials(connect,
573                                              self.user_creds,
574                                              self.user_creds_bad,
575                                              self.last_srvsvc_bad)
576
577         self.srvsvc_connections.append(c)
578         return c
579
580     def get_lsarpc_connection(self, new=False):
581         if self.lsarpc_connections and not new:
582             return self.lsarpc_connections[-1]
583
584         def connect(creds):
585             binding_options = 'schannel,seal,sign'
586             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
587                               (self.server, binding_options),
588                               self.lp,
589                               creds)
590
591         (c, self.last_lsarpc_bad) = \
592             self.with_random_bad_credentials(connect,
593                                              self.machine_creds,
594                                              self.machine_creds_bad,
595                                              self.last_lsarpc_bad)
596
597         self.lsarpc_connections.append(c)
598         return c
599
600     def get_lsarpc_named_pipe_connection(self, new=False):
601         if self.lsarpc_connections_named and not new:
602             return self.lsarpc_connections_named[-1]
603
604         def connect(creds):
605             return lsa.lsarpc("ncacn_np:%s" % (self.server),
606                               self.lp,
607                               creds)
608
609         (c, self.last_lsarpc_named_bad) = \
610             self.with_random_bad_credentials(connect,
611                                              self.machine_creds,
612                                              self.machine_creds_bad,
613                                              self.last_lsarpc_named_bad)
614
615         self.lsarpc_connections_named.append(c)
616         return c
617
618     def get_drsuapi_connection_pair(self, new=False, unbind=False):
619         """get a (drs, drs_handle) tuple"""
620         if self.drsuapi_connections and not new:
621             c = self.drsuapi_connections[-1]
622             return c
623
624         def connect(creds):
625             binding_options = 'seal'
626             binding_string = "ncacn_ip_tcp:%s[%s]" %\
627                              (self.server, binding_options)
628             return drsuapi.drsuapi(binding_string, self.lp, creds)
629
630         (drs, self.last_drsuapi_bad) = \
631             self.with_random_bad_credentials(connect,
632                                              self.user_creds,
633                                              self.user_creds_bad,
634                                              self.last_drsuapi_bad)
635
636         (drs_handle, supported_extensions) = drs_DsBind(drs)
637         c = (drs, drs_handle)
638         self.drsuapi_connections.append(c)
639         return c
640
641     def get_ldap_connection(self, new=False, simple=False):
642         if self.ldap_connections and not new:
643             return self.ldap_connections[-1]
644
645         def simple_bind(creds):
646             return SamDB('ldaps://%s' % self.server,
647                          credentials=creds,
648                          lp=self.lp)
649
650         def sasl_bind(creds):
651             return SamDB('ldap://%s' % self.server,
652                          credentials=creds,
653                          lp=self.lp)
654         if simple:
655             (samdb, self.last_simple_bind_bad) = \
656                 self.with_random_bad_credentials(simple_bind,
657                                                  self.simple_bind_creds,
658                                                  self.simple_bind_creds_bad,
659                                                  self.last_simple_bind_bad)
660         else:
661             (samdb, self.last_bind_bad) = \
662                 self.with_random_bad_credentials(sasl_bind,
663                                                  self.user_creds,
664                                                  self.user_creds_bad,
665                                                  self.last_bind_bad)
666
667         self.ldap_connections.append(samdb)
668         return samdb
669
670     def get_samr_context(self, new=False):
671         if not self.samr_contexts or new:
672             self.samr_contexts.append(SamrContext(self.server))
673         return self.samr_contexts[-1]
674
675     def get_netlogon_connection(self):
676
677         if self.netlogon_connection:
678             return self.netlogon_connection
679
680         def connect(creds):
681             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
682                                      (self.server),
683                                      self.lp,
684                                      creds)
685         (c, self.last_netlogon_bad) = \
686             self.with_random_bad_credentials(connect,
687                                              self.machine_creds,
688                                              self.machine_creds_bad,
689                                              self.last_netlogon_bad)
690         self.netlogon_connection = c
691         return c
692
693     def guess_a_dns_lookup(self):
694         return (self.realm, 'A')
695
696     def get_authenticator(self):
697         auth = self.machine_creds.new_client_authenticator()
698         current  = netr_Authenticator()
699         current.cred.data = [ord(x) for x in auth["credential"]]
700         current.timestamp = auth["timestamp"]
701
702         subsequent = netr_Authenticator()
703         return (current, subsequent)
704
705
706 class SamrContext(object):
707     """State/Context associated with a samr connection.
708     """
709     def __init__(self, server):
710         self.connection    = None
711         self.handle        = None
712         self.domain_handle = None
713         self.domain_sid    = None
714         self.group_handle  = None
715         self.user_handle   = None
716         self.rids          = None
717         self.server        = server
718
719     def get_connection(self):
720         if not self.connection:
721             self.connection = samr.samr("ncacn_ip_tcp:%s" % (self.server))
722         return self.connection
723
724     def get_handle(self):
725         if not self.handle:
726             c = self.get_connection()
727             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
728         return self.handle
729
730
731 class Conversation(object):
732     """Details of a converation between a simulated client and a server."""
733     conversation_id = None
734
735     def __init__(self, start_time=None, endpoints=None):
736         self.start_time = start_time
737         self.endpoints = endpoints
738         self.packets = []
739         self.msg = random_colour_print()
740         self.client_balance = 0.0
741
742     def __cmp__(self, other):
743         if self.start_time is None:
744             if other.start_time is None:
745                 return 0
746             return -1
747         if other.start_time is None:
748             return 1
749         return self.start_time - other.start_time
750
751     def add_packet(self, packet):
752         """Add a packet object to this conversation, making a local copy with
753         a conversation-relative timestamp."""
754         p = packet.copy()
755
756         if self.start_time is None:
757             self.start_time = p.timestamp
758
759         if self.endpoints is None:
760             self.endpoints = p.endpoints
761
762         if p.endpoints != self.endpoints:
763             raise FakePacketError("Conversation endpoints %s don't match"
764                                   "packet endpoints %s" %
765                                   (self.endpoints, p.endpoints))
766
767         p.timestamp -= self.start_time
768
769         if p.src == p.endpoints[0]:
770             self.client_balance -= p.client_score()
771         else:
772             self.client_balance += p.client_score()
773
774         if p.is_really_a_packet():
775             self.packets.append(p)
776
777     def add_short_packet(self, timestamp, p, extra, client=True):
778         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
779         (possibly empty) list of extra data. If client is True, assume
780         this packet is from the client to the server.
781         """
782         protocol, opcode = p.split(':', 1)
783         src, dest = self.guess_client_server()
784         if not client:
785             src, dest = dest, src
786
787         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
788         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
789         fields = [timestamp - self.start_time, ip_protocol,
790                   '', src, dest,
791                   protocol, opcode, desc]
792         fields.extend(extra)
793         packet = Packet(fields)
794         # XXX we're assuming the timestamp is already adjusted for
795         # this conversation?
796         # XXX should we adjust client balance for guessed packets?
797         if packet.src == packet.endpoints[0]:
798             self.client_balance -= packet.client_score()
799         else:
800             self.client_balance += packet.client_score()
801         if packet.is_really_a_packet():
802             self.packets.append(packet)
803
804     def __str__(self):
805         return ("<Conversation %s %s starting %.3f %d packets>" %
806                 (self.conversation_id, self.endpoints, self.start_time,
807                  len(self.packets)))
808
809     __repr__ = __str__
810
811     def __iter__(self):
812         return iter(self.packets)
813
814     def __len__(self):
815         return len(self.packets)
816
817     def get_duration(self):
818         if len(self.packets) < 2:
819             return 0
820         return self.packets[-1].timestamp - self.packets[0].timestamp
821
822     def replay_as_summary_lines(self):
823         lines = []
824         for p in self.packets:
825             lines.append(p.as_summary(self.start_time))
826         return lines
827
828     def replay_in_fork_with_delay(self, start, context=None, account=None):
829         """Fork a new process and replay the conversation.
830         """
831         def signal_handler(signal, frame):
832             """Signal handler closes standard out and error.
833
834             Triggered by a sigterm, ensures that the log messages are flushed
835             to disk and not lost.
836             """
837             sys.stderr.close()
838             sys.stdout.close()
839             os._exit(0)
840
841         t = self.start_time
842         now = time.time() - start
843         gap = t - now
844         # we are replaying strictly in order, so it is safe to sleep
845         # in the main process if the gap is big enough. This reduces
846         # the number of concurrent threads, which allows us to make
847         # larger loads.
848         if gap > 0.15 and False:
849             print("sleeping for %f in main process" % (gap - 0.1),
850                   file=sys.stderr)
851             time.sleep(gap - 0.1)
852             now = time.time() - start
853             gap = t - now
854             print("gap is now %f" % gap, file=sys.stderr)
855
856         self.conversation_id = context.next_conversation_id()
857         pid = os.fork()
858         if pid != 0:
859             return pid
860         pid = os.getpid()
861         signal.signal(signal.SIGTERM, signal_handler)
862         # we must never return, or we'll end up running parts of the
863         # parent's clean-up code. So we work in a try...finally, and
864         # try to print any exceptions.
865
866         try:
867             context.generate_process_local_config(account, self)
868             sys.stdin.close()
869             os.close(0)
870             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
871                                     self.conversation_id)
872             sys.stdout.close()
873             sys.stdout = open(filename, 'w')
874
875             sleep_time = gap - SLEEP_OVERHEAD
876             if sleep_time > 0:
877                 time.sleep(sleep_time)
878
879             miss = t - (time.time() - start)
880             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
881             self.replay(context)
882         except Exception:
883             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
884                   file=sys.stderr)
885             traceback.print_exc(sys.stderr)
886         finally:
887             sys.stderr.close()
888             sys.stdout.close()
889             os._exit(0)
890
891     def replay(self, context=None):
892         start = time.time()
893
894         for p in self.packets:
895             now = time.time() - start
896             gap = p.timestamp - now
897             sleep_time = gap - SLEEP_OVERHEAD
898             if sleep_time > 0:
899                 time.sleep(sleep_time)
900
901             miss = p.timestamp - (time.time() - start)
902             if context is None:
903                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
904                                                            os.getpid()))
905                 continue
906             p.play(self, context)
907
908     def guess_client_server(self, server_clue=None):
909         """Have a go at deciding who is the server and who is the client.
910         returns (client, server)
911         """
912         a, b = self.endpoints
913
914         if self.client_balance < 0:
915             return (a, b)
916
917         # in the absense of a clue, we will fall through to assuming
918         # the lowest number is the server (which is usually true).
919
920         if self.client_balance == 0 and server_clue == b:
921             return (a, b)
922
923         return (b, a)
924
925     def forget_packets_outside_window(self, s, e):
926         """Prune any packets outside the timne window we're interested in
927
928         :param s: start of the window
929         :param e: end of the window
930         """
931
932         new_packets = []
933         for p in self.packets:
934             if p.timestamp < s or p.timestamp > e:
935                 continue
936             new_packets.append(p)
937
938         self.packets = new_packets
939         if new_packets:
940             self.start_time = new_packets[0].timestamp
941         else:
942             self.start_time = None
943
944     def renormalise_times(self, start_time):
945         """Adjust the packet start times relative to the new start time."""
946         for p in self.packets:
947             p.timestamp -= start_time
948
949         if self.start_time is not None:
950             self.start_time -= start_time
951
952
953 class DnsHammer(Conversation):
954     """A lightweight conversation that generates a lot of dns:0 packets on
955     the fly"""
956
957     def __init__(self, dns_rate, duration):
958         n = int(dns_rate * duration)
959         self.times = [random.uniform(0, duration) for i in range(n)]
960         self.times.sort()
961         self.rate = dns_rate
962         self.duration = duration
963         self.start_time = 0
964         self.msg = random_colour_print()
965
966     def __str__(self):
967         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
968                 (len(self.times), self.duration, self.rate))
969
970     def replay_in_fork_with_delay(self, start, context=None, account=None):
971         return Conversation.replay_in_fork_with_delay(self,
972                                                       start,
973                                                       context,
974                                                       account)
975
976     def replay(self, context=None):
977         start = time.time()
978         fn = traffic_packets.packet_dns_0
979         for t in self.times:
980             now = time.time() - start
981             gap = t - now
982             sleep_time = gap - SLEEP_OVERHEAD
983             if sleep_time > 0:
984                 time.sleep(sleep_time)
985
986             if context is None:
987                 miss = t - (time.time() - start)
988                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
989                                                            os.getpid()))
990                 continue
991
992             packet_start = time.time()
993             try:
994                 fn(self, self, context)
995                 end = time.time()
996                 duration = end - packet_start
997                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
998             except Exception as e:
999                 end = time.time()
1000                 duration = end - packet_start
1001                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1002
1003
1004 def ingest_summaries(files, dns_mode='count'):
1005     """Load a summary traffic summary file and generated Converations from it.
1006     """
1007
1008     dns_counts = defaultdict(int)
1009     packets = []
1010     for f in files:
1011         if isinstance(f, str):
1012             f = open(f)
1013         print("Ingesting %s" % (f.name,), file=sys.stderr)
1014         for line in f:
1015             p = Packet(line)
1016             if p.protocol == 'dns' and dns_mode != 'include':
1017                 dns_counts[p.opcode] += 1
1018             else:
1019                 packets.append(p)
1020
1021         f.close()
1022
1023     if not packets:
1024         return [], 0
1025
1026     start_time = min(p.timestamp for p in packets)
1027     last_packet = max(p.timestamp for p in packets)
1028
1029     print("gathering packets into conversations", file=sys.stderr)
1030     conversations = OrderedDict()
1031     for p in packets:
1032         p.timestamp -= start_time
1033         c = conversations.get(p.endpoints)
1034         if c is None:
1035             c = Conversation()
1036             conversations[p.endpoints] = c
1037         c.add_packet(p)
1038
1039     # We only care about conversations with actual traffic, so we
1040     # filter out conversations with nothing to say. We do that here,
1041     # rather than earlier, because those empty packets contain useful
1042     # hints as to which end of the conversation was the client.
1043     conversation_list = []
1044     for c in conversations.values():
1045         if len(c) != 0:
1046             conversation_list.append(c)
1047
1048     # This is obviously not correct, as many conversations will appear
1049     # to start roughly simultaneously at the beginning of the snapshot.
1050     # To which we say: oh well, so be it.
1051     duration = float(last_packet - start_time)
1052     mean_interval = len(conversations) / duration
1053
1054     return conversation_list, mean_interval, duration, dns_counts
1055
1056
1057 def guess_server_address(conversations):
1058     # we guess the most common address.
1059     addresses = Counter()
1060     for c in conversations:
1061         addresses.update(c.endpoints)
1062     if addresses:
1063         return addresses.most_common(1)[0]
1064
1065
1066 def stringify_keys(x):
1067     y = {}
1068     for k, v in x.items():
1069         k2 = '\t'.join(k)
1070         y[k2] = v
1071     return y
1072
1073
1074 def unstringify_keys(x):
1075     y = {}
1076     for k, v in x.items():
1077         t = tuple(str(k).split('\t'))
1078         y[t] = v
1079     return y
1080
1081
1082 class TrafficModel(object):
1083     def __init__(self, n=3):
1084         self.ngrams = {}
1085         self.query_details = {}
1086         self.n = n
1087         self.dns_opcounts = defaultdict(int)
1088         self.cumulative_duration = 0.0
1089         self.conversation_rate = [0, 1]
1090
1091     def learn(self, conversations, dns_opcounts={}):
1092         prev = 0.0
1093         cum_duration = 0.0
1094         key = (NON_PACKET,) * (self.n - 1)
1095
1096         server = guess_server_address(conversations)
1097
1098         for k, v in dns_opcounts.items():
1099             self.dns_opcounts[k] += v
1100
1101         if len(conversations) > 1:
1102             elapsed =\
1103                 conversations[-1].start_time - conversations[0].start_time
1104             self.conversation_rate[0] = len(conversations)
1105             self.conversation_rate[1] = elapsed
1106
1107         for c in conversations:
1108             client, server = c.guess_client_server(server)
1109             cum_duration += c.get_duration()
1110             key = (NON_PACKET,) * (self.n - 1)
1111             for p in c:
1112                 if p.src != client:
1113                     continue
1114
1115                 elapsed = p.timestamp - prev
1116                 prev = p.timestamp
1117                 if elapsed > WAIT_THRESHOLD:
1118                     # add the wait as an extra state
1119                     wait = 'wait:%d' % (math.log(max(1.0,
1120                                                      elapsed * WAIT_SCALE)))
1121                     self.ngrams.setdefault(key, []).append(wait)
1122                     key = key[1:] + (wait,)
1123
1124                 short_p = p.as_packet_type()
1125                 self.query_details.setdefault(short_p,
1126                                               []).append(tuple(p.extra))
1127                 self.ngrams.setdefault(key, []).append(short_p)
1128                 key = key[1:] + (short_p,)
1129
1130         self.cumulative_duration += cum_duration
1131         # add in the end
1132         self.ngrams.setdefault(key, []).append(NON_PACKET)
1133
1134     def save(self, f):
1135         ngrams = {}
1136         for k, v in self.ngrams.items():
1137             k = '\t'.join(k)
1138             ngrams[k] = dict(Counter(v))
1139
1140         query_details = {}
1141         for k, v in self.query_details.items():
1142             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1143                                             for x in v))
1144
1145         d = {
1146             'ngrams': ngrams,
1147             'query_details': query_details,
1148             'cumulative_duration': self.cumulative_duration,
1149             'conversation_rate': self.conversation_rate,
1150         }
1151         d['dns'] = self.dns_opcounts
1152
1153         if isinstance(f, str):
1154             f = open(f, 'w')
1155
1156         json.dump(d, f, indent=2)
1157
1158     def load(self, f):
1159         if isinstance(f, str):
1160             f = open(f)
1161
1162         d = json.load(f)
1163
1164         for k, v in d['ngrams'].items():
1165             k = tuple(str(k).split('\t'))
1166             values = self.ngrams.setdefault(k, [])
1167             for p, count in v.items():
1168                 values.extend([str(p)] * count)
1169
1170         for k, v in d['query_details'].items():
1171             values = self.query_details.setdefault(str(k), [])
1172             for p, count in v.items():
1173                 if p == '-':
1174                     values.extend([()] * count)
1175                 else:
1176                     values.extend([tuple(str(p).split('\t'))] * count)
1177
1178         if 'dns' in d:
1179             for k, v in d['dns'].items():
1180                 self.dns_opcounts[k] += v
1181
1182         self.cumulative_duration = d['cumulative_duration']
1183         self.conversation_rate = d['conversation_rate']
1184
1185     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1186                                hard_stop=None, packet_rate=1):
1187         """Construct a individual converation from the model."""
1188
1189         c = Conversation(timestamp, (server, client))
1190
1191         key = (NON_PACKET,) * (self.n - 1)
1192
1193         while key in self.ngrams:
1194             p = random.choice(self.ngrams.get(key, NON_PACKET))
1195             if p == NON_PACKET:
1196                 break
1197             if p in self.query_details:
1198                 extra = random.choice(self.query_details[p])
1199             else:
1200                 extra = []
1201
1202             protocol, opcode = p.split(':', 1)
1203             if protocol == 'wait':
1204                 log_wait_time = int(opcode) + random.random()
1205                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1206                 timestamp += wait
1207             else:
1208                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1209                 wait = math.exp(log_wait) / packet_rate
1210                 timestamp += wait
1211                 if hard_stop is not None and timestamp > hard_stop:
1212                     break
1213                 c.add_short_packet(timestamp, p, extra)
1214
1215             key = key[1:] + (p,)
1216
1217         return c
1218
1219     def generate_conversations(self, rate, duration, packet_rate=1):
1220         """Generate a list of conversations from the model."""
1221
1222         # We run the simulation for at least ten times as long as our
1223         # desired duration, and take a section near the start.
1224         rate_n, rate_t  = self.conversation_rate
1225
1226         duration2 = max(rate_t, duration * 2)
1227         n = rate * duration2 * rate_n / rate_t
1228
1229         server = 1
1230         client = 2
1231
1232         conversations = []
1233         end = duration2
1234         start = end - duration
1235
1236         while client < n + 2:
1237             start = random.uniform(0, duration2)
1238             c = self.construct_conversation(start,
1239                                             client,
1240                                             server,
1241                                             hard_stop=(duration2 * 5),
1242                                             packet_rate=packet_rate)
1243
1244             c.forget_packets_outside_window(start, end)
1245             c.renormalise_times(start)
1246             if len(c) != 0:
1247                 conversations.append(c)
1248             client += 1
1249
1250         print(("we have %d conversations at rate %f" %
1251                               (len(conversations), rate)), file=sys.stderr)
1252         conversations.sort()
1253         return conversations
1254
1255
1256 IP_PROTOCOLS = {
1257     'dns': '11',
1258     'rpc_netlogon': '06',
1259     'kerberos': '06',      # ratio 16248:258
1260     'smb': '06',
1261     'smb2': '06',
1262     'ldap': '06',
1263     'cldap': '11',
1264     'lsarpc': '06',
1265     'samr': '06',
1266     'dcerpc': '06',
1267     'epm': '06',
1268     'drsuapi': '06',
1269     'browser': '11',
1270     'smb_netlogon': '11',
1271     'srvsvc': '06',
1272     'nbns': '11',
1273 }
1274
1275 OP_DESCRIPTIONS = {
1276     ('browser', '0x01'): 'Host Announcement (0x01)',
1277     ('browser', '0x02'): 'Request Announcement (0x02)',
1278     ('browser', '0x08'): 'Browser Election Request (0x08)',
1279     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1280     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1281     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1282     ('cldap', '3'): 'searchRequest',
1283     ('cldap', '5'): 'searchResDone',
1284     ('dcerpc', '0'): 'Request',
1285     ('dcerpc', '11'): 'Bind',
1286     ('dcerpc', '12'): 'Bind_ack',
1287     ('dcerpc', '13'): 'Bind_nak',
1288     ('dcerpc', '14'): 'Alter_context',
1289     ('dcerpc', '15'): 'Alter_context_resp',
1290     ('dcerpc', '16'): 'AUTH3',
1291     ('dcerpc', '2'): 'Response',
1292     ('dns', '0'): 'query',
1293     ('dns', '1'): 'response',
1294     ('drsuapi', '0'): 'DsBind',
1295     ('drsuapi', '12'): 'DsCrackNames',
1296     ('drsuapi', '13'): 'DsWriteAccountSpn',
1297     ('drsuapi', '1'): 'DsUnbind',
1298     ('drsuapi', '2'): 'DsReplicaSync',
1299     ('drsuapi', '3'): 'DsGetNCChanges',
1300     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1301     ('epm', '3'): 'Map',
1302     ('kerberos', ''): '',
1303     ('ldap', '0'): 'bindRequest',
1304     ('ldap', '1'): 'bindResponse',
1305     ('ldap', '2'): 'unbindRequest',
1306     ('ldap', '3'): 'searchRequest',
1307     ('ldap', '4'): 'searchResEntry',
1308     ('ldap', '5'): 'searchResDone',
1309     ('ldap', ''): '*** Unknown ***',
1310     ('lsarpc', '14'): 'lsa_LookupNames',
1311     ('lsarpc', '15'): 'lsa_LookupSids',
1312     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1313     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1314     ('lsarpc', '6'): 'lsa_OpenPolicy',
1315     ('lsarpc', '76'): 'lsa_LookupSids3',
1316     ('lsarpc', '77'): 'lsa_LookupNames4',
1317     ('nbns', '0'): 'query',
1318     ('nbns', '1'): 'response',
1319     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1320     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1321     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1322     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1323     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1324     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1325     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1326     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1327     ('samr', '0',): 'Connect',
1328     ('samr', '16'): 'GetAliasMembership',
1329     ('samr', '17'): 'LookupNames',
1330     ('samr', '18'): 'LookupRids',
1331     ('samr', '19'): 'OpenGroup',
1332     ('samr', '1'): 'Close',
1333     ('samr', '25'): 'QueryGroupMember',
1334     ('samr', '34'): 'OpenUser',
1335     ('samr', '36'): 'QueryUserInfo',
1336     ('samr', '39'): 'GetGroupsForUser',
1337     ('samr', '3'): 'QuerySecurity',
1338     ('samr', '5'): 'LookupDomain',
1339     ('samr', '64'): 'Connect5',
1340     ('samr', '6'): 'EnumDomains',
1341     ('samr', '7'): 'OpenDomain',
1342     ('samr', '8'): 'QueryDomainInfo',
1343     ('smb', '0x04'): 'Close (0x04)',
1344     ('smb', '0x24'): 'Locking AndX (0x24)',
1345     ('smb', '0x2e'): 'Read AndX (0x2e)',
1346     ('smb', '0x32'): 'Trans2 (0x32)',
1347     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1348     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1349     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1350     ('smb', '0x74'): 'Logoff AndX (0x74)',
1351     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1352     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1353     ('smb2', '0'): 'NegotiateProtocol',
1354     ('smb2', '11'): 'Ioctl',
1355     ('smb2', '14'): 'Find',
1356     ('smb2', '16'): 'GetInfo',
1357     ('smb2', '18'): 'Break',
1358     ('smb2', '1'): 'SessionSetup',
1359     ('smb2', '2'): 'SessionLogoff',
1360     ('smb2', '3'): 'TreeConnect',
1361     ('smb2', '4'): 'TreeDisconnect',
1362     ('smb2', '5'): 'Create',
1363     ('smb2', '6'): 'Close',
1364     ('smb2', '8'): 'Read',
1365     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1366     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1367                                'user unknown (0x17)'),
1368     ('srvsvc', '16'): 'NetShareGetInfo',
1369     ('srvsvc', '21'): 'NetSrvGetInfo',
1370 }
1371
1372
1373 def expand_short_packet(p, timestamp, src, dest, extra):
1374     protocol, opcode = p.split(':', 1)
1375     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1376     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1377
1378     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1379     line.extend(extra)
1380     return '\t'.join(line)
1381
1382
1383 def replay(conversations,
1384            host=None,
1385            creds=None,
1386            lp=None,
1387            accounts=None,
1388            dns_rate=0,
1389            duration=None,
1390            **kwargs):
1391
1392     context = ReplayContext(server=host,
1393                             creds=creds,
1394                             lp=lp,
1395                             **kwargs)
1396
1397     if len(accounts) < len(conversations):
1398         print(("we have %d accounts but %d conversations" %
1399                (accounts, conversations)), file=sys.stderr)
1400
1401     cstack = list(zip(
1402         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1403         accounts))
1404
1405     # Set the process group so that the calling scripts are not killed
1406     # when the forked child processes are killed.
1407     os.setpgrp()
1408
1409     start = time.time()
1410
1411     if duration is None:
1412         # end 1 second after the last packet of the last conversation
1413         # to start. Conversations other than the last could still be
1414         # going, but we don't care.
1415         duration = cstack[0][0].packets[-1].timestamp + 1.0
1416         print("We will stop after %.1f seconds" % duration,
1417               file=sys.stderr)
1418
1419     end = start + duration
1420
1421     print("Replaying traffic for %u conversations over %d seconds"
1422           % (len(conversations), duration))
1423
1424     children = {}
1425     if dns_rate:
1426         dns_hammer = DnsHammer(dns_rate, duration)
1427         cstack.append((dns_hammer, None))
1428
1429     try:
1430         while True:
1431             # we spawn a batch, wait for finishers, then spawn another
1432             now = time.time()
1433             batch_end = min(now + 2.0, end)
1434             fork_time = 0.0
1435             fork_n = 0
1436             while cstack:
1437                 c, account = cstack.pop()
1438                 if c.start_time + start > batch_end:
1439                     cstack.append((c, account))
1440                     break
1441
1442                 st = time.time()
1443                 pid = c.replay_in_fork_with_delay(start, context, account)
1444                 children[pid] = c
1445                 t = time.time()
1446                 elapsed = t - st
1447                 fork_time += elapsed
1448                 fork_n += 1
1449                 print("forked %s in pid %s (in %fs)" % (c, pid,
1450                                                         elapsed),
1451                       file=sys.stderr)
1452
1453             if fork_n:
1454                 print(("forked %d times in %f seconds (avg %f)" %
1455                        (fork_n, fork_time, fork_time / fork_n)),
1456                       file=sys.stderr)
1457             elif cstack:
1458                 debug(2, "no forks in batch ending %f" % batch_end)
1459
1460             while time.time() < batch_end - 1.0:
1461                 time.sleep(0.01)
1462                 try:
1463                     pid, status = os.waitpid(-1, os.WNOHANG)
1464                 except OSError as e:
1465                     if e.errno != 10:  # no child processes
1466                         raise
1467                     break
1468                 if pid:
1469                     c = children.pop(pid, None)
1470                     print(("process %d finished conversation %s;"
1471                            " %d to go" %
1472                            (pid, c, len(children))), file=sys.stderr)
1473
1474             if time.time() >= end:
1475                 print("time to stop", file=sys.stderr)
1476                 break
1477
1478     except Exception:
1479         print("EXCEPTION in parent", file=sys.stderr)
1480         traceback.print_exc()
1481     finally:
1482         for s in (15, 15, 9):
1483             print(("killing %d children with -%d" %
1484                                  (len(children), s)), file=sys.stderr)
1485             for pid in children:
1486                 try:
1487                     os.kill(pid, s)
1488                 except OSError as e:
1489                     if e.errno != 3:  # don't fail if it has already died
1490                         raise
1491             time.sleep(0.5)
1492             end = time.time() + 1
1493             while children:
1494                 try:
1495                     pid, status = os.waitpid(-1, os.WNOHANG)
1496                 except OSError as e:
1497                     if e.errno != 10:
1498                         raise
1499                 if pid != 0:
1500                     c = children.pop(pid, None)
1501                     print(("kill -%d %d KILLED conversation %s; "
1502                            "%d to go" %
1503                            (s, pid, c, len(children))),
1504                           file=sys.stderr)
1505                 if time.time() >= end:
1506                     break
1507
1508             if not children:
1509                 break
1510             time.sleep(1)
1511
1512         if children:
1513             print("%d children are missing" % len(children),
1514                   file=sys.stderr)
1515
1516         # there may be stragglers that were forked just as ^C was hit
1517         # and don't appear in the list of children. We can get them
1518         # with killpg, but that will also kill us, so this is^H^H would be
1519         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1520         # so as not to have to fuss around writing signal handlers.
1521         try:
1522             os.killpg(0, 2)
1523         except KeyboardInterrupt:
1524             print("ignoring fake ^C", file=sys.stderr)
1525
1526
1527 def openLdb(host, creds, lp):
1528     session = system_session()
1529     ldb = SamDB(url="ldap://%s" % host,
1530                 session_info=session,
1531                 credentials=creds,
1532                 lp=lp)
1533     return ldb
1534
1535
1536 def ou_name(ldb, instance_id):
1537     """Generate an ou name from the instance id"""
1538     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1539                                                     ldb.domain_dn())
1540
1541
1542 def create_ou(ldb, instance_id):
1543     """Create an ou, all created user and machine accounts will belong to it.
1544
1545     This allows all the created resources to be cleaned up easily.
1546     """
1547     ou = ou_name(ldb, instance_id)
1548     try:
1549         ldb.add({"dn":          ou.split(',', 1)[1],
1550                  "objectclass": "organizationalunit"})
1551     except LdbError as e:
1552         (status, _) = e
1553         # ignore already exists
1554         if status != 68:
1555             raise
1556     try:
1557         ldb.add({"dn":          ou,
1558                  "objectclass": "organizationalunit"})
1559     except LdbError as e:
1560         (status, _) = e
1561         # ignore already exists
1562         if status != 68:
1563             raise
1564     return ou
1565
1566
1567 class ConversationAccounts(object):
1568     """Details of the machine and user accounts associated with a conversation.
1569     """
1570     def __init__(self, netbios_name, machinepass, username, userpass):
1571         self.netbios_name = netbios_name
1572         self.machinepass  = machinepass
1573         self.username     = username
1574         self.userpass     = userpass
1575
1576
1577 def generate_replay_accounts(ldb, instance_id, number, password):
1578     """Generate a series of unique machine and user account names."""
1579
1580     generate_traffic_accounts(ldb, instance_id, number, password)
1581     accounts = []
1582     for i in range(1, number + 1):
1583         netbios_name = "STGM-%d-%d" % (instance_id, i)
1584         username     = "STGU-%d-%d" % (instance_id, i)
1585
1586         account = ConversationAccounts(netbios_name, password, username,
1587                                        password)
1588         accounts.append(account)
1589     return accounts
1590
1591
1592 def generate_traffic_accounts(ldb, instance_id, number, password):
1593     """Create the specified number of user and machine accounts.
1594
1595     As accounts are not explicitly deleted between runs. This function starts
1596     with the last account and iterates backwards stopping either when it
1597     finds an already existing account or it has generated all the required
1598     accounts.
1599     """
1600     print(("Generating machine and conversation accounts, "
1601            "as required for %d conversations" % number),
1602           file=sys.stderr)
1603     added = 0
1604     for i in range(number, 0, -1):
1605         try:
1606             netbios_name = "STGM-%d-%d" % (instance_id, i)
1607             create_machine_account(ldb, instance_id, netbios_name, password)
1608             added += 1
1609         except LdbError as e:
1610             (status, _) = e
1611             if status == 68:
1612                 break
1613             else:
1614                 raise
1615     if added > 0:
1616         print("Added %d new machine accounts" % added,
1617               file=sys.stderr)
1618
1619     added = 0
1620     for i in range(number, 0, -1):
1621         try:
1622             username = "STGU-%d-%d" % (instance_id, i)
1623             create_user_account(ldb, instance_id, username, password)
1624             added += 1
1625         except LdbError as e:
1626             (status, _) = e
1627             if status == 68:
1628                 break
1629             else:
1630                 raise
1631
1632     if added > 0:
1633         print("Added %d new user accounts" % added,
1634               file=sys.stderr)
1635
1636
1637 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1638     """Create a machine account via ldap."""
1639
1640     ou = ou_name(ldb, instance_id)
1641     dn = "cn=%s,%s" % (netbios_name, ou)
1642     utf16pw = unicode(
1643         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1644     ).encode('utf-16-le')
1645     start = time.time()
1646     ldb.add({
1647         "dn": dn,
1648         "objectclass": "computer",
1649         "sAMAccountName": "%s$" % netbios_name,
1650         "userAccountControl":
1651         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1652         "unicodePwd": utf16pw})
1653     end = time.time()
1654     duration = end - start
1655     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1656
1657
1658 def create_user_account(ldb, instance_id, username, userpass):
1659     """Create a user account via ldap."""
1660     ou = ou_name(ldb, instance_id)
1661     user_dn = "cn=%s,%s" % (username, ou)
1662     utf16pw = unicode(
1663         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1664     ).encode('utf-16-le')
1665     start = time.time()
1666     ldb.add({
1667         "dn": user_dn,
1668         "objectclass": "user",
1669         "sAMAccountName": username,
1670         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1671         "unicodePwd": utf16pw
1672     })
1673     end = time.time()
1674     duration = end - start
1675     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1676
1677
1678 def create_group(ldb, instance_id, name):
1679     """Create a group via ldap."""
1680
1681     ou = ou_name(ldb, instance_id)
1682     dn = "cn=%s,%s" % (name, ou)
1683     start = time.time()
1684     ldb.add({
1685         "dn": dn,
1686         "objectclass": "group",
1687     })
1688     end = time.time()
1689     duration = end - start
1690     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1691
1692
1693 def user_name(instance_id, i):
1694     """Generate a user name based in the instance id"""
1695     return "STGU-%d-%d" % (instance_id, i)
1696
1697
1698 def generate_users(ldb, instance_id, number, password):
1699     """Add users to the server"""
1700     users = 0
1701     for i in range(number, 0, -1):
1702         try:
1703             username = user_name(instance_id, i)
1704             create_user_account(ldb, instance_id, username, password)
1705             users += 1
1706         except LdbError as e:
1707             (status, _) = e
1708             # Stop if entry exists
1709             if status == 68:
1710                 break
1711             else:
1712                 raise
1713
1714     return users
1715
1716
1717 def group_name(instance_id, i):
1718     """Generate a group name from instance id."""
1719     return "STGG-%d-%d" % (instance_id, i)
1720
1721
1722 def generate_groups(ldb, instance_id, number):
1723     """Create the required number of groups on the server."""
1724     groups = 0
1725     for i in range(number, 0, -1):
1726         try:
1727             name = group_name(instance_id, i)
1728             create_group(ldb, instance_id, name)
1729             groups += 1
1730         except LdbError as e:
1731             (status, _) = e
1732             # Stop if entry exists
1733             if status == 68:
1734                 break
1735             else:
1736                 raise
1737     return groups
1738
1739
1740 def clean_up_accounts(ldb, instance_id):
1741     """Remove the created accounts and groups from the server."""
1742     ou = ou_name(ldb, instance_id)
1743     try:
1744         ldb.delete(ou, ["tree_delete:1"])
1745     except LdbError as e:
1746         (status, _) = e
1747         # ignore does not exist
1748         if status != 32:
1749             raise
1750
1751
1752 def generate_users_and_groups(ldb, instance_id, password,
1753                               number_of_users, number_of_groups,
1754                               group_memberships):
1755     """Generate the required users and groups, allocating the users to
1756        those groups."""
1757     assignments = []
1758     groups_added  = 0
1759
1760     create_ou(ldb, instance_id)
1761
1762     print("Generating dummy user accounts", file=sys.stderr)
1763     users_added = generate_users(ldb, instance_id, number_of_users, password)
1764
1765     if number_of_groups > 0:
1766         print("Generating dummy groups", file=sys.stderr)
1767         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1768
1769     if group_memberships > 0:
1770         print("Assigning users to groups", file=sys.stderr)
1771         assignments = assign_groups(number_of_groups,
1772                                     groups_added,
1773                                     number_of_users,
1774                                     users_added,
1775                                     group_memberships)
1776         print("Adding users to groups", file=sys.stderr)
1777         add_users_to_groups(ldb, instance_id, assignments)
1778
1779     if (groups_added > 0 and users_added == 0 and
1780        number_of_groups != groups_added):
1781         print("Warning: the added groups will contain no members",
1782               file=sys.stderr)
1783
1784     print(("Added %d users, %d groups and %d group memberships" %
1785            (users_added, groups_added, len(assignments))),
1786           file=sys.stderr)
1787
1788
1789 def assign_groups(number_of_groups,
1790                   groups_added,
1791                   number_of_users,
1792                   users_added,
1793                   group_memberships):
1794     """Allocate users to groups.
1795
1796     The intention is to have a few users that belong to most groups, while
1797     the majority of users belong to a few groups.
1798
1799     A few groups will contain most users, with the remaining only having a
1800     few users.
1801     """
1802
1803     def generate_user_distribution(n):
1804         """Probability distribution of a user belonging to a group.
1805         """
1806         dist = []
1807         for x in range(1, n + 1):
1808             p = 1 / (x + 0.001)
1809             dist.append(p)
1810         return dist
1811
1812     def generate_group_distribution(n):
1813         """Probability distribution of a group containing a user."""
1814         dist = []
1815         for x in range(1, n + 1):
1816             p = 1 / (x**1.3)
1817             dist.append(p)
1818         return dist
1819
1820     assignments = set()
1821     if group_memberships <= 0:
1822         return assignments
1823
1824     group_dist = generate_group_distribution(number_of_groups)
1825     user_dist  = generate_user_distribution(number_of_users)
1826
1827     # Calculate the number of group menberships required
1828     group_memberships = math.ceil(
1829         float(group_memberships) *
1830         (float(users_added) / float(number_of_users)))
1831
1832     existing_users  = number_of_users  - users_added  - 1
1833     existing_groups = number_of_groups - groups_added - 1
1834     while len(assignments) < group_memberships:
1835         user        = random.randint(0, number_of_users - 1)
1836         group       = random.randint(0, number_of_groups - 1)
1837         probability = group_dist[group] * user_dist[user]
1838
1839         if ((random.random() < probability * 10000) and
1840            (group > existing_groups or user > existing_users)):
1841             # the + 1 converts the array index to the corresponding
1842             # group or user number
1843             assignments.add(((user + 1), (group + 1)))
1844
1845     return assignments
1846
1847
1848 def add_users_to_groups(db, instance_id, assignments):
1849     """Add users to their assigned groups.
1850
1851     Takes the list of (group,user) tuples generated by assign_groups and
1852     assign the users to their specified groups."""
1853
1854     ou = ou_name(db, instance_id)
1855
1856     def build_dn(name):
1857         return("cn=%s,%s" % (name, ou))
1858
1859     for (user, group) in assignments:
1860         user_dn  = build_dn(user_name(instance_id, user))
1861         group_dn = build_dn(group_name(instance_id, group))
1862
1863         m = ldb.Message()
1864         m.dn = ldb.Dn(db, group_dn)
1865         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1866         start = time.time()
1867         db.modify(m)
1868         end = time.time()
1869         duration = end - start
1870         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1871
1872
1873 def generate_stats(statsdir, timing_file):
1874     """Generate and print the summary stats for a run."""
1875     first      = sys.float_info.max
1876     last       = 0
1877     successful = 0
1878     failed     = 0
1879     latencies  = {}
1880     failures   = {}
1881     unique_converations = set()
1882     conversations = 0
1883
1884     if timing_file is not None:
1885         tw = timing_file.write
1886     else:
1887         def tw(x):
1888             pass
1889
1890     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1891
1892     for filename in os.listdir(statsdir):
1893         path = os.path.join(statsdir, filename)
1894         with open(path, 'r') as f:
1895             for line in f:
1896                 try:
1897                     fields       = line.rstrip('\n').split('\t')
1898                     conversation = fields[1]
1899                     protocol     = fields[2]
1900                     packet_type  = fields[3]
1901                     latency      = float(fields[4])
1902                     first        = min(float(fields[0]) - latency, first)
1903                     last         = max(float(fields[0]), last)
1904
1905                     if protocol not in latencies:
1906                         latencies[protocol] = {}
1907                     if packet_type not in latencies[protocol]:
1908                         latencies[protocol][packet_type] = []
1909
1910                     latencies[protocol][packet_type].append(latency)
1911
1912                     if protocol not in failures:
1913                         failures[protocol] = {}
1914                     if packet_type not in failures[protocol]:
1915                         failures[protocol][packet_type] = 0
1916
1917                     if fields[5] == 'True':
1918                         successful += 1
1919                     else:
1920                         failed += 1
1921                         failures[protocol][packet_type] += 1
1922
1923                     if conversation not in unique_converations:
1924                         unique_converations.add(conversation)
1925                         conversations += 1
1926
1927                     tw(line)
1928                 except (ValueError, IndexError):
1929                     # not a valid line print and ignore
1930                     print(line, file=sys.stderr)
1931                     pass
1932     duration = last - first
1933     if successful == 0:
1934         success_rate = 0
1935     else:
1936         success_rate = successful / duration
1937     if failed == 0:
1938         failure_rate = 0
1939     else:
1940         failure_rate = failed / duration
1941
1942     # print the stats in more human-readable format when stdout is going to the
1943     # console (as opposed to being redirected to a file)
1944     if sys.stdout.isatty():
1945         print("Total conversations:   %10d" % conversations)
1946         print("Successful operations: %10d (%.3f per second)"
1947               % (successful, success_rate))
1948         print("Failed operations:     %10d (%.3f per second)"
1949               % (failed, failure_rate))
1950     else:
1951         print("(%d, %d, %d, %.3f, %.3f)" %
1952               (conversations, successful, failed, success_rate, failure_rate))
1953
1954     if sys.stdout.isatty():
1955         print("Protocol    Op Code  Description                               "
1956               " Count       Failed         Mean       Median          "
1957               "95%        Range          Max")
1958     else:
1959         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1960               "\tmax")
1961     protocols = sorted(latencies.keys())
1962     for protocol in protocols:
1963         packet_types = sorted(latencies[protocol], key=opcode_key)
1964         for packet_type in packet_types:
1965             values     = latencies[protocol][packet_type]
1966             values     = sorted(values)
1967             count      = len(values)
1968             failed     = failures[protocol][packet_type]
1969             mean       = sum(values) / count
1970             median     = calc_percentile(values, 0.50)
1971             percentile = calc_percentile(values, 0.95)
1972             rng        = values[-1] - values[0]
1973             maxv       = values[-1]
1974             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1975             if sys.stdout.isatty:
1976                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1977                       "%12.6f %12.6f %12.6f %12.6f"
1978                       % (protocol,
1979                          packet_type,
1980                          desc,
1981                          count,
1982                          failed,
1983                          mean,
1984                          median,
1985                          percentile,
1986                          rng,
1987                          maxv))
1988             else:
1989                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
1990                       % (protocol,
1991                          packet_type,
1992                          desc,
1993                          count,
1994                          failed,
1995                          mean,
1996                          median,
1997                          percentile,
1998                          rng,
1999                          maxv))
2000
2001
2002 def opcode_key(v):
2003     """Sort key for the operation code to ensure that it sorts numerically"""
2004     try:
2005         return "%03d" % int(v)
2006     except:
2007         return v
2008
2009
2010 def calc_percentile(values, percentile):
2011     """Calculate the specified percentile from the list of values.
2012
2013     Assumes the list is sorted in ascending order.
2014     """
2015
2016     if not values:
2017         return 0
2018     k = (len(values) - 1) * percentile
2019     f = math.floor(k)
2020     c = math.ceil(k)
2021     if f == c:
2022         return values[int(k)]
2023     d0 = values[int(f)] * (c - k)
2024     d1 = values[int(c)] * (k - f)
2025     return d0 + d1
2026
2027
2028 def mk_masked_dir(*path):
2029     """In a testenv we end up with 0777 diectories that look an alarming
2030     green colour with ls. Use umask to avoid that."""
2031     d = os.path.join(*path)
2032     mask = os.umask(0o077)
2033     os.mkdir(d)
2034     os.umask(mask)
2035     return d