traffic: add credentials to samr
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49
50 SLEEP_OVERHEAD = 3e-4
51
52 # we don't use None, because it complicates [de]serialisation
53 NON_PACKET = '-'
54
55 CLIENT_CLUES = {
56     ('dns', '0'): 1.0,      # query
57     ('smb', '0x72'): 1.0,   # Negotiate protocol
58     ('ldap', '0'): 1.0,     # bind
59     ('ldap', '3'): 1.0,     # searchRequest
60     ('ldap', '2'): 1.0,     # unbindRequest
61     ('cldap', '3'): 1.0,
62     ('dcerpc', '11'): 1.0,  # bind
63     ('dcerpc', '14'): 1.0,  # Alter_context
64     ('nbns', '0'): 1.0,     # query
65 }
66
67 SERVER_CLUES = {
68     ('dns', '1'): 1.0,      # response
69     ('ldap', '1'): 1.0,     # bind response
70     ('ldap', '4'): 1.0,     # search result
71     ('ldap', '5'): 1.0,     # search done
72     ('cldap', '5'): 1.0,
73     ('dcerpc', '12'): 1.0,  # bind_ack
74     ('dcerpc', '13'): 1.0,  # bind_nak
75     ('dcerpc', '15'): 1.0,  # Alter_context response
76 }
77
78 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
79
80 WAIT_SCALE = 10.0
81 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
82 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
83
84 # DEBUG_LEVEL can be changed by scripts with -d
85 DEBUG_LEVEL = 0
86
87
88 def debug(level, msg, *args):
89     """Print a formatted debug message to standard error.
90
91
92     :param level: The debug level, message will be printed if it is <= the
93                   currently set debug level. The debug level can be set with
94                   the -d option.
95     :param msg:   The message to be logged, can contain C-Style format
96                   specifiers
97     :param args:  The parameters required by the format specifiers
98     """
99     if level <= DEBUG_LEVEL:
100         if not args:
101             print(msg, file=sys.stderr)
102         else:
103             print(msg % tuple(args), file=sys.stderr)
104
105
106 def debug_lineno(*args):
107     """ Print an unformatted log message to stderr, contaning the line number
108     """
109     tb = traceback.extract_stack(limit=2)
110     print((" %s:" "\033[01;33m"
111            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
112           file=sys.stderr)
113     for a in args:
114         print(a, file=sys.stderr)
115     print(file=sys.stderr)
116     sys.stderr.flush()
117
118
119 def random_colour_print():
120     """Return a function that prints a randomly coloured line to stderr"""
121     n = 18 + random.randrange(214)
122     prefix = "\033[38;5;%dm" % n
123
124     def p(*args):
125         for a in args:
126             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
127
128     return p
129
130
131 class FakePacketError(Exception):
132     pass
133
134
135 class Packet(object):
136     """Details of a network packet"""
137     def __init__(self, fields):
138         if isinstance(fields, str):
139             fields = fields.rstrip('\n').split('\t')
140
141         (timestamp,
142          ip_protocol,
143          stream_number,
144          src,
145          dest,
146          protocol,
147          opcode,
148          desc) = fields[:8]
149         extra = fields[8:]
150
151         self.timestamp = float(timestamp)
152         self.ip_protocol = ip_protocol
153         try:
154             self.stream_number = int(stream_number)
155         except (ValueError, TypeError):
156             self.stream_number = None
157         self.src = int(src)
158         self.dest = int(dest)
159         self.protocol = protocol
160         self.opcode = opcode
161         self.desc = desc
162         self.extra = extra
163
164         if self.src < self.dest:
165             self.endpoints = (self.src, self.dest)
166         else:
167             self.endpoints = (self.dest, self.src)
168
169     def as_summary(self, time_offset=0.0):
170         """Format the packet as a traffic_summary line.
171         """
172         extra = '\t'.join(self.extra)
173         t = self.timestamp + time_offset
174         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
175                 (t,
176                  self.ip_protocol,
177                  self.stream_number or '',
178                  self.src,
179                  self.dest,
180                  self.protocol,
181                  self.opcode,
182                  self.desc,
183                  extra))
184
185     def __str__(self):
186         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
187                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
188                  self.stream_number, self.protocol, self.opcode, self.desc,
189                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
190
191     def __repr__(self):
192         return "<Packet @%s>" % self
193
194     def copy(self):
195         return self.__class__([self.timestamp,
196                                self.ip_protocol,
197                                self.stream_number,
198                                self.src,
199                                self.dest,
200                                self.protocol,
201                                self.opcode,
202                                self.desc] + self.extra)
203
204     def as_packet_type(self):
205         t = '%s:%s' % (self.protocol, self.opcode)
206         return t
207
208     def client_score(self):
209         """A positive number means we think it is a client; a negative number
210         means we think it is a server. Zero means no idea. range: -1 to 1.
211         """
212         key = (self.protocol, self.opcode)
213         if key in CLIENT_CLUES:
214             return CLIENT_CLUES[key]
215         if key in SERVER_CLUES:
216             return -SERVER_CLUES[key]
217         return 0.0
218
219     def play(self, conversation, context):
220         """Send the packet over the network, if required.
221
222         Some packets are ignored, i.e. for  protocols not handled,
223         server response messages, or messages that are generated by the
224         protocol layer associated with other packets.
225         """
226         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
227         try:
228             fn = getattr(traffic_packets, fn_name)
229
230         except AttributeError as e:
231             print("Conversation(%s) Missing handler %s" % \
232                   (conversation.conversation_id, fn_name),
233                   file=sys.stderr)
234             return
235
236         # Don't display a message for kerberos packets, they're not directly
237         # generated they're used to indicate kerberos should be used
238         if self.protocol != "kerberos":
239             debug(2, "Conversation(%s) Calling handler %s" %
240                      (conversation.conversation_id, fn_name))
241
242         start = time.time()
243         try:
244             if fn(self, conversation, context):
245                 # Only collect timing data for functions that generate
246                 # network traffic, or fail
247                 end = time.time()
248                 duration = end - start
249                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
250                       (end, conversation.conversation_id, self.protocol,
251                        self.opcode, duration))
252         except Exception as e:
253             end = time.time()
254             duration = end - start
255             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
256                   (end, conversation.conversation_id, self.protocol,
257                    self.opcode, duration, e))
258
259     def __cmp__(self, other):
260         return self.timestamp - other.timestamp
261
262     def is_really_a_packet(self, missing_packet_stats=None):
263         """Is the packet one that can be ignored?
264
265         If so removing it will have no effect on the replay
266         """
267         if self.protocol in SKIPPED_PROTOCOLS:
268             # Ignore any packets for the protocols we're not interested in.
269             return False
270         if self.protocol == "ldap" and self.opcode == '':
271             # skip ldap continuation packets
272             return False
273
274         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
275         try:
276             fn = getattr(traffic_packets, fn_name)
277             if fn is traffic_packets.null_packet:
278                 return False
279         except AttributeError:
280             print("missing packet %s" % fn_name, file=sys.stderr)
281             return False
282         return True
283
284
285 class ReplayContext(object):
286     """State/Context for an individual conversation between an simulated client
287        and a server.
288     """
289
290     def __init__(self,
291                  server=None,
292                  lp=None,
293                  creds=None,
294                  badpassword_frequency=None,
295                  prefer_kerberos=None,
296                  tempdir=None,
297                  statsdir=None,
298                  ou=None,
299                  base_dn=None,
300                  domain=None,
301                  domain_sid=None):
302
303         self.server                   = server
304         self.ldap_connections         = []
305         self.dcerpc_connections       = []
306         self.lsarpc_connections       = []
307         self.lsarpc_connections_named = []
308         self.drsuapi_connections      = []
309         self.srvsvc_connections       = []
310         self.samr_contexts            = []
311         self.netlogon_connection      = None
312         self.creds                    = creds
313         self.lp                       = lp
314         self.prefer_kerberos          = prefer_kerberos
315         self.ou                       = ou
316         self.base_dn                  = base_dn
317         self.domain                   = domain
318         self.statsdir                 = statsdir
319         self.global_tempdir           = tempdir
320         self.domain_sid               = domain_sid
321         self.realm                    = lp.get('realm')
322
323         # Bad password attempt controls
324         self.badpassword_frequency    = badpassword_frequency
325         self.last_lsarpc_bad          = False
326         self.last_lsarpc_named_bad    = False
327         self.last_simple_bind_bad     = False
328         self.last_bind_bad            = False
329         self.last_srvsvc_bad          = False
330         self.last_drsuapi_bad         = False
331         self.last_netlogon_bad        = False
332         self.last_samlogon_bad        = False
333         self.generate_ldap_search_tables()
334         self.next_conversation_id = itertools.count().next
335
336     def generate_ldap_search_tables(self):
337         session = system_session()
338
339         db = SamDB(url="ldap://%s" % self.server,
340                    session_info=session,
341                    credentials=self.creds,
342                    lp=self.lp)
343
344         res = db.search(db.domain_dn(),
345                         scope=ldb.SCOPE_SUBTREE,
346                         controls=["paged_results:1:1000"],
347                         attrs=['dn'])
348
349         # find a list of dns for each pattern
350         # e.g. CN,CN,CN,DC,DC
351         dn_map = {}
352         attribute_clue_map = {
353             'invocationId': []
354         }
355
356         for r in res:
357             dn = str(r.dn)
358             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
359             dns = dn_map.setdefault(pattern, [])
360             dns.append(dn)
361             if dn.startswith('CN=NTDS Settings,'):
362                 attribute_clue_map['invocationId'].append(dn)
363
364         # extend the map in case we are working with a different
365         # number of DC components.
366         # for k, v in self.dn_map.items():
367         #     print >>sys.stderr, k, len(v)
368
369         for k, v in dn_map.items():
370             if k[-3:] != ',DC':
371                 continue
372             p = k[:-3]
373             while p[-3:] == ',DC':
374                 p = p[:-3]
375             for i in range(5):
376                 p += ',DC'
377                 if p != k and p in dn_map:
378                     print('dn_map collison %s %s' % (k, p),
379                           file=sys.stderr)
380                     continue
381                 dn_map[p] = dn_map[k]
382
383         self.dn_map = dn_map
384         self.attribute_clue_map = attribute_clue_map
385
386     def generate_process_local_config(self, account, conversation):
387         if account is None:
388             return
389         self.netbios_name             = account.netbios_name
390         self.machinepass              = account.machinepass
391         self.username                 = account.username
392         self.userpass                 = account.userpass
393
394         self.tempdir = mk_masked_dir(self.global_tempdir,
395                                      'conversation-%d' %
396                                      conversation.conversation_id)
397
398         self.lp.set("private dir",     self.tempdir)
399         self.lp.set("lock dir",        self.tempdir)
400         self.lp.set("state directory", self.tempdir)
401         self.lp.set("tls verify peer", "no_check")
402
403         # If the domain was not specified, check for the environment
404         # variable.
405         if self.domain is None:
406             self.domain = os.environ["DOMAIN"]
407
408         self.remoteAddress = "/root/ncalrpc_as_system"
409         self.samlogon_dn   = ("cn=%s,%s" %
410                               (self.netbios_name, self.ou))
411         self.user_dn       = ("cn=%s,%s" %
412                               (self.username, self.ou))
413
414         self.generate_machine_creds()
415         self.generate_user_creds()
416
417     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
418         """Execute the supplied logon function, randomly choosing the
419            bad credentials.
420
421            Based on the frequency in badpassword_frequency randomly perform the
422            function with the supplied bad credentials.
423            If run with bad credentials, the function is re-run with the good
424            credentials.
425            failed_last_time is used to prevent consecutive bad credential
426            attempts. So the over all bad credential frequency will be lower
427            than that requested, but not significantly.
428         """
429         if not failed_last_time:
430             if (self.badpassword_frequency > 0 and
431                random.random() < self.badpassword_frequency):
432                 try:
433                     f(bad)
434                 except:
435                     # Ignore any exceptions as the operation may fail
436                     # as it's being performed with bad credentials
437                     pass
438                 failed_last_time = True
439             else:
440                 failed_last_time = False
441
442         result = f(good)
443         return (result, failed_last_time)
444
445     def generate_user_creds(self):
446         """Generate the conversation specific user Credentials.
447
448         Each Conversation has an associated user account used to simulate
449         any non Administrative user traffic.
450
451         Generates user credentials with good and bad passwords and ldap
452         simple bind credentials with good and bad passwords.
453         """
454         self.user_creds = Credentials()
455         self.user_creds.guess(self.lp)
456         self.user_creds.set_workstation(self.netbios_name)
457         self.user_creds.set_password(self.userpass)
458         self.user_creds.set_username(self.username)
459         if self.prefer_kerberos:
460             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
461         else:
462             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
463
464         self.user_creds_bad = Credentials()
465         self.user_creds_bad.guess(self.lp)
466         self.user_creds_bad.set_workstation(self.netbios_name)
467         self.user_creds_bad.set_password(self.userpass[:-4])
468         self.user_creds_bad.set_username(self.username)
469         if self.prefer_kerberos:
470             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
471         else:
472             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
473
474         # Credentials for ldap simple bind.
475         self.simple_bind_creds = Credentials()
476         self.simple_bind_creds.guess(self.lp)
477         self.simple_bind_creds.set_workstation(self.netbios_name)
478         self.simple_bind_creds.set_password(self.userpass)
479         self.simple_bind_creds.set_username(self.username)
480         self.simple_bind_creds.set_gensec_features(
481             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
482         if self.prefer_kerberos:
483             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
484         else:
485             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
486         self.simple_bind_creds.set_bind_dn(self.user_dn)
487
488         self.simple_bind_creds_bad = Credentials()
489         self.simple_bind_creds_bad.guess(self.lp)
490         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
491         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
492         self.simple_bind_creds_bad.set_username(self.username)
493         self.simple_bind_creds_bad.set_gensec_features(
494             self.simple_bind_creds_bad.get_gensec_features() |
495             gensec.FEATURE_SEAL)
496         if self.prefer_kerberos:
497             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
498         else:
499             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
500         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
501
502     def generate_machine_creds(self):
503         """Generate the conversation specific machine Credentials.
504
505         Each Conversation has an associated machine account.
506
507         Generates machine credentials with good and bad passwords.
508         """
509
510         self.machine_creds = Credentials()
511         self.machine_creds.guess(self.lp)
512         self.machine_creds.set_workstation(self.netbios_name)
513         self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
514         self.machine_creds.set_password(self.machinepass)
515         self.machine_creds.set_username(self.netbios_name + "$")
516         if self.prefer_kerberos:
517             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
518         else:
519             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
520
521         self.machine_creds_bad = Credentials()
522         self.machine_creds_bad.guess(self.lp)
523         self.machine_creds_bad.set_workstation(self.netbios_name)
524         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
525         self.machine_creds_bad.set_password(self.machinepass[:-4])
526         self.machine_creds_bad.set_username(self.netbios_name + "$")
527         if self.prefer_kerberos:
528             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
529         else:
530             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
531
532     def get_matching_dn(self, pattern, attributes=None):
533         # If the pattern is an empty string, we assume ROOTDSE,
534         # Otherwise we try adding or removing DC suffixes, then
535         # shorter leading patterns until we hit one.
536         # e.g if there is no CN,CN,CN,CN,DC,DC
537         # we first try       CN,CN,CN,CN,DC
538         # and                CN,CN,CN,CN,DC,DC,DC
539         # then change to        CN,CN,CN,DC,DC
540         # and as last resort we use the base_dn
541         attr_clue = self.attribute_clue_map.get(attributes)
542         if attr_clue:
543             return random.choice(attr_clue)
544
545         pattern = pattern.upper()
546         while pattern:
547             if pattern in self.dn_map:
548                 return random.choice(self.dn_map[pattern])
549             # chop one off the front and try it all again.
550             pattern = pattern[3:]
551
552         return self.base_dn
553
554     def get_dcerpc_connection(self, new=False):
555         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
556         if self.dcerpc_connections and not new:
557             return self.dcerpc_connections[-1]
558         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
559                              (guid, 1), self.lp)
560         self.dcerpc_connections.append(c)
561         return c
562
563     def get_srvsvc_connection(self, new=False):
564         if self.srvsvc_connections and not new:
565             return self.srvsvc_connections[-1]
566
567         def connect(creds):
568             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
569                                  self.lp,
570                                  creds)
571
572         (c, self.last_srvsvc_bad) = \
573             self.with_random_bad_credentials(connect,
574                                              self.user_creds,
575                                              self.user_creds_bad,
576                                              self.last_srvsvc_bad)
577
578         self.srvsvc_connections.append(c)
579         return c
580
581     def get_lsarpc_connection(self, new=False):
582         if self.lsarpc_connections and not new:
583             return self.lsarpc_connections[-1]
584
585         def connect(creds):
586             binding_options = 'schannel,seal,sign'
587             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
588                               (self.server, binding_options),
589                               self.lp,
590                               creds)
591
592         (c, self.last_lsarpc_bad) = \
593             self.with_random_bad_credentials(connect,
594                                              self.machine_creds,
595                                              self.machine_creds_bad,
596                                              self.last_lsarpc_bad)
597
598         self.lsarpc_connections.append(c)
599         return c
600
601     def get_lsarpc_named_pipe_connection(self, new=False):
602         if self.lsarpc_connections_named and not new:
603             return self.lsarpc_connections_named[-1]
604
605         def connect(creds):
606             return lsa.lsarpc("ncacn_np:%s" % (self.server),
607                               self.lp,
608                               creds)
609
610         (c, self.last_lsarpc_named_bad) = \
611             self.with_random_bad_credentials(connect,
612                                              self.machine_creds,
613                                              self.machine_creds_bad,
614                                              self.last_lsarpc_named_bad)
615
616         self.lsarpc_connections_named.append(c)
617         return c
618
619     def get_drsuapi_connection_pair(self, new=False, unbind=False):
620         """get a (drs, drs_handle) tuple"""
621         if self.drsuapi_connections and not new:
622             c = self.drsuapi_connections[-1]
623             return c
624
625         def connect(creds):
626             binding_options = 'seal'
627             binding_string = "ncacn_ip_tcp:%s[%s]" %\
628                              (self.server, binding_options)
629             return drsuapi.drsuapi(binding_string, self.lp, creds)
630
631         (drs, self.last_drsuapi_bad) = \
632             self.with_random_bad_credentials(connect,
633                                              self.user_creds,
634                                              self.user_creds_bad,
635                                              self.last_drsuapi_bad)
636
637         (drs_handle, supported_extensions) = drs_DsBind(drs)
638         c = (drs, drs_handle)
639         self.drsuapi_connections.append(c)
640         return c
641
642     def get_ldap_connection(self, new=False, simple=False):
643         if self.ldap_connections and not new:
644             return self.ldap_connections[-1]
645
646         def simple_bind(creds):
647             return SamDB('ldaps://%s' % self.server,
648                          credentials=creds,
649                          lp=self.lp)
650
651         def sasl_bind(creds):
652             return SamDB('ldap://%s' % self.server,
653                          credentials=creds,
654                          lp=self.lp)
655         if simple:
656             (samdb, self.last_simple_bind_bad) = \
657                 self.with_random_bad_credentials(simple_bind,
658                                                  self.simple_bind_creds,
659                                                  self.simple_bind_creds_bad,
660                                                  self.last_simple_bind_bad)
661         else:
662             (samdb, self.last_bind_bad) = \
663                 self.with_random_bad_credentials(sasl_bind,
664                                                  self.user_creds,
665                                                  self.user_creds_bad,
666                                                  self.last_bind_bad)
667
668         self.ldap_connections.append(samdb)
669         return samdb
670
671     def get_samr_context(self, new=False):
672         if not self.samr_contexts or new:
673             self.samr_contexts.append(
674                 SamrContext(self.server, lp=self.lp, creds=self.creds))
675         return self.samr_contexts[-1]
676
677     def get_netlogon_connection(self):
678
679         if self.netlogon_connection:
680             return self.netlogon_connection
681
682         def connect(creds):
683             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
684                                      (self.server),
685                                      self.lp,
686                                      creds)
687         (c, self.last_netlogon_bad) = \
688             self.with_random_bad_credentials(connect,
689                                              self.machine_creds,
690                                              self.machine_creds_bad,
691                                              self.last_netlogon_bad)
692         self.netlogon_connection = c
693         return c
694
695     def guess_a_dns_lookup(self):
696         return (self.realm, 'A')
697
698     def get_authenticator(self):
699         auth = self.machine_creds.new_client_authenticator()
700         current  = netr_Authenticator()
701         current.cred.data = [ord(x) for x in auth["credential"]]
702         current.timestamp = auth["timestamp"]
703
704         subsequent = netr_Authenticator()
705         return (current, subsequent)
706
707
708 class SamrContext(object):
709     """State/Context associated with a samr connection.
710     """
711     def __init__(self, server, lp=None, creds=None):
712         self.connection    = None
713         self.handle        = None
714         self.domain_handle = None
715         self.domain_sid    = None
716         self.group_handle  = None
717         self.user_handle   = None
718         self.rids          = None
719         self.server        = server
720         self.lp            = lp
721         self.creds         = creds
722
723     def get_connection(self):
724         if not self.connection:
725             self.connection = samr.samr(
726                 "ncacn_ip_tcp:%s[seal]" % (self.server),
727                 lp_ctx=self.lp,
728                 credentials=self.creds)
729
730         return self.connection
731
732     def get_handle(self):
733         if not self.handle:
734             c = self.get_connection()
735             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
736         return self.handle
737
738
739 class Conversation(object):
740     """Details of a converation between a simulated client and a server."""
741     conversation_id = None
742
743     def __init__(self, start_time=None, endpoints=None):
744         self.start_time = start_time
745         self.endpoints = endpoints
746         self.packets = []
747         self.msg = random_colour_print()
748         self.client_balance = 0.0
749
750     def __cmp__(self, other):
751         if self.start_time is None:
752             if other.start_time is None:
753                 return 0
754             return -1
755         if other.start_time is None:
756             return 1
757         return self.start_time - other.start_time
758
759     def add_packet(self, packet):
760         """Add a packet object to this conversation, making a local copy with
761         a conversation-relative timestamp."""
762         p = packet.copy()
763
764         if self.start_time is None:
765             self.start_time = p.timestamp
766
767         if self.endpoints is None:
768             self.endpoints = p.endpoints
769
770         if p.endpoints != self.endpoints:
771             raise FakePacketError("Conversation endpoints %s don't match"
772                                   "packet endpoints %s" %
773                                   (self.endpoints, p.endpoints))
774
775         p.timestamp -= self.start_time
776
777         if p.src == p.endpoints[0]:
778             self.client_balance -= p.client_score()
779         else:
780             self.client_balance += p.client_score()
781
782         if p.is_really_a_packet():
783             self.packets.append(p)
784
785     def add_short_packet(self, timestamp, p, extra, client=True):
786         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
787         (possibly empty) list of extra data. If client is True, assume
788         this packet is from the client to the server.
789         """
790         protocol, opcode = p.split(':', 1)
791         src, dest = self.guess_client_server()
792         if not client:
793             src, dest = dest, src
794
795         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
796         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
797         fields = [timestamp - self.start_time, ip_protocol,
798                   '', src, dest,
799                   protocol, opcode, desc]
800         fields.extend(extra)
801         packet = Packet(fields)
802         # XXX we're assuming the timestamp is already adjusted for
803         # this conversation?
804         # XXX should we adjust client balance for guessed packets?
805         if packet.src == packet.endpoints[0]:
806             self.client_balance -= packet.client_score()
807         else:
808             self.client_balance += packet.client_score()
809         if packet.is_really_a_packet():
810             self.packets.append(packet)
811
812     def __str__(self):
813         return ("<Conversation %s %s starting %.3f %d packets>" %
814                 (self.conversation_id, self.endpoints, self.start_time,
815                  len(self.packets)))
816
817     __repr__ = __str__
818
819     def __iter__(self):
820         return iter(self.packets)
821
822     def __len__(self):
823         return len(self.packets)
824
825     def get_duration(self):
826         if len(self.packets) < 2:
827             return 0
828         return self.packets[-1].timestamp - self.packets[0].timestamp
829
830     def replay_as_summary_lines(self):
831         lines = []
832         for p in self.packets:
833             lines.append(p.as_summary(self.start_time))
834         return lines
835
836     def replay_in_fork_with_delay(self, start, context=None, account=None):
837         """Fork a new process and replay the conversation.
838         """
839         def signal_handler(signal, frame):
840             """Signal handler closes standard out and error.
841
842             Triggered by a sigterm, ensures that the log messages are flushed
843             to disk and not lost.
844             """
845             sys.stderr.close()
846             sys.stdout.close()
847             os._exit(0)
848
849         t = self.start_time
850         now = time.time() - start
851         gap = t - now
852         # we are replaying strictly in order, so it is safe to sleep
853         # in the main process if the gap is big enough. This reduces
854         # the number of concurrent threads, which allows us to make
855         # larger loads.
856         if gap > 0.15 and False:
857             print("sleeping for %f in main process" % (gap - 0.1),
858                   file=sys.stderr)
859             time.sleep(gap - 0.1)
860             now = time.time() - start
861             gap = t - now
862             print("gap is now %f" % gap, file=sys.stderr)
863
864         self.conversation_id = context.next_conversation_id()
865         pid = os.fork()
866         if pid != 0:
867             return pid
868         pid = os.getpid()
869         signal.signal(signal.SIGTERM, signal_handler)
870         # we must never return, or we'll end up running parts of the
871         # parent's clean-up code. So we work in a try...finally, and
872         # try to print any exceptions.
873
874         try:
875             context.generate_process_local_config(account, self)
876             sys.stdin.close()
877             os.close(0)
878             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
879                                     self.conversation_id)
880             sys.stdout.close()
881             sys.stdout = open(filename, 'w')
882
883             sleep_time = gap - SLEEP_OVERHEAD
884             if sleep_time > 0:
885                 time.sleep(sleep_time)
886
887             miss = t - (time.time() - start)
888             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
889             self.replay(context)
890         except Exception:
891             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
892                   file=sys.stderr)
893             traceback.print_exc(sys.stderr)
894         finally:
895             sys.stderr.close()
896             sys.stdout.close()
897             os._exit(0)
898
899     def replay(self, context=None):
900         start = time.time()
901
902         for p in self.packets:
903             now = time.time() - start
904             gap = p.timestamp - now
905             sleep_time = gap - SLEEP_OVERHEAD
906             if sleep_time > 0:
907                 time.sleep(sleep_time)
908
909             miss = p.timestamp - (time.time() - start)
910             if context is None:
911                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
912                                                            os.getpid()))
913                 continue
914             p.play(self, context)
915
916     def guess_client_server(self, server_clue=None):
917         """Have a go at deciding who is the server and who is the client.
918         returns (client, server)
919         """
920         a, b = self.endpoints
921
922         if self.client_balance < 0:
923             return (a, b)
924
925         # in the absense of a clue, we will fall through to assuming
926         # the lowest number is the server (which is usually true).
927
928         if self.client_balance == 0 and server_clue == b:
929             return (a, b)
930
931         return (b, a)
932
933     def forget_packets_outside_window(self, s, e):
934         """Prune any packets outside the timne window we're interested in
935
936         :param s: start of the window
937         :param e: end of the window
938         """
939
940         new_packets = []
941         for p in self.packets:
942             if p.timestamp < s or p.timestamp > e:
943                 continue
944             new_packets.append(p)
945
946         self.packets = new_packets
947         if new_packets:
948             self.start_time = new_packets[0].timestamp
949         else:
950             self.start_time = None
951
952     def renormalise_times(self, start_time):
953         """Adjust the packet start times relative to the new start time."""
954         for p in self.packets:
955             p.timestamp -= start_time
956
957         if self.start_time is not None:
958             self.start_time -= start_time
959
960
961 class DnsHammer(Conversation):
962     """A lightweight conversation that generates a lot of dns:0 packets on
963     the fly"""
964
965     def __init__(self, dns_rate, duration):
966         n = int(dns_rate * duration)
967         self.times = [random.uniform(0, duration) for i in range(n)]
968         self.times.sort()
969         self.rate = dns_rate
970         self.duration = duration
971         self.start_time = 0
972         self.msg = random_colour_print()
973
974     def __str__(self):
975         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
976                 (len(self.times), self.duration, self.rate))
977
978     def replay_in_fork_with_delay(self, start, context=None, account=None):
979         return Conversation.replay_in_fork_with_delay(self,
980                                                       start,
981                                                       context,
982                                                       account)
983
984     def replay(self, context=None):
985         start = time.time()
986         fn = traffic_packets.packet_dns_0
987         for t in self.times:
988             now = time.time() - start
989             gap = t - now
990             sleep_time = gap - SLEEP_OVERHEAD
991             if sleep_time > 0:
992                 time.sleep(sleep_time)
993
994             if context is None:
995                 miss = t - (time.time() - start)
996                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
997                                                            os.getpid()))
998                 continue
999
1000             packet_start = time.time()
1001             try:
1002                 fn(self, self, context)
1003                 end = time.time()
1004                 duration = end - packet_start
1005                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1006             except Exception as e:
1007                 end = time.time()
1008                 duration = end - packet_start
1009                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1010
1011
1012 def ingest_summaries(files, dns_mode='count'):
1013     """Load a summary traffic summary file and generated Converations from it.
1014     """
1015
1016     dns_counts = defaultdict(int)
1017     packets = []
1018     for f in files:
1019         if isinstance(f, str):
1020             f = open(f)
1021         print("Ingesting %s" % (f.name,), file=sys.stderr)
1022         for line in f:
1023             p = Packet(line)
1024             if p.protocol == 'dns' and dns_mode != 'include':
1025                 dns_counts[p.opcode] += 1
1026             else:
1027                 packets.append(p)
1028
1029         f.close()
1030
1031     if not packets:
1032         return [], 0
1033
1034     start_time = min(p.timestamp for p in packets)
1035     last_packet = max(p.timestamp for p in packets)
1036
1037     print("gathering packets into conversations", file=sys.stderr)
1038     conversations = OrderedDict()
1039     for p in packets:
1040         p.timestamp -= start_time
1041         c = conversations.get(p.endpoints)
1042         if c is None:
1043             c = Conversation()
1044             conversations[p.endpoints] = c
1045         c.add_packet(p)
1046
1047     # We only care about conversations with actual traffic, so we
1048     # filter out conversations with nothing to say. We do that here,
1049     # rather than earlier, because those empty packets contain useful
1050     # hints as to which end of the conversation was the client.
1051     conversation_list = []
1052     for c in conversations.values():
1053         if len(c) != 0:
1054             conversation_list.append(c)
1055
1056     # This is obviously not correct, as many conversations will appear
1057     # to start roughly simultaneously at the beginning of the snapshot.
1058     # To which we say: oh well, so be it.
1059     duration = float(last_packet - start_time)
1060     mean_interval = len(conversations) / duration
1061
1062     return conversation_list, mean_interval, duration, dns_counts
1063
1064
1065 def guess_server_address(conversations):
1066     # we guess the most common address.
1067     addresses = Counter()
1068     for c in conversations:
1069         addresses.update(c.endpoints)
1070     if addresses:
1071         return addresses.most_common(1)[0]
1072
1073
1074 def stringify_keys(x):
1075     y = {}
1076     for k, v in x.items():
1077         k2 = '\t'.join(k)
1078         y[k2] = v
1079     return y
1080
1081
1082 def unstringify_keys(x):
1083     y = {}
1084     for k, v in x.items():
1085         t = tuple(str(k).split('\t'))
1086         y[t] = v
1087     return y
1088
1089
1090 class TrafficModel(object):
1091     def __init__(self, n=3):
1092         self.ngrams = {}
1093         self.query_details = {}
1094         self.n = n
1095         self.dns_opcounts = defaultdict(int)
1096         self.cumulative_duration = 0.0
1097         self.conversation_rate = [0, 1]
1098
1099     def learn(self, conversations, dns_opcounts={}):
1100         prev = 0.0
1101         cum_duration = 0.0
1102         key = (NON_PACKET,) * (self.n - 1)
1103
1104         server = guess_server_address(conversations)
1105
1106         for k, v in dns_opcounts.items():
1107             self.dns_opcounts[k] += v
1108
1109         if len(conversations) > 1:
1110             elapsed =\
1111                 conversations[-1].start_time - conversations[0].start_time
1112             self.conversation_rate[0] = len(conversations)
1113             self.conversation_rate[1] = elapsed
1114
1115         for c in conversations:
1116             client, server = c.guess_client_server(server)
1117             cum_duration += c.get_duration()
1118             key = (NON_PACKET,) * (self.n - 1)
1119             for p in c:
1120                 if p.src != client:
1121                     continue
1122
1123                 elapsed = p.timestamp - prev
1124                 prev = p.timestamp
1125                 if elapsed > WAIT_THRESHOLD:
1126                     # add the wait as an extra state
1127                     wait = 'wait:%d' % (math.log(max(1.0,
1128                                                      elapsed * WAIT_SCALE)))
1129                     self.ngrams.setdefault(key, []).append(wait)
1130                     key = key[1:] + (wait,)
1131
1132                 short_p = p.as_packet_type()
1133                 self.query_details.setdefault(short_p,
1134                                               []).append(tuple(p.extra))
1135                 self.ngrams.setdefault(key, []).append(short_p)
1136                 key = key[1:] + (short_p,)
1137
1138         self.cumulative_duration += cum_duration
1139         # add in the end
1140         self.ngrams.setdefault(key, []).append(NON_PACKET)
1141
1142     def save(self, f):
1143         ngrams = {}
1144         for k, v in self.ngrams.items():
1145             k = '\t'.join(k)
1146             ngrams[k] = dict(Counter(v))
1147
1148         query_details = {}
1149         for k, v in self.query_details.items():
1150             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1151                                             for x in v))
1152
1153         d = {
1154             'ngrams': ngrams,
1155             'query_details': query_details,
1156             'cumulative_duration': self.cumulative_duration,
1157             'conversation_rate': self.conversation_rate,
1158         }
1159         d['dns'] = self.dns_opcounts
1160
1161         if isinstance(f, str):
1162             f = open(f, 'w')
1163
1164         json.dump(d, f, indent=2)
1165
1166     def load(self, f):
1167         if isinstance(f, str):
1168             f = open(f)
1169
1170         d = json.load(f)
1171
1172         for k, v in d['ngrams'].items():
1173             k = tuple(str(k).split('\t'))
1174             values = self.ngrams.setdefault(k, [])
1175             for p, count in v.items():
1176                 values.extend([str(p)] * count)
1177
1178         for k, v in d['query_details'].items():
1179             values = self.query_details.setdefault(str(k), [])
1180             for p, count in v.items():
1181                 if p == '-':
1182                     values.extend([()] * count)
1183                 else:
1184                     values.extend([tuple(str(p).split('\t'))] * count)
1185
1186         if 'dns' in d:
1187             for k, v in d['dns'].items():
1188                 self.dns_opcounts[k] += v
1189
1190         self.cumulative_duration = d['cumulative_duration']
1191         self.conversation_rate = d['conversation_rate']
1192
1193     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1194                                hard_stop=None, packet_rate=1):
1195         """Construct a individual converation from the model."""
1196
1197         c = Conversation(timestamp, (server, client))
1198
1199         key = (NON_PACKET,) * (self.n - 1)
1200
1201         while key in self.ngrams:
1202             p = random.choice(self.ngrams.get(key, NON_PACKET))
1203             if p == NON_PACKET:
1204                 break
1205             if p in self.query_details:
1206                 extra = random.choice(self.query_details[p])
1207             else:
1208                 extra = []
1209
1210             protocol, opcode = p.split(':', 1)
1211             if protocol == 'wait':
1212                 log_wait_time = int(opcode) + random.random()
1213                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1214                 timestamp += wait
1215             else:
1216                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1217                 wait = math.exp(log_wait) / packet_rate
1218                 timestamp += wait
1219                 if hard_stop is not None and timestamp > hard_stop:
1220                     break
1221                 c.add_short_packet(timestamp, p, extra)
1222
1223             key = key[1:] + (p,)
1224
1225         return c
1226
1227     def generate_conversations(self, rate, duration, packet_rate=1):
1228         """Generate a list of conversations from the model."""
1229
1230         # We run the simulation for at least ten times as long as our
1231         # desired duration, and take a section near the start.
1232         rate_n, rate_t  = self.conversation_rate
1233
1234         duration2 = max(rate_t, duration * 2)
1235         n = rate * duration2 * rate_n / rate_t
1236
1237         server = 1
1238         client = 2
1239
1240         conversations = []
1241         end = duration2
1242         start = end - duration
1243
1244         while client < n + 2:
1245             start = random.uniform(0, duration2)
1246             c = self.construct_conversation(start,
1247                                             client,
1248                                             server,
1249                                             hard_stop=(duration2 * 5),
1250                                             packet_rate=packet_rate)
1251
1252             c.forget_packets_outside_window(start, end)
1253             c.renormalise_times(start)
1254             if len(c) != 0:
1255                 conversations.append(c)
1256             client += 1
1257
1258         print(("we have %d conversations at rate %f" %
1259                               (len(conversations), rate)), file=sys.stderr)
1260         conversations.sort()
1261         return conversations
1262
1263
1264 IP_PROTOCOLS = {
1265     'dns': '11',
1266     'rpc_netlogon': '06',
1267     'kerberos': '06',      # ratio 16248:258
1268     'smb': '06',
1269     'smb2': '06',
1270     'ldap': '06',
1271     'cldap': '11',
1272     'lsarpc': '06',
1273     'samr': '06',
1274     'dcerpc': '06',
1275     'epm': '06',
1276     'drsuapi': '06',
1277     'browser': '11',
1278     'smb_netlogon': '11',
1279     'srvsvc': '06',
1280     'nbns': '11',
1281 }
1282
1283 OP_DESCRIPTIONS = {
1284     ('browser', '0x01'): 'Host Announcement (0x01)',
1285     ('browser', '0x02'): 'Request Announcement (0x02)',
1286     ('browser', '0x08'): 'Browser Election Request (0x08)',
1287     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1288     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1289     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1290     ('cldap', '3'): 'searchRequest',
1291     ('cldap', '5'): 'searchResDone',
1292     ('dcerpc', '0'): 'Request',
1293     ('dcerpc', '11'): 'Bind',
1294     ('dcerpc', '12'): 'Bind_ack',
1295     ('dcerpc', '13'): 'Bind_nak',
1296     ('dcerpc', '14'): 'Alter_context',
1297     ('dcerpc', '15'): 'Alter_context_resp',
1298     ('dcerpc', '16'): 'AUTH3',
1299     ('dcerpc', '2'): 'Response',
1300     ('dns', '0'): 'query',
1301     ('dns', '1'): 'response',
1302     ('drsuapi', '0'): 'DsBind',
1303     ('drsuapi', '12'): 'DsCrackNames',
1304     ('drsuapi', '13'): 'DsWriteAccountSpn',
1305     ('drsuapi', '1'): 'DsUnbind',
1306     ('drsuapi', '2'): 'DsReplicaSync',
1307     ('drsuapi', '3'): 'DsGetNCChanges',
1308     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1309     ('epm', '3'): 'Map',
1310     ('kerberos', ''): '',
1311     ('ldap', '0'): 'bindRequest',
1312     ('ldap', '1'): 'bindResponse',
1313     ('ldap', '2'): 'unbindRequest',
1314     ('ldap', '3'): 'searchRequest',
1315     ('ldap', '4'): 'searchResEntry',
1316     ('ldap', '5'): 'searchResDone',
1317     ('ldap', ''): '*** Unknown ***',
1318     ('lsarpc', '14'): 'lsa_LookupNames',
1319     ('lsarpc', '15'): 'lsa_LookupSids',
1320     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1321     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1322     ('lsarpc', '6'): 'lsa_OpenPolicy',
1323     ('lsarpc', '76'): 'lsa_LookupSids3',
1324     ('lsarpc', '77'): 'lsa_LookupNames4',
1325     ('nbns', '0'): 'query',
1326     ('nbns', '1'): 'response',
1327     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1328     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1329     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1330     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1331     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1332     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1333     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1334     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1335     ('samr', '0',): 'Connect',
1336     ('samr', '16'): 'GetAliasMembership',
1337     ('samr', '17'): 'LookupNames',
1338     ('samr', '18'): 'LookupRids',
1339     ('samr', '19'): 'OpenGroup',
1340     ('samr', '1'): 'Close',
1341     ('samr', '25'): 'QueryGroupMember',
1342     ('samr', '34'): 'OpenUser',
1343     ('samr', '36'): 'QueryUserInfo',
1344     ('samr', '39'): 'GetGroupsForUser',
1345     ('samr', '3'): 'QuerySecurity',
1346     ('samr', '5'): 'LookupDomain',
1347     ('samr', '64'): 'Connect5',
1348     ('samr', '6'): 'EnumDomains',
1349     ('samr', '7'): 'OpenDomain',
1350     ('samr', '8'): 'QueryDomainInfo',
1351     ('smb', '0x04'): 'Close (0x04)',
1352     ('smb', '0x24'): 'Locking AndX (0x24)',
1353     ('smb', '0x2e'): 'Read AndX (0x2e)',
1354     ('smb', '0x32'): 'Trans2 (0x32)',
1355     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1356     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1357     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1358     ('smb', '0x74'): 'Logoff AndX (0x74)',
1359     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1360     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1361     ('smb2', '0'): 'NegotiateProtocol',
1362     ('smb2', '11'): 'Ioctl',
1363     ('smb2', '14'): 'Find',
1364     ('smb2', '16'): 'GetInfo',
1365     ('smb2', '18'): 'Break',
1366     ('smb2', '1'): 'SessionSetup',
1367     ('smb2', '2'): 'SessionLogoff',
1368     ('smb2', '3'): 'TreeConnect',
1369     ('smb2', '4'): 'TreeDisconnect',
1370     ('smb2', '5'): 'Create',
1371     ('smb2', '6'): 'Close',
1372     ('smb2', '8'): 'Read',
1373     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1374     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1375                                'user unknown (0x17)'),
1376     ('srvsvc', '16'): 'NetShareGetInfo',
1377     ('srvsvc', '21'): 'NetSrvGetInfo',
1378 }
1379
1380
1381 def expand_short_packet(p, timestamp, src, dest, extra):
1382     protocol, opcode = p.split(':', 1)
1383     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1384     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1385
1386     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1387     line.extend(extra)
1388     return '\t'.join(line)
1389
1390
1391 def replay(conversations,
1392            host=None,
1393            creds=None,
1394            lp=None,
1395            accounts=None,
1396            dns_rate=0,
1397            duration=None,
1398            **kwargs):
1399
1400     context = ReplayContext(server=host,
1401                             creds=creds,
1402                             lp=lp,
1403                             **kwargs)
1404
1405     if len(accounts) < len(conversations):
1406         print(("we have %d accounts but %d conversations" %
1407                (accounts, conversations)), file=sys.stderr)
1408
1409     cstack = list(zip(
1410         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1411         accounts))
1412
1413     # Set the process group so that the calling scripts are not killed
1414     # when the forked child processes are killed.
1415     os.setpgrp()
1416
1417     start = time.time()
1418
1419     if duration is None:
1420         # end 1 second after the last packet of the last conversation
1421         # to start. Conversations other than the last could still be
1422         # going, but we don't care.
1423         duration = cstack[0][0].packets[-1].timestamp + 1.0
1424         print("We will stop after %.1f seconds" % duration,
1425               file=sys.stderr)
1426
1427     end = start + duration
1428
1429     print("Replaying traffic for %u conversations over %d seconds"
1430           % (len(conversations), duration))
1431
1432     children = {}
1433     if dns_rate:
1434         dns_hammer = DnsHammer(dns_rate, duration)
1435         cstack.append((dns_hammer, None))
1436
1437     try:
1438         while True:
1439             # we spawn a batch, wait for finishers, then spawn another
1440             now = time.time()
1441             batch_end = min(now + 2.0, end)
1442             fork_time = 0.0
1443             fork_n = 0
1444             while cstack:
1445                 c, account = cstack.pop()
1446                 if c.start_time + start > batch_end:
1447                     cstack.append((c, account))
1448                     break
1449
1450                 st = time.time()
1451                 pid = c.replay_in_fork_with_delay(start, context, account)
1452                 children[pid] = c
1453                 t = time.time()
1454                 elapsed = t - st
1455                 fork_time += elapsed
1456                 fork_n += 1
1457                 print("forked %s in pid %s (in %fs)" % (c, pid,
1458                                                         elapsed),
1459                       file=sys.stderr)
1460
1461             if fork_n:
1462                 print(("forked %d times in %f seconds (avg %f)" %
1463                        (fork_n, fork_time, fork_time / fork_n)),
1464                       file=sys.stderr)
1465             elif cstack:
1466                 debug(2, "no forks in batch ending %f" % batch_end)
1467
1468             while time.time() < batch_end - 1.0:
1469                 time.sleep(0.01)
1470                 try:
1471                     pid, status = os.waitpid(-1, os.WNOHANG)
1472                 except OSError as e:
1473                     if e.errno != 10:  # no child processes
1474                         raise
1475                     break
1476                 if pid:
1477                     c = children.pop(pid, None)
1478                     print(("process %d finished conversation %s;"
1479                            " %d to go" %
1480                            (pid, c, len(children))), file=sys.stderr)
1481
1482             if time.time() >= end:
1483                 print("time to stop", file=sys.stderr)
1484                 break
1485
1486     except Exception:
1487         print("EXCEPTION in parent", file=sys.stderr)
1488         traceback.print_exc()
1489     finally:
1490         for s in (15, 15, 9):
1491             print(("killing %d children with -%d" %
1492                                  (len(children), s)), file=sys.stderr)
1493             for pid in children:
1494                 try:
1495                     os.kill(pid, s)
1496                 except OSError as e:
1497                     if e.errno != 3:  # don't fail if it has already died
1498                         raise
1499             time.sleep(0.5)
1500             end = time.time() + 1
1501             while children:
1502                 try:
1503                     pid, status = os.waitpid(-1, os.WNOHANG)
1504                 except OSError as e:
1505                     if e.errno != 10:
1506                         raise
1507                 if pid != 0:
1508                     c = children.pop(pid, None)
1509                     print(("kill -%d %d KILLED conversation %s; "
1510                            "%d to go" %
1511                            (s, pid, c, len(children))),
1512                           file=sys.stderr)
1513                 if time.time() >= end:
1514                     break
1515
1516             if not children:
1517                 break
1518             time.sleep(1)
1519
1520         if children:
1521             print("%d children are missing" % len(children),
1522                   file=sys.stderr)
1523
1524         # there may be stragglers that were forked just as ^C was hit
1525         # and don't appear in the list of children. We can get them
1526         # with killpg, but that will also kill us, so this is^H^H would be
1527         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1528         # so as not to have to fuss around writing signal handlers.
1529         try:
1530             os.killpg(0, 2)
1531         except KeyboardInterrupt:
1532             print("ignoring fake ^C", file=sys.stderr)
1533
1534
1535 def openLdb(host, creds, lp):
1536     session = system_session()
1537     ldb = SamDB(url="ldap://%s" % host,
1538                 session_info=session,
1539                 credentials=creds,
1540                 lp=lp)
1541     return ldb
1542
1543
1544 def ou_name(ldb, instance_id):
1545     """Generate an ou name from the instance id"""
1546     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1547                                                     ldb.domain_dn())
1548
1549
1550 def create_ou(ldb, instance_id):
1551     """Create an ou, all created user and machine accounts will belong to it.
1552
1553     This allows all the created resources to be cleaned up easily.
1554     """
1555     ou = ou_name(ldb, instance_id)
1556     try:
1557         ldb.add({"dn":          ou.split(',', 1)[1],
1558                  "objectclass": "organizationalunit"})
1559     except LdbError as e:
1560         (status, _) = e
1561         # ignore already exists
1562         if status != 68:
1563             raise
1564     try:
1565         ldb.add({"dn":          ou,
1566                  "objectclass": "organizationalunit"})
1567     except LdbError as e:
1568         (status, _) = e
1569         # ignore already exists
1570         if status != 68:
1571             raise
1572     return ou
1573
1574
1575 class ConversationAccounts(object):
1576     """Details of the machine and user accounts associated with a conversation.
1577     """
1578     def __init__(self, netbios_name, machinepass, username, userpass):
1579         self.netbios_name = netbios_name
1580         self.machinepass  = machinepass
1581         self.username     = username
1582         self.userpass     = userpass
1583
1584
1585 def generate_replay_accounts(ldb, instance_id, number, password):
1586     """Generate a series of unique machine and user account names."""
1587
1588     generate_traffic_accounts(ldb, instance_id, number, password)
1589     accounts = []
1590     for i in range(1, number + 1):
1591         netbios_name = "STGM-%d-%d" % (instance_id, i)
1592         username     = "STGU-%d-%d" % (instance_id, i)
1593
1594         account = ConversationAccounts(netbios_name, password, username,
1595                                        password)
1596         accounts.append(account)
1597     return accounts
1598
1599
1600 def generate_traffic_accounts(ldb, instance_id, number, password):
1601     """Create the specified number of user and machine accounts.
1602
1603     As accounts are not explicitly deleted between runs. This function starts
1604     with the last account and iterates backwards stopping either when it
1605     finds an already existing account or it has generated all the required
1606     accounts.
1607     """
1608     print(("Generating machine and conversation accounts, "
1609            "as required for %d conversations" % number),
1610           file=sys.stderr)
1611     added = 0
1612     for i in range(number, 0, -1):
1613         try:
1614             netbios_name = "STGM-%d-%d" % (instance_id, i)
1615             create_machine_account(ldb, instance_id, netbios_name, password)
1616             added += 1
1617         except LdbError as e:
1618             (status, _) = e
1619             if status == 68:
1620                 break
1621             else:
1622                 raise
1623     if added > 0:
1624         print("Added %d new machine accounts" % added,
1625               file=sys.stderr)
1626
1627     added = 0
1628     for i in range(number, 0, -1):
1629         try:
1630             username = "STGU-%d-%d" % (instance_id, i)
1631             create_user_account(ldb, instance_id, username, password)
1632             added += 1
1633         except LdbError as e:
1634             (status, _) = e
1635             if status == 68:
1636                 break
1637             else:
1638                 raise
1639
1640     if added > 0:
1641         print("Added %d new user accounts" % added,
1642               file=sys.stderr)
1643
1644
1645 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1646     """Create a machine account via ldap."""
1647
1648     ou = ou_name(ldb, instance_id)
1649     dn = "cn=%s,%s" % (netbios_name, ou)
1650     utf16pw = unicode(
1651         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1652     ).encode('utf-16-le')
1653     start = time.time()
1654     ldb.add({
1655         "dn": dn,
1656         "objectclass": "computer",
1657         "sAMAccountName": "%s$" % netbios_name,
1658         "userAccountControl":
1659         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1660         "unicodePwd": utf16pw})
1661     end = time.time()
1662     duration = end - start
1663     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1664
1665
1666 def create_user_account(ldb, instance_id, username, userpass):
1667     """Create a user account via ldap."""
1668     ou = ou_name(ldb, instance_id)
1669     user_dn = "cn=%s,%s" % (username, ou)
1670     utf16pw = unicode(
1671         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1672     ).encode('utf-16-le')
1673     start = time.time()
1674     ldb.add({
1675         "dn": user_dn,
1676         "objectclass": "user",
1677         "sAMAccountName": username,
1678         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1679         "unicodePwd": utf16pw
1680     })
1681     end = time.time()
1682     duration = end - start
1683     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1684
1685
1686 def create_group(ldb, instance_id, name):
1687     """Create a group via ldap."""
1688
1689     ou = ou_name(ldb, instance_id)
1690     dn = "cn=%s,%s" % (name, ou)
1691     start = time.time()
1692     ldb.add({
1693         "dn": dn,
1694         "objectclass": "group",
1695     })
1696     end = time.time()
1697     duration = end - start
1698     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1699
1700
1701 def user_name(instance_id, i):
1702     """Generate a user name based in the instance id"""
1703     return "STGU-%d-%d" % (instance_id, i)
1704
1705
1706 def generate_users(ldb, instance_id, number, password):
1707     """Add users to the server"""
1708     users = 0
1709     for i in range(number, 0, -1):
1710         try:
1711             username = user_name(instance_id, i)
1712             create_user_account(ldb, instance_id, username, password)
1713             users += 1
1714         except LdbError as e:
1715             (status, _) = e
1716             # Stop if entry exists
1717             if status == 68:
1718                 break
1719             else:
1720                 raise
1721
1722     return users
1723
1724
1725 def group_name(instance_id, i):
1726     """Generate a group name from instance id."""
1727     return "STGG-%d-%d" % (instance_id, i)
1728
1729
1730 def generate_groups(ldb, instance_id, number):
1731     """Create the required number of groups on the server."""
1732     groups = 0
1733     for i in range(number, 0, -1):
1734         try:
1735             name = group_name(instance_id, i)
1736             create_group(ldb, instance_id, name)
1737             groups += 1
1738         except LdbError as e:
1739             (status, _) = e
1740             # Stop if entry exists
1741             if status == 68:
1742                 break
1743             else:
1744                 raise
1745     return groups
1746
1747
1748 def clean_up_accounts(ldb, instance_id):
1749     """Remove the created accounts and groups from the server."""
1750     ou = ou_name(ldb, instance_id)
1751     try:
1752         ldb.delete(ou, ["tree_delete:1"])
1753     except LdbError as e:
1754         (status, _) = e
1755         # ignore does not exist
1756         if status != 32:
1757             raise
1758
1759
1760 def generate_users_and_groups(ldb, instance_id, password,
1761                               number_of_users, number_of_groups,
1762                               group_memberships):
1763     """Generate the required users and groups, allocating the users to
1764        those groups."""
1765     assignments = []
1766     groups_added  = 0
1767
1768     create_ou(ldb, instance_id)
1769
1770     print("Generating dummy user accounts", file=sys.stderr)
1771     users_added = generate_users(ldb, instance_id, number_of_users, password)
1772
1773     if number_of_groups > 0:
1774         print("Generating dummy groups", file=sys.stderr)
1775         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1776
1777     if group_memberships > 0:
1778         print("Assigning users to groups", file=sys.stderr)
1779         assignments = assign_groups(number_of_groups,
1780                                     groups_added,
1781                                     number_of_users,
1782                                     users_added,
1783                                     group_memberships)
1784         print("Adding users to groups", file=sys.stderr)
1785         add_users_to_groups(ldb, instance_id, assignments)
1786
1787     if (groups_added > 0 and users_added == 0 and
1788        number_of_groups != groups_added):
1789         print("Warning: the added groups will contain no members",
1790               file=sys.stderr)
1791
1792     print(("Added %d users, %d groups and %d group memberships" %
1793            (users_added, groups_added, len(assignments))),
1794           file=sys.stderr)
1795
1796
1797 def assign_groups(number_of_groups,
1798                   groups_added,
1799                   number_of_users,
1800                   users_added,
1801                   group_memberships):
1802     """Allocate users to groups.
1803
1804     The intention is to have a few users that belong to most groups, while
1805     the majority of users belong to a few groups.
1806
1807     A few groups will contain most users, with the remaining only having a
1808     few users.
1809     """
1810
1811     def generate_user_distribution(n):
1812         """Probability distribution of a user belonging to a group.
1813         """
1814         dist = []
1815         for x in range(1, n + 1):
1816             p = 1 / (x + 0.001)
1817             dist.append(p)
1818         return dist
1819
1820     def generate_group_distribution(n):
1821         """Probability distribution of a group containing a user."""
1822         dist = []
1823         for x in range(1, n + 1):
1824             p = 1 / (x**1.3)
1825             dist.append(p)
1826         return dist
1827
1828     assignments = set()
1829     if group_memberships <= 0:
1830         return assignments
1831
1832     group_dist = generate_group_distribution(number_of_groups)
1833     user_dist  = generate_user_distribution(number_of_users)
1834
1835     # Calculate the number of group menberships required
1836     group_memberships = math.ceil(
1837         float(group_memberships) *
1838         (float(users_added) / float(number_of_users)))
1839
1840     existing_users  = number_of_users  - users_added  - 1
1841     existing_groups = number_of_groups - groups_added - 1
1842     while len(assignments) < group_memberships:
1843         user        = random.randint(0, number_of_users - 1)
1844         group       = random.randint(0, number_of_groups - 1)
1845         probability = group_dist[group] * user_dist[user]
1846
1847         if ((random.random() < probability * 10000) and
1848            (group > existing_groups or user > existing_users)):
1849             # the + 1 converts the array index to the corresponding
1850             # group or user number
1851             assignments.add(((user + 1), (group + 1)))
1852
1853     return assignments
1854
1855
1856 def add_users_to_groups(db, instance_id, assignments):
1857     """Add users to their assigned groups.
1858
1859     Takes the list of (group,user) tuples generated by assign_groups and
1860     assign the users to their specified groups."""
1861
1862     ou = ou_name(db, instance_id)
1863
1864     def build_dn(name):
1865         return("cn=%s,%s" % (name, ou))
1866
1867     for (user, group) in assignments:
1868         user_dn  = build_dn(user_name(instance_id, user))
1869         group_dn = build_dn(group_name(instance_id, group))
1870
1871         m = ldb.Message()
1872         m.dn = ldb.Dn(db, group_dn)
1873         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1874         start = time.time()
1875         db.modify(m)
1876         end = time.time()
1877         duration = end - start
1878         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1879
1880
1881 def generate_stats(statsdir, timing_file):
1882     """Generate and print the summary stats for a run."""
1883     first      = sys.float_info.max
1884     last       = 0
1885     successful = 0
1886     failed     = 0
1887     latencies  = {}
1888     failures   = {}
1889     unique_converations = set()
1890     conversations = 0
1891
1892     if timing_file is not None:
1893         tw = timing_file.write
1894     else:
1895         def tw(x):
1896             pass
1897
1898     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1899
1900     for filename in os.listdir(statsdir):
1901         path = os.path.join(statsdir, filename)
1902         with open(path, 'r') as f:
1903             for line in f:
1904                 try:
1905                     fields       = line.rstrip('\n').split('\t')
1906                     conversation = fields[1]
1907                     protocol     = fields[2]
1908                     packet_type  = fields[3]
1909                     latency      = float(fields[4])
1910                     first        = min(float(fields[0]) - latency, first)
1911                     last         = max(float(fields[0]), last)
1912
1913                     if protocol not in latencies:
1914                         latencies[protocol] = {}
1915                     if packet_type not in latencies[protocol]:
1916                         latencies[protocol][packet_type] = []
1917
1918                     latencies[protocol][packet_type].append(latency)
1919
1920                     if protocol not in failures:
1921                         failures[protocol] = {}
1922                     if packet_type not in failures[protocol]:
1923                         failures[protocol][packet_type] = 0
1924
1925                     if fields[5] == 'True':
1926                         successful += 1
1927                     else:
1928                         failed += 1
1929                         failures[protocol][packet_type] += 1
1930
1931                     if conversation not in unique_converations:
1932                         unique_converations.add(conversation)
1933                         conversations += 1
1934
1935                     tw(line)
1936                 except (ValueError, IndexError):
1937                     # not a valid line print and ignore
1938                     print(line, file=sys.stderr)
1939                     pass
1940     duration = last - first
1941     if successful == 0:
1942         success_rate = 0
1943     else:
1944         success_rate = successful / duration
1945     if failed == 0:
1946         failure_rate = 0
1947     else:
1948         failure_rate = failed / duration
1949
1950     # print the stats in more human-readable format when stdout is going to the
1951     # console (as opposed to being redirected to a file)
1952     if sys.stdout.isatty():
1953         print("Total conversations:   %10d" % conversations)
1954         print("Successful operations: %10d (%.3f per second)"
1955               % (successful, success_rate))
1956         print("Failed operations:     %10d (%.3f per second)"
1957               % (failed, failure_rate))
1958     else:
1959         print("(%d, %d, %d, %.3f, %.3f)" %
1960               (conversations, successful, failed, success_rate, failure_rate))
1961
1962     if sys.stdout.isatty():
1963         print("Protocol    Op Code  Description                               "
1964               " Count       Failed         Mean       Median          "
1965               "95%        Range          Max")
1966     else:
1967         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1968               "\tmax")
1969     protocols = sorted(latencies.keys())
1970     for protocol in protocols:
1971         packet_types = sorted(latencies[protocol], key=opcode_key)
1972         for packet_type in packet_types:
1973             values     = latencies[protocol][packet_type]
1974             values     = sorted(values)
1975             count      = len(values)
1976             failed     = failures[protocol][packet_type]
1977             mean       = sum(values) / count
1978             median     = calc_percentile(values, 0.50)
1979             percentile = calc_percentile(values, 0.95)
1980             rng        = values[-1] - values[0]
1981             maxv       = values[-1]
1982             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1983             if sys.stdout.isatty:
1984                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1985                       "%12.6f %12.6f %12.6f %12.6f"
1986                       % (protocol,
1987                          packet_type,
1988                          desc,
1989                          count,
1990                          failed,
1991                          mean,
1992                          median,
1993                          percentile,
1994                          rng,
1995                          maxv))
1996             else:
1997                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
1998                       % (protocol,
1999                          packet_type,
2000                          desc,
2001                          count,
2002                          failed,
2003                          mean,
2004                          median,
2005                          percentile,
2006                          rng,
2007                          maxv))
2008
2009
2010 def opcode_key(v):
2011     """Sort key for the operation code to ensure that it sorts numerically"""
2012     try:
2013         return "%03d" % int(v)
2014     except:
2015         return v
2016
2017
2018 def calc_percentile(values, percentile):
2019     """Calculate the specified percentile from the list of values.
2020
2021     Assumes the list is sorted in ascending order.
2022     """
2023
2024     if not values:
2025         return 0
2026     k = (len(values) - 1) * percentile
2027     f = math.floor(k)
2028     c = math.ceil(k)
2029     if f == c:
2030         return values[int(k)]
2031     d0 = values[int(f)] * (c - k)
2032     d1 = values[int(c)] * (k - f)
2033     return d0 + d1
2034
2035
2036 def mk_masked_dir(*path):
2037     """In a testenv we end up with 0777 diectories that look an alarming
2038     green colour with ls. Use umask to avoid that."""
2039     d = os.path.join(*path)
2040     mask = os.umask(0o077)
2041     os.mkdir(d)
2042     os.umask(mask)
2043     return d