1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
28 from errno import ECHILD, ESRCH
30 from collections import OrderedDict, Counter, defaultdict, namedtuple
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
47 UF_SERVER_TRUST_ACCOUNT,
48 UF_TRUSTED_FOR_DELEGATION,
49 UF_WORKSTATION_TRUST_ACCOUNT
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
58 CURRENT_MODEL_VERSION = 2 # save as this
59 REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
62 # we don't use None, because it complicates [de]serialisation
66 ('dns', '0'): 1.0, # query
67 ('smb', '0x72'): 1.0, # Negotiate protocol
68 ('ldap', '0'): 1.0, # bind
69 ('ldap', '3'): 1.0, # searchRequest
70 ('ldap', '2'): 1.0, # unbindRequest
72 ('dcerpc', '11'): 1.0, # bind
73 ('dcerpc', '14'): 1.0, # Alter_context
74 ('nbns', '0'): 1.0, # query
78 ('dns', '1'): 1.0, # response
79 ('ldap', '1'): 1.0, # bind response
80 ('ldap', '4'): 1.0, # search result
81 ('ldap', '5'): 1.0, # search done
83 ('dcerpc', '12'): 1.0, # bind_ack
84 ('dcerpc', '13'): 1.0, # bind_nak
85 ('dcerpc', '15'): 1.0, # Alter_context response
88 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
91 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
92 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
94 # DEBUG_LEVEL can be changed by scripts with -d
97 LOGGER = get_samba_logger(name=__name__)
100 def debug(level, msg, *args):
101 """Print a formatted debug message to standard error.
104 :param level: The debug level, message will be printed if it is <= the
105 currently set debug level. The debug level can be set with
107 :param msg: The message to be logged, can contain C-Style format
109 :param args: The parameters required by the format specifiers
111 if level <= DEBUG_LEVEL:
113 print(msg, file=sys.stderr)
115 print(msg % tuple(args), file=sys.stderr)
118 def debug_lineno(*args):
119 """ Print an unformatted log message to stderr, contaning the line number
121 tb = traceback.extract_stack(limit=2)
122 print((" %s:" "\033[01;33m"
123 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
126 print(a, file=sys.stderr)
127 print(file=sys.stderr)
131 def random_colour_print(seeds):
132 """Return a function that prints a coloured line to stderr. The colour
133 of the line depends on a sort of hash of the integer arguments."""
140 prefix = "\033[38;5;%dm" % (18 + s)
145 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
150 print(a, file=sys.stderr)
155 class FakePacketError(Exception):
159 class Packet(object):
160 """Details of a network packet"""
161 __slots__ = ('timestamp',
171 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
172 protocol, opcode, desc, extra):
173 self.timestamp = timestamp
174 self.ip_protocol = ip_protocol
175 self.stream_number = stream_number
178 self.protocol = protocol
182 if self.src < self.dest:
183 self.endpoints = (self.src, self.dest)
185 self.endpoints = (self.dest, self.src)
188 def from_line(cls, line):
189 fields = line.rstrip('\n').split('\t')
200 timestamp = float(timestamp)
204 return cls(timestamp, ip_protocol, stream_number, src, dest,
205 protocol, opcode, desc, extra)
207 def as_summary(self, time_offset=0.0):
208 """Format the packet as a traffic_summary line.
210 extra = '\t'.join(self.extra)
211 t = self.timestamp + time_offset
212 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
215 self.stream_number or '',
224 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
225 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
226 self.stream_number, self.protocol, self.opcode, self.desc,
227 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
230 return "<Packet @%s>" % self
233 return self.__class__(self.timestamp,
243 def as_packet_type(self):
244 t = '%s:%s' % (self.protocol, self.opcode)
247 def client_score(self):
248 """A positive number means we think it is a client; a negative number
249 means we think it is a server. Zero means no idea. range: -1 to 1.
251 key = (self.protocol, self.opcode)
252 if key in CLIENT_CLUES:
253 return CLIENT_CLUES[key]
254 if key in SERVER_CLUES:
255 return -SERVER_CLUES[key]
258 def play(self, conversation, context):
259 """Send the packet over the network, if required.
261 Some packets are ignored, i.e. for protocols not handled,
262 server response messages, or messages that are generated by the
263 protocol layer associated with other packets.
265 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
267 fn = getattr(traffic_packets, fn_name)
269 except AttributeError as e:
270 print("Conversation(%s) Missing handler %s" %
271 (conversation.conversation_id, fn_name),
275 # Don't display a message for kerberos packets, they're not directly
276 # generated they're used to indicate kerberos should be used
277 if self.protocol != "kerberos":
278 debug(2, "Conversation(%s) Calling handler %s" %
279 (conversation.conversation_id, fn_name))
283 if fn(self, conversation, context):
284 # Only collect timing data for functions that generate
285 # network traffic, or fail
287 duration = end - start
288 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
289 (end, conversation.conversation_id, self.protocol,
290 self.opcode, duration))
291 except Exception as e:
293 duration = end - start
294 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
295 (end, conversation.conversation_id, self.protocol,
296 self.opcode, duration, e))
298 def __cmp__(self, other):
299 return self.timestamp - other.timestamp
301 def is_really_a_packet(self, missing_packet_stats=None):
302 return is_a_real_packet(self.protocol, self.opcode)
305 def is_a_real_packet(protocol, opcode):
306 """Is the packet one that can be ignored?
308 If so removing it will have no effect on the replay
310 if protocol in SKIPPED_PROTOCOLS:
311 # Ignore any packets for the protocols we're not interested in.
313 if protocol == "ldap" and opcode == '':
314 # skip ldap continuation packets
317 fn_name = 'packet_%s_%s' % (protocol, opcode)
318 fn = getattr(traffic_packets, fn_name, None)
320 LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
322 if fn is traffic_packets.null_packet:
327 def is_a_traffic_generating_packet(protocol, opcode):
328 """Return true if a packet generates traffic in its own right. Some of
329 these will generate traffic in certain contexts (e.g. ldap unbind
330 after a bind) but not if the conversation consists only of these packets.
332 if protocol == 'wait':
335 if (protocol, opcode) in (
342 return is_a_real_packet(protocol, opcode)
345 class ReplayContext(object):
346 """State/Context for a conversation between an simulated client and a
347 server. Some of the context is shared amongst all conversations
348 and should be generated before the fork, while other context is
349 specific to a particular conversation and should be generated
350 *after* the fork, in generate_process_local_config().
356 badpassword_frequency=None,
357 prefer_kerberos=None,
362 domain=os.environ.get("DOMAIN"),
365 self.netlogon_connection = None
368 self.prefer_kerberos = prefer_kerberos
370 self.base_dn = base_dn
372 self.statsdir = statsdir
373 self.global_tempdir = tempdir
374 self.domain_sid = domain_sid
375 self.realm = lp.get('realm')
377 # Bad password attempt controls
378 self.badpassword_frequency = badpassword_frequency
379 self.last_lsarpc_bad = False
380 self.last_lsarpc_named_bad = False
381 self.last_simple_bind_bad = False
382 self.last_bind_bad = False
383 self.last_srvsvc_bad = False
384 self.last_drsuapi_bad = False
385 self.last_netlogon_bad = False
386 self.last_samlogon_bad = False
387 self.generate_ldap_search_tables()
389 def generate_ldap_search_tables(self):
390 session = system_session()
392 db = SamDB(url="ldap://%s" % self.server,
393 session_info=session,
394 credentials=self.creds,
397 res = db.search(db.domain_dn(),
398 scope=ldb.SCOPE_SUBTREE,
399 controls=["paged_results:1:1000"],
402 # find a list of dns for each pattern
403 # e.g. CN,CN,CN,DC,DC
405 attribute_clue_map = {
411 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
412 dns = dn_map.setdefault(pattern, [])
414 if dn.startswith('CN=NTDS Settings,'):
415 attribute_clue_map['invocationId'].append(dn)
417 # extend the map in case we are working with a different
418 # number of DC components.
419 # for k, v in self.dn_map.items():
420 # print >>sys.stderr, k, len(v)
422 for k in list(dn_map.keys()):
426 while p[-3:] == ',DC':
430 if p != k and p in dn_map:
431 print('dn_map collison %s %s' % (k, p),
434 dn_map[p] = dn_map[k]
437 self.attribute_clue_map = attribute_clue_map
439 def generate_process_local_config(self, account, conversation):
440 self.ldap_connections = []
441 self.dcerpc_connections = []
442 self.lsarpc_connections = []
443 self.lsarpc_connections_named = []
444 self.drsuapi_connections = []
445 self.srvsvc_connections = []
446 self.samr_contexts = []
447 self.netbios_name = account.netbios_name
448 self.machinepass = account.machinepass
449 self.username = account.username
450 self.userpass = account.userpass
452 self.tempdir = mk_masked_dir(self.global_tempdir,
454 conversation.conversation_id)
456 self.lp.set("private dir", self.tempdir)
457 self.lp.set("lock dir", self.tempdir)
458 self.lp.set("state directory", self.tempdir)
459 self.lp.set("tls verify peer", "no_check")
461 self.remoteAddress = "/root/ncalrpc_as_system"
462 self.samlogon_dn = ("cn=%s,%s" %
463 (self.netbios_name, self.ou))
464 self.user_dn = ("cn=%s,%s" %
465 (self.username, self.ou))
467 self.generate_machine_creds()
468 self.generate_user_creds()
470 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
471 """Execute the supplied logon function, randomly choosing the
474 Based on the frequency in badpassword_frequency randomly perform the
475 function with the supplied bad credentials.
476 If run with bad credentials, the function is re-run with the good
478 failed_last_time is used to prevent consecutive bad credential
479 attempts. So the over all bad credential frequency will be lower
480 than that requested, but not significantly.
482 if not failed_last_time:
483 if (self.badpassword_frequency and
484 random.random() < self.badpassword_frequency):
488 # Ignore any exceptions as the operation may fail
489 # as it's being performed with bad credentials
491 failed_last_time = True
493 failed_last_time = False
496 return (result, failed_last_time)
498 def generate_user_creds(self):
499 """Generate the conversation specific user Credentials.
501 Each Conversation has an associated user account used to simulate
502 any non Administrative user traffic.
504 Generates user credentials with good and bad passwords and ldap
505 simple bind credentials with good and bad passwords.
507 self.user_creds = Credentials()
508 self.user_creds.guess(self.lp)
509 self.user_creds.set_workstation(self.netbios_name)
510 self.user_creds.set_password(self.userpass)
511 self.user_creds.set_username(self.username)
512 self.user_creds.set_domain(self.domain)
513 if self.prefer_kerberos:
514 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
516 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
518 self.user_creds_bad = Credentials()
519 self.user_creds_bad.guess(self.lp)
520 self.user_creds_bad.set_workstation(self.netbios_name)
521 self.user_creds_bad.set_password(self.userpass[:-4])
522 self.user_creds_bad.set_username(self.username)
523 if self.prefer_kerberos:
524 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
526 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
528 # Credentials for ldap simple bind.
529 self.simple_bind_creds = Credentials()
530 self.simple_bind_creds.guess(self.lp)
531 self.simple_bind_creds.set_workstation(self.netbios_name)
532 self.simple_bind_creds.set_password(self.userpass)
533 self.simple_bind_creds.set_username(self.username)
534 self.simple_bind_creds.set_gensec_features(
535 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
536 if self.prefer_kerberos:
537 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
539 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
540 self.simple_bind_creds.set_bind_dn(self.user_dn)
542 self.simple_bind_creds_bad = Credentials()
543 self.simple_bind_creds_bad.guess(self.lp)
544 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
545 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
546 self.simple_bind_creds_bad.set_username(self.username)
547 self.simple_bind_creds_bad.set_gensec_features(
548 self.simple_bind_creds_bad.get_gensec_features() |
550 if self.prefer_kerberos:
551 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
553 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
554 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
556 def generate_machine_creds(self):
557 """Generate the conversation specific machine Credentials.
559 Each Conversation has an associated machine account.
561 Generates machine credentials with good and bad passwords.
564 self.machine_creds = Credentials()
565 self.machine_creds.guess(self.lp)
566 self.machine_creds.set_workstation(self.netbios_name)
567 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
568 self.machine_creds.set_password(self.machinepass)
569 self.machine_creds.set_username(self.netbios_name + "$")
570 self.machine_creds.set_domain(self.domain)
571 if self.prefer_kerberos:
572 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
574 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
576 self.machine_creds_bad = Credentials()
577 self.machine_creds_bad.guess(self.lp)
578 self.machine_creds_bad.set_workstation(self.netbios_name)
579 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
580 self.machine_creds_bad.set_password(self.machinepass[:-4])
581 self.machine_creds_bad.set_username(self.netbios_name + "$")
582 if self.prefer_kerberos:
583 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
585 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
587 def get_matching_dn(self, pattern, attributes=None):
588 # If the pattern is an empty string, we assume ROOTDSE,
589 # Otherwise we try adding or removing DC suffixes, then
590 # shorter leading patterns until we hit one.
591 # e.g if there is no CN,CN,CN,CN,DC,DC
592 # we first try CN,CN,CN,CN,DC
593 # and CN,CN,CN,CN,DC,DC,DC
594 # then change to CN,CN,CN,DC,DC
595 # and as last resort we use the base_dn
596 attr_clue = self.attribute_clue_map.get(attributes)
598 return random.choice(attr_clue)
600 pattern = pattern.upper()
602 if pattern in self.dn_map:
603 return random.choice(self.dn_map[pattern])
604 # chop one off the front and try it all again.
605 pattern = pattern[3:]
609 def get_dcerpc_connection(self, new=False):
610 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
611 if self.dcerpc_connections and not new:
612 return self.dcerpc_connections[-1]
613 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
615 self.dcerpc_connections.append(c)
618 def get_srvsvc_connection(self, new=False):
619 if self.srvsvc_connections and not new:
620 return self.srvsvc_connections[-1]
623 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
627 (c, self.last_srvsvc_bad) = \
628 self.with_random_bad_credentials(connect,
631 self.last_srvsvc_bad)
633 self.srvsvc_connections.append(c)
636 def get_lsarpc_connection(self, new=False):
637 if self.lsarpc_connections and not new:
638 return self.lsarpc_connections[-1]
641 binding_options = 'schannel,seal,sign'
642 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
643 (self.server, binding_options),
647 (c, self.last_lsarpc_bad) = \
648 self.with_random_bad_credentials(connect,
650 self.machine_creds_bad,
651 self.last_lsarpc_bad)
653 self.lsarpc_connections.append(c)
656 def get_lsarpc_named_pipe_connection(self, new=False):
657 if self.lsarpc_connections_named and not new:
658 return self.lsarpc_connections_named[-1]
661 return lsa.lsarpc("ncacn_np:%s" % (self.server),
665 (c, self.last_lsarpc_named_bad) = \
666 self.with_random_bad_credentials(connect,
668 self.machine_creds_bad,
669 self.last_lsarpc_named_bad)
671 self.lsarpc_connections_named.append(c)
674 def get_drsuapi_connection_pair(self, new=False, unbind=False):
675 """get a (drs, drs_handle) tuple"""
676 if self.drsuapi_connections and not new:
677 c = self.drsuapi_connections[-1]
681 binding_options = 'seal'
682 binding_string = "ncacn_ip_tcp:%s[%s]" %\
683 (self.server, binding_options)
684 return drsuapi.drsuapi(binding_string, self.lp, creds)
686 (drs, self.last_drsuapi_bad) = \
687 self.with_random_bad_credentials(connect,
690 self.last_drsuapi_bad)
692 (drs_handle, supported_extensions) = drs_DsBind(drs)
693 c = (drs, drs_handle)
694 self.drsuapi_connections.append(c)
697 def get_ldap_connection(self, new=False, simple=False):
698 if self.ldap_connections and not new:
699 return self.ldap_connections[-1]
701 def simple_bind(creds):
703 To run simple bind against Windows, we need to run
704 following commands in PowerShell:
706 Install-windowsfeature ADCS-Cert-Authority
707 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
711 return SamDB('ldaps://%s' % self.server,
715 def sasl_bind(creds):
716 return SamDB('ldap://%s' % self.server,
720 (samdb, self.last_simple_bind_bad) = \
721 self.with_random_bad_credentials(simple_bind,
722 self.simple_bind_creds,
723 self.simple_bind_creds_bad,
724 self.last_simple_bind_bad)
726 (samdb, self.last_bind_bad) = \
727 self.with_random_bad_credentials(sasl_bind,
732 self.ldap_connections.append(samdb)
735 def get_samr_context(self, new=False):
736 if not self.samr_contexts or new:
737 self.samr_contexts.append(
738 SamrContext(self.server, lp=self.lp, creds=self.creds))
739 return self.samr_contexts[-1]
741 def get_netlogon_connection(self):
743 if self.netlogon_connection:
744 return self.netlogon_connection
747 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
751 (c, self.last_netlogon_bad) = \
752 self.with_random_bad_credentials(connect,
754 self.machine_creds_bad,
755 self.last_netlogon_bad)
756 self.netlogon_connection = c
759 def guess_a_dns_lookup(self):
760 return (self.realm, 'A')
762 def get_authenticator(self):
763 auth = self.machine_creds.new_client_authenticator()
764 current = netr_Authenticator()
765 current.cred.data = [x if isinstance(x, int) else ord(x)
766 for x in auth["credential"]]
767 current.timestamp = auth["timestamp"]
769 subsequent = netr_Authenticator()
770 return (current, subsequent)
772 def write_stats(self, filename, **kwargs):
773 """Write arbitrary key/value pairs to a file in our stats directory in
774 order for them to be picked up later by another process working out
776 filename = os.path.join(self.statsdir, filename)
777 f = open(filename, 'w')
778 for k, v in kwargs.items():
779 print("%s: %s" % (k, v), file=f)
783 class SamrContext(object):
784 """State/Context associated with a samr connection.
786 def __init__(self, server, lp=None, creds=None):
787 self.connection = None
789 self.domain_handle = None
790 self.domain_sid = None
791 self.group_handle = None
792 self.user_handle = None
798 def get_connection(self):
799 if not self.connection:
800 self.connection = samr.samr(
801 "ncacn_ip_tcp:%s[seal]" % (self.server),
803 credentials=self.creds)
805 return self.connection
807 def get_handle(self):
809 c = self.get_connection()
810 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
814 class Conversation(object):
815 """Details of a converation between a simulated client and a server."""
816 def __init__(self, start_time=None, endpoints=None, seq=(),
817 conversation_id=None):
818 self.start_time = start_time
819 self.endpoints = endpoints
821 self.msg = random_colour_print(endpoints)
822 self.client_balance = 0.0
823 self.conversation_id = conversation_id
825 self.add_short_packet(*p)
827 def __cmp__(self, other):
828 if self.start_time is None:
829 if other.start_time is None:
832 if other.start_time is None:
834 return self.start_time - other.start_time
836 def add_packet(self, packet):
837 """Add a packet object to this conversation, making a local copy with
838 a conversation-relative timestamp."""
841 if self.start_time is None:
842 self.start_time = p.timestamp
844 if self.endpoints is None:
845 self.endpoints = p.endpoints
847 if p.endpoints != self.endpoints:
848 raise FakePacketError("Conversation endpoints %s don't match"
849 "packet endpoints %s" %
850 (self.endpoints, p.endpoints))
852 p.timestamp -= self.start_time
854 if p.src == p.endpoints[0]:
855 self.client_balance -= p.client_score()
857 self.client_balance += p.client_score()
859 if p.is_really_a_packet():
860 self.packets.append(p)
862 def add_short_packet(self, timestamp, protocol, opcode, extra,
864 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
865 (possibly empty) list of extra data. If client is True, assume
866 this packet is from the client to the server.
868 src, dest = self.guess_client_server()
870 src, dest = dest, src
871 key = (protocol, opcode)
872 desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
873 if protocol in IP_PROTOCOLS:
874 ip_protocol = IP_PROTOCOLS[protocol]
877 packet = Packet(timestamp - self.start_time, ip_protocol,
879 protocol, opcode, desc, extra)
880 # XXX we're assuming the timestamp is already adjusted for
882 # XXX should we adjust client balance for guessed packets?
883 if packet.src == packet.endpoints[0]:
884 self.client_balance -= packet.client_score()
886 self.client_balance += packet.client_score()
887 if packet.is_really_a_packet():
888 self.packets.append(packet)
891 return ("<Conversation %s %s starting %.3f %d packets>" %
892 (self.conversation_id, self.endpoints, self.start_time,
898 return iter(self.packets)
901 return len(self.packets)
903 def get_duration(self):
904 if len(self.packets) < 2:
906 return self.packets[-1].timestamp - self.packets[0].timestamp
908 def replay_as_summary_lines(self):
910 for p in self.packets:
911 lines.append(p.as_summary(self.start_time))
914 def replay_with_delay(self, start, context=None, account=None):
915 """Replay the conversation at the right time.
916 (We're already in a fork)."""
917 # first we sleep until the first packet
919 now = time.time() - start
921 sleep_time = gap - SLEEP_OVERHEAD
923 time.sleep(sleep_time)
925 miss = (time.time() - start) - t
926 self.msg("starting %s [miss %.3f]" % (self, miss))
930 # packet times are relative to conversation start
931 p_start = time.time()
932 for p in self.packets:
933 now = time.time() - p_start
934 gap = now - p.timestamp
938 sleep_time = -gap - SLEEP_OVERHEAD
940 time.sleep(sleep_time)
941 t = time.time() - p_start
942 if t - p.timestamp > max_sleep_miss:
943 max_sleep_miss = t - p.timestamp
945 p.play(self, context)
947 return max_gap, miss, max_sleep_miss
949 def guess_client_server(self, server_clue=None):
950 """Have a go at deciding who is the server and who is the client.
951 returns (client, server)
953 a, b = self.endpoints
955 if self.client_balance < 0:
958 # in the absense of a clue, we will fall through to assuming
959 # the lowest number is the server (which is usually true).
961 if self.client_balance == 0 and server_clue == b:
966 def forget_packets_outside_window(self, s, e):
967 """Prune any packets outside the timne window we're interested in
969 :param s: start of the window
970 :param e: end of the window
972 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
973 self.start_time = self.packets[0].timestamp if self.packets else None
975 def renormalise_times(self, start_time):
976 """Adjust the packet start times relative to the new start time."""
977 for p in self.packets:
978 p.timestamp -= start_time
980 if self.start_time is not None:
981 self.start_time -= start_time
984 class DnsHammer(Conversation):
985 """A lightweight conversation that generates a lot of dns:0 packets on
988 def __init__(self, dns_rate, duration):
989 n = int(dns_rate * duration)
990 self.times = [random.uniform(0, duration) for i in range(n)]
993 self.duration = duration
995 self.msg = random_colour_print()
998 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
999 (len(self.times), self.duration, self.rate))
1001 def replay(self, context=None):
1003 fn = traffic_packets.packet_dns_0
1004 for t in self.times:
1005 now = time.time() - start
1007 sleep_time = gap - SLEEP_OVERHEAD
1009 time.sleep(sleep_time)
1011 packet_start = time.time()
1013 fn(None, None, context)
1015 duration = end - packet_start
1016 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1017 except Exception as e:
1019 duration = end - packet_start
1020 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1023 def ingest_summaries(files, dns_mode='count'):
1024 """Load a summary traffic summary file and generated Converations from it.
1027 dns_counts = defaultdict(int)
1030 if isinstance(f, str):
1032 print("Ingesting %s" % (f.name,), file=sys.stderr)
1034 p = Packet.from_line(line)
1035 if p.protocol == 'dns' and dns_mode != 'include':
1036 dns_counts[p.opcode] += 1
1045 start_time = min(p.timestamp for p in packets)
1046 last_packet = max(p.timestamp for p in packets)
1048 print("gathering packets into conversations", file=sys.stderr)
1049 conversations = OrderedDict()
1050 for i, p in enumerate(packets):
1051 p.timestamp -= start_time
1052 c = conversations.get(p.endpoints)
1054 c = Conversation(conversation_id=(i + 2))
1055 conversations[p.endpoints] = c
1058 # We only care about conversations with actual traffic, so we
1059 # filter out conversations with nothing to say. We do that here,
1060 # rather than earlier, because those empty packets contain useful
1061 # hints as to which end of the conversation was the client.
1062 conversation_list = []
1063 for c in conversations.values():
1065 conversation_list.append(c)
1067 # This is obviously not correct, as many conversations will appear
1068 # to start roughly simultaneously at the beginning of the snapshot.
1069 # To which we say: oh well, so be it.
1070 duration = float(last_packet - start_time)
1071 mean_interval = len(conversations) / duration
1073 return conversation_list, mean_interval, duration, dns_counts
1076 def guess_server_address(conversations):
1077 # we guess the most common address.
1078 addresses = Counter()
1079 for c in conversations:
1080 addresses.update(c.endpoints)
1082 return addresses.most_common(1)[0]
1085 def stringify_keys(x):
1087 for k, v in x.items():
1093 def unstringify_keys(x):
1095 for k, v in x.items():
1096 t = tuple(str(k).split('\t'))
1101 class TrafficModel(object):
1102 def __init__(self, n=3):
1104 self.query_details = {}
1106 self.dns_opcounts = defaultdict(int)
1107 self.cumulative_duration = 0.0
1108 self.packet_rate = [0, 1]
1110 def learn(self, conversations, dns_opcounts={}):
1113 key = (NON_PACKET,) * (self.n - 1)
1115 server = guess_server_address(conversations)
1117 for k, v in dns_opcounts.items():
1118 self.dns_opcounts[k] += v
1120 if len(conversations) > 1:
1121 first = conversations[0].start_time
1124 for c in conversations:
1126 last = max(last, c.packets[-1].timestamp)
1128 self.packet_rate[0] = total
1129 self.packet_rate[1] = last - first
1131 for c in conversations:
1132 client, server = c.guess_client_server(server)
1133 cum_duration += c.get_duration()
1134 key = (NON_PACKET,) * (self.n - 1)
1139 elapsed = p.timestamp - prev
1141 if elapsed > WAIT_THRESHOLD:
1142 # add the wait as an extra state
1143 wait = 'wait:%d' % (math.log(max(1.0,
1144 elapsed * WAIT_SCALE)))
1145 self.ngrams.setdefault(key, []).append(wait)
1146 key = key[1:] + (wait,)
1148 short_p = p.as_packet_type()
1149 self.query_details.setdefault(short_p,
1150 []).append(tuple(p.extra))
1151 self.ngrams.setdefault(key, []).append(short_p)
1152 key = key[1:] + (short_p,)
1154 self.cumulative_duration += cum_duration
1156 self.ngrams.setdefault(key, []).append(NON_PACKET)
1160 for k, v in self.ngrams.items():
1162 ngrams[k] = dict(Counter(v))
1165 for k, v in self.query_details.items():
1166 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1171 'query_details': query_details,
1172 'cumulative_duration': self.cumulative_duration,
1173 'packet_rate': self.packet_rate,
1174 'version': CURRENT_MODEL_VERSION
1176 d['dns'] = self.dns_opcounts
1178 if isinstance(f, str):
1181 json.dump(d, f, indent=2)
1184 if isinstance(f, str):
1190 version = d["version"]
1191 if version < REQUIRED_MODEL_VERSION:
1192 raise ValueError("the model file is version %d; "
1193 "version %d is required" %
1194 (version, REQUIRED_MODEL_VERSION))
1196 raise ValueError("the model file lacks a version number; "
1197 "version %d is required" %
1198 (REQUIRED_MODEL_VERSION))
1200 for k, v in d['ngrams'].items():
1201 k = tuple(str(k).split('\t'))
1202 values = self.ngrams.setdefault(k, [])
1203 for p, count in v.items():
1204 values.extend([str(p)] * count)
1207 for k, v in d['query_details'].items():
1208 values = self.query_details.setdefault(str(k), [])
1209 for p, count in v.items():
1211 values.extend([()] * count)
1213 values.extend([tuple(str(p).split('\t'))] * count)
1217 for k, v in d['dns'].items():
1218 self.dns_opcounts[k] += v
1220 self.cumulative_duration = d['cumulative_duration']
1221 self.packet_rate = d['packet_rate']
1223 def construct_conversation_sequence(self, timestamp=0.0,
1227 """Construct an individual conversation packet sequence from the
1231 key = (NON_PACKET,) * (self.n - 1)
1232 if ignore_before is None:
1233 ignore_before = timestamp - 1
1236 p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1240 if p in self.query_details:
1241 extra = random.choice(self.query_details[p])
1245 protocol, opcode = p.split(':', 1)
1246 if protocol == 'wait':
1247 log_wait_time = int(opcode) + random.random()
1248 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1251 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1252 wait = math.exp(log_wait) / replay_speed
1254 if hard_stop is not None and timestamp > hard_stop:
1256 if timestamp >= ignore_before:
1257 c.append((timestamp, protocol, opcode, extra))
1259 key = key[1:] + (p,)
1263 def generate_conversation_sequences(self, scale, duration, replay_speed=1):
1264 """Generate a list of conversation descriptions from the model."""
1266 # We run the simulation for ten times as long as our desired
1267 # duration, and take the section at the end.
1268 lead_in = 9 * duration
1269 rate_n, rate_t = self.packet_rate
1270 target_packets = int(duration * scale * rate_n / rate_t)
1275 while n_packets < target_packets:
1276 start = random.uniform(-lead_in, duration)
1277 c = self.construct_conversation_sequence(start,
1279 replay_speed=replay_speed,
1281 # will these "packets" generate actual traffic?
1282 # some (e.g. ldap unbind) will not generate anything
1283 # if the previous packets are not there, and if the
1284 # conversation only has those it wastes a process doing nothing.
1285 for timestamp, protocol, opcode, extra in c:
1286 if is_a_traffic_generating_packet(protocol, opcode):
1291 conversations.append(c)
1294 print(("we have %d packets (target %d) in %d conversations at scale %f"
1295 % (n_packets, target_packets, len(conversations), scale)),
1297 conversations.sort() # sorts by first element == start time
1298 return conversations
1301 def seq_to_conversations(seq, server=1, client=2):
1305 c = Conversation(s[0][0], (server, client), s)
1307 conversations.append(c)
1308 return conversations
1313 'rpc_netlogon': '06',
1314 'kerberos': '06', # ratio 16248:258
1325 'smb_netlogon': '11',
1331 ('browser', '0x01'): 'Host Announcement (0x01)',
1332 ('browser', '0x02'): 'Request Announcement (0x02)',
1333 ('browser', '0x08'): 'Browser Election Request (0x08)',
1334 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1335 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1336 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1337 ('cldap', '3'): 'searchRequest',
1338 ('cldap', '5'): 'searchResDone',
1339 ('dcerpc', '0'): 'Request',
1340 ('dcerpc', '11'): 'Bind',
1341 ('dcerpc', '12'): 'Bind_ack',
1342 ('dcerpc', '13'): 'Bind_nak',
1343 ('dcerpc', '14'): 'Alter_context',
1344 ('dcerpc', '15'): 'Alter_context_resp',
1345 ('dcerpc', '16'): 'AUTH3',
1346 ('dcerpc', '2'): 'Response',
1347 ('dns', '0'): 'query',
1348 ('dns', '1'): 'response',
1349 ('drsuapi', '0'): 'DsBind',
1350 ('drsuapi', '12'): 'DsCrackNames',
1351 ('drsuapi', '13'): 'DsWriteAccountSpn',
1352 ('drsuapi', '1'): 'DsUnbind',
1353 ('drsuapi', '2'): 'DsReplicaSync',
1354 ('drsuapi', '3'): 'DsGetNCChanges',
1355 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1356 ('epm', '3'): 'Map',
1357 ('kerberos', ''): '',
1358 ('ldap', '0'): 'bindRequest',
1359 ('ldap', '1'): 'bindResponse',
1360 ('ldap', '2'): 'unbindRequest',
1361 ('ldap', '3'): 'searchRequest',
1362 ('ldap', '4'): 'searchResEntry',
1363 ('ldap', '5'): 'searchResDone',
1364 ('ldap', ''): '*** Unknown ***',
1365 ('lsarpc', '14'): 'lsa_LookupNames',
1366 ('lsarpc', '15'): 'lsa_LookupSids',
1367 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1368 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1369 ('lsarpc', '6'): 'lsa_OpenPolicy',
1370 ('lsarpc', '76'): 'lsa_LookupSids3',
1371 ('lsarpc', '77'): 'lsa_LookupNames4',
1372 ('nbns', '0'): 'query',
1373 ('nbns', '1'): 'response',
1374 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1375 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1376 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1377 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1378 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1379 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1380 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1381 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1382 ('samr', '0',): 'Connect',
1383 ('samr', '16'): 'GetAliasMembership',
1384 ('samr', '17'): 'LookupNames',
1385 ('samr', '18'): 'LookupRids',
1386 ('samr', '19'): 'OpenGroup',
1387 ('samr', '1'): 'Close',
1388 ('samr', '25'): 'QueryGroupMember',
1389 ('samr', '34'): 'OpenUser',
1390 ('samr', '36'): 'QueryUserInfo',
1391 ('samr', '39'): 'GetGroupsForUser',
1392 ('samr', '3'): 'QuerySecurity',
1393 ('samr', '5'): 'LookupDomain',
1394 ('samr', '64'): 'Connect5',
1395 ('samr', '6'): 'EnumDomains',
1396 ('samr', '7'): 'OpenDomain',
1397 ('samr', '8'): 'QueryDomainInfo',
1398 ('smb', '0x04'): 'Close (0x04)',
1399 ('smb', '0x24'): 'Locking AndX (0x24)',
1400 ('smb', '0x2e'): 'Read AndX (0x2e)',
1401 ('smb', '0x32'): 'Trans2 (0x32)',
1402 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1403 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1404 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1405 ('smb', '0x74'): 'Logoff AndX (0x74)',
1406 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1407 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1408 ('smb2', '0'): 'NegotiateProtocol',
1409 ('smb2', '11'): 'Ioctl',
1410 ('smb2', '14'): 'Find',
1411 ('smb2', '16'): 'GetInfo',
1412 ('smb2', '18'): 'Break',
1413 ('smb2', '1'): 'SessionSetup',
1414 ('smb2', '2'): 'SessionLogoff',
1415 ('smb2', '3'): 'TreeConnect',
1416 ('smb2', '4'): 'TreeDisconnect',
1417 ('smb2', '5'): 'Create',
1418 ('smb2', '6'): 'Close',
1419 ('smb2', '8'): 'Read',
1420 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1421 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1422 'user unknown (0x17)'),
1423 ('srvsvc', '16'): 'NetShareGetInfo',
1424 ('srvsvc', '21'): 'NetSrvGetInfo',
1428 def expand_short_packet(p, timestamp, src, dest, extra):
1429 protocol, opcode = p.split(':', 1)
1430 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1431 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1433 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1435 return '\t'.join(line)
1438 def flushing_signal_handler(signal, frame):
1439 """Signal handler closes standard out and error.
1441 Triggered by a sigterm, ensures that the log messages are flushed
1442 to disk and not lost.
1449 def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1450 """Fork a new process and replay the conversation sequence."""
1451 # We will need to reseed the random number generator or all the
1452 # clients will end up using the same sequence of random
1453 # numbers. random.randint() is mixed in so the initial seed will
1454 # have an effect here.
1455 seed = client_id * 1000 + random.randint(0, 999)
1457 # flush our buffers so messages won't be written by both sides
1464 # we must never return, or we'll end up running parts of the
1465 # parent's clean-up code. So we work in a try...finally, and
1466 # try to print any exceptions.
1469 endpoints = (server_id, client_id)
1472 c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1473 signal.signal(signal.SIGTERM, flushing_signal_handler)
1475 context.generate_process_local_config(account, c)
1478 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1480 f = open(filename, 'w')
1484 except IOError as e:
1485 LOGGER.info("stdout closing failed with %s" % e)
1489 now = time.time() - start
1491 sleep_time = gap - SLEEP_OVERHEAD
1493 time.sleep(sleep_time)
1495 max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1497 print("Maximum lag: %f" % max_lag)
1498 print("Start lag: %f" % start_lag)
1499 print("Max sleep miss: %f" % max_sleep_miss)
1503 print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1505 traceback.print_exc(sys.stderr)
1513 def dnshammer_in_fork(dns_rate, duration):
1521 signal.signal(signal.SIGTERM, flushing_signal_handler)
1522 hammer = DnsHammer(dns_rate, duration)
1526 print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1528 traceback.print_exc(sys.stderr)
1535 def replay(conversation_seq,
1542 latency_timeout=1.0,
1543 stop_on_any_error=False,
1546 context = ReplayContext(server=host,
1551 if len(accounts) < len(conversation_seq):
1552 raise ValueError(("we have %d accounts but %d conversations" %
1553 (len(accounts), len(conversation_seq))))
1555 # Set the process group so that the calling scripts are not killed
1556 # when the forked child processes are killed.
1559 # we delay the start by a bit to allow all the forks to get up and
1561 delay = len(conversation_seq) * 0.02
1562 start = time.time() + delay
1564 if duration is None:
1565 # end slightly after the last packet of the last conversation
1566 # to start. Conversations other than the last could still be
1567 # going, but we don't care.
1568 duration = conversation_seq[-1][-1][0] + latency_timeout
1570 print("We will start in %.1f seconds" % delay,
1572 print("We will stop after %.1f seconds" % (duration + delay),
1574 print("runtime %.1f seconds" % duration,
1577 # give one second grace for packets to finish before killing begins
1578 end = start + duration + 1.0
1580 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1581 % (len(conversation_seq), duration))
1583 context.write_stats('intentions',
1584 Planned_conversations=len(conversation_seq),
1585 Planned_packets=sum(len(x) for x in conversation_seq))
1590 pid = dnshammer_in_fork(dns_rate, duration)
1593 for i, cs in enumerate(conversation_seq):
1594 account = accounts[i]
1596 pid = replay_seq_in_fork(cs, start, context, account, client_id)
1597 children[pid] = client_id
1599 # HERE, we are past all the forks
1601 print("all forks done in %.1f seconds, waiting %.1f" %
1602 (t - start + delay, t - start),
1605 while time.time() < end and children:
1608 pid, status = os.waitpid(-1, os.WNOHANG)
1609 except OSError as e:
1610 if e.errno != ECHILD: # no child processes
1614 c = children.pop(pid, None)
1616 print(("process %d finished conversation %d;"
1618 (pid, c, len(children))), file=sys.stderr)
1619 if stop_on_any_error and status != 0:
1623 print("EXCEPTION in parent", file=sys.stderr)
1624 traceback.print_exc()
1626 context.write_stats('unfinished',
1627 Unfinished_conversations=len(children))
1629 for s in (15, 15, 9):
1630 print(("killing %d children with -%d" %
1631 (len(children), s)), file=sys.stderr)
1632 for pid in children:
1635 except OSError as e:
1636 if e.errno != ESRCH: # don't fail if it has already died
1639 end = time.time() + 1
1642 pid, status = os.waitpid(-1, os.WNOHANG)
1643 except OSError as e:
1644 if e.errno != ECHILD:
1647 c = children.pop(pid, None)
1649 print("children is %s, no pid found" % children)
1653 print(("kill -%d %d KILLED conversation; "
1655 (s, pid, len(children))),
1657 if time.time() >= end:
1665 print("%d children are missing" % len(children),
1668 # there may be stragglers that were forked just as ^C was hit
1669 # and don't appear in the list of children. We can get them
1670 # with killpg, but that will also kill us, so this is^H^H would be
1671 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1672 # so as not to have to fuss around writing signal handlers.
1675 except KeyboardInterrupt:
1676 print("ignoring fake ^C", file=sys.stderr)
1679 def openLdb(host, creds, lp):
1680 session = system_session()
1681 ldb = SamDB(url="ldap://%s" % host,
1682 session_info=session,
1683 options=['modules:paged_searches'],
1689 def ou_name(ldb, instance_id):
1690 """Generate an ou name from the instance id"""
1691 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1695 def create_ou(ldb, instance_id):
1696 """Create an ou, all created user and machine accounts will belong to it.
1698 This allows all the created resources to be cleaned up easily.
1700 ou = ou_name(ldb, instance_id)
1702 ldb.add({"dn": ou.split(',', 1)[1],
1703 "objectclass": "organizationalunit"})
1704 except LdbError as e:
1705 (status, _) = e.args
1706 # ignore already exists
1711 "objectclass": "organizationalunit"})
1712 except LdbError as e:
1713 (status, _) = e.args
1714 # ignore already exists
1720 # ConversationAccounts holds details of the machine and user accounts
1721 # associated with a conversation.
1723 # We use a named tuple to reduce shared memory usage.
1724 ConversationAccounts = namedtuple('ConversationAccounts',
1731 def generate_replay_accounts(ldb, instance_id, number, password):
1732 """Generate a series of unique machine and user account names."""
1735 for i in range(1, number + 1):
1736 netbios_name = machine_name(instance_id, i)
1737 username = user_name(instance_id, i)
1739 account = ConversationAccounts(netbios_name, password, username,
1741 accounts.append(account)
1745 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1746 traffic_account=True):
1747 """Create a machine account via ldap."""
1749 ou = ou_name(ldb, instance_id)
1750 dn = "cn=%s,%s" % (netbios_name, ou)
1751 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1754 # we set these bits for the machine account otherwise the replayed
1755 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1756 account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1757 UF_SERVER_TRUST_ACCOUNT)
1760 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1764 "objectclass": "computer",
1765 "sAMAccountName": "%s$" % netbios_name,
1766 "userAccountControl": account_controls,
1767 "unicodePwd": utf16pw})
1770 def create_user_account(ldb, instance_id, username, userpass):
1771 """Create a user account via ldap."""
1772 ou = ou_name(ldb, instance_id)
1773 user_dn = "cn=%s,%s" % (username, ou)
1774 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1777 "objectclass": "user",
1778 "sAMAccountName": username,
1779 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1780 "unicodePwd": utf16pw
1783 # grant user write permission to do things like write account SPN
1784 sdutils = sd_utils.SDUtils(ldb)
1785 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1788 def create_group(ldb, instance_id, name):
1789 """Create a group via ldap."""
1791 ou = ou_name(ldb, instance_id)
1792 dn = "cn=%s,%s" % (name, ou)
1795 "objectclass": "group",
1796 "sAMAccountName": name,
1800 def user_name(instance_id, i):
1801 """Generate a user name based in the instance id"""
1802 return "STGU-%d-%d" % (instance_id, i)
1805 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1806 """Seach objectclass, return attr in a set"""
1808 expression="(objectClass={})".format(objectclass),
1811 return {str(obj[attr]) for obj in objs}
1814 def generate_users(ldb, instance_id, number, password):
1815 """Add users to the server"""
1816 existing_objects = search_objectclass(ldb, objectclass='user')
1818 for i in range(number, 0, -1):
1819 name = user_name(instance_id, i)
1820 if name not in existing_objects:
1821 create_user_account(ldb, instance_id, name, password)
1824 LOGGER.info("Created %u/%u users" % (users, number))
1829 def machine_name(instance_id, i, traffic_account=True):
1830 """Generate a machine account name from instance id."""
1832 # traffic accounts correspond to a given user, and use different
1833 # userAccountControl flags to ensure packets get processed correctly
1835 return "STGM-%d-%d" % (instance_id, i)
1837 # Otherwise we're just generating computer accounts to simulate a
1838 # semi-realistic network. These use the default computer
1839 # userAccountControl flags, so we use a different account name so that
1840 # we don't try to use them when generating packets
1841 return "PC-%d-%d" % (instance_id, i)
1844 def generate_machine_accounts(ldb, instance_id, number, password,
1845 traffic_account=True):
1846 """Add machine accounts to the server"""
1847 existing_objects = search_objectclass(ldb, objectclass='computer')
1849 for i in range(number, 0, -1):
1850 name = machine_name(instance_id, i, traffic_account)
1851 if name + "$" not in existing_objects:
1852 create_machine_account(ldb, instance_id, name, password,
1856 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1861 def group_name(instance_id, i):
1862 """Generate a group name from instance id."""
1863 return "STGG-%d-%d" % (instance_id, i)
1866 def generate_groups(ldb, instance_id, number):
1867 """Create the required number of groups on the server."""
1868 existing_objects = search_objectclass(ldb, objectclass='group')
1870 for i in range(number, 0, -1):
1871 name = group_name(instance_id, i)
1872 if name not in existing_objects:
1873 create_group(ldb, instance_id, name)
1875 if groups % 1000 == 0:
1876 LOGGER.info("Created %u/%u groups" % (groups, number))
1881 def clean_up_accounts(ldb, instance_id):
1882 """Remove the created accounts and groups from the server."""
1883 ou = ou_name(ldb, instance_id)
1885 ldb.delete(ou, ["tree_delete:1"])
1886 except LdbError as e:
1887 (status, _) = e.args
1888 # ignore does not exist
1893 def generate_users_and_groups(ldb, instance_id, password,
1894 number_of_users, number_of_groups,
1895 group_memberships, max_members,
1896 machine_accounts, traffic_accounts=True):
1897 """Generate the required users and groups, allocating the users to
1899 memberships_added = 0
1903 create_ou(ldb, instance_id)
1905 LOGGER.info("Generating dummy user accounts")
1906 users_added = generate_users(ldb, instance_id, number_of_users, password)
1908 LOGGER.info("Generating dummy machine accounts")
1909 computers_added = generate_machine_accounts(ldb, instance_id,
1910 machine_accounts, password,
1913 if number_of_groups > 0:
1914 LOGGER.info("Generating dummy groups")
1915 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1917 if group_memberships > 0:
1918 LOGGER.info("Assigning users to groups")
1919 assignments = GroupAssignments(number_of_groups,
1925 LOGGER.info("Adding users to groups")
1926 add_users_to_groups(ldb, instance_id, assignments)
1927 memberships_added = assignments.total()
1929 if (groups_added > 0 and users_added == 0 and
1930 number_of_groups != groups_added):
1931 LOGGER.warning("The added groups will contain no members")
1933 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1934 (users_added, computers_added, groups_added,
1938 class GroupAssignments(object):
1939 def __init__(self, number_of_groups, groups_added, number_of_users,
1940 users_added, group_memberships, max_members):
1943 self.generate_group_distribution(number_of_groups)
1944 self.generate_user_distribution(number_of_users, group_memberships)
1945 self.max_members = max_members
1946 self.assignments = defaultdict(list)
1947 self.assign_groups(number_of_groups, groups_added, number_of_users,
1948 users_added, group_memberships)
1950 def cumulative_distribution(self, weights):
1951 # make sure the probabilities conform to a cumulative distribution
1952 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1953 # probability a proportional share of 1.0. Higher probabilities get a
1954 # bigger share, so are more likely to be picked. We use the cumulative
1955 # value, so we can use random.random() as a simple index into the list
1957 total = sum(weights)
1962 for probability in weights:
1963 cumulative += probability
1964 dist.append(cumulative / total)
1967 def generate_user_distribution(self, num_users, num_memberships):
1968 """Probability distribution of a user belonging to a group.
1970 # Assign a weighted probability to each user. Use the Pareto
1971 # Distribution so that some users are in a lot of groups, and the
1972 # bulk of users are in only a few groups. If we're assigning a large
1973 # number of group memberships, use a higher shape. This means slightly
1974 # fewer outlying users that are in large numbers of groups. The aim is
1975 # to have no users belonging to more than ~500 groups.
1976 if num_memberships > 5000000:
1978 elif num_memberships > 2000000:
1980 elif num_memberships > 300000:
1986 for x in range(1, num_users + 1):
1987 p = random.paretovariate(shape)
1990 # convert the weights to a cumulative distribution between 0.0 and 1.0
1991 self.user_dist = self.cumulative_distribution(weights)
1993 def generate_group_distribution(self, n):
1994 """Probability distribution of a group containing a user."""
1996 # Assign a weighted probability to each user. Probability decreases
1997 # as the group-ID increases
1999 for x in range(1, n + 1):
2003 # convert the weights to a cumulative distribution between 0.0 and 1.0
2004 self.group_weights = weights
2005 self.group_dist = self.cumulative_distribution(weights)
2007 def generate_random_membership(self):
2008 """Returns a randomly generated user-group membership"""
2010 # the list items are cumulative distribution values between 0.0 and
2011 # 1.0, which makes random() a handy way to index the list to get a
2012 # weighted random user/group. (Here the user/group returned are
2013 # zero-based array indexes)
2014 user = bisect.bisect(self.user_dist, random.random())
2015 group = bisect.bisect(self.group_dist, random.random())
2019 def users_in_group(self, group):
2020 return self.assignments[group]
2022 def get_groups(self):
2023 return self.assignments.keys()
2025 def cap_group_membership(self, group, max_members):
2026 """Prevent the group's membership from exceeding the max specified"""
2027 num_members = len(self.assignments[group])
2028 if num_members >= max_members:
2029 LOGGER.info("Group {0} has {1} members".format(group, num_members))
2031 # remove this group and then recalculate the cumulative
2032 # distribution, so this group is no longer selected
2033 self.group_weights[group - 1] = 0
2034 new_dist = self.cumulative_distribution(self.group_weights)
2035 self.group_dist = new_dist
2037 def add_assignment(self, user, group):
2038 # the assignments are stored in a dictionary where key=group,
2039 # value=list-of-users-in-group (indexing by group-ID allows us to
2040 # optimize for DB membership writes)
2041 if user not in self.assignments[group]:
2042 self.assignments[group].append(user)
2045 # check if there'a cap on how big the groups can grow
2046 if self.max_members:
2047 self.cap_group_membership(group, self.max_members)
2049 def assign_groups(self, number_of_groups, groups_added,
2050 number_of_users, users_added, group_memberships):
2051 """Allocate users to groups.
2053 The intention is to have a few users that belong to most groups, while
2054 the majority of users belong to a few groups.
2056 A few groups will contain most users, with the remaining only having a
2060 if group_memberships <= 0:
2063 # Calculate the number of group menberships required
2064 group_memberships = math.ceil(
2065 float(group_memberships) *
2066 (float(users_added) / float(number_of_users)))
2068 if self.max_members:
2069 group_memberships = min(group_memberships,
2070 self.max_members * number_of_groups)
2072 existing_users = number_of_users - users_added - 1
2073 existing_groups = number_of_groups - groups_added - 1
2074 while self.total() < group_memberships:
2075 user, group = self.generate_random_membership()
2077 if group > existing_groups or user > existing_users:
2078 # the + 1 converts the array index to the corresponding
2079 # group or user number
2080 self.add_assignment(user + 1, group + 1)
2086 def add_users_to_groups(db, instance_id, assignments):
2087 """Takes the assignments of users to groups and applies them to the DB."""
2089 total = assignments.total()
2093 for group in assignments.get_groups():
2094 users_in_group = assignments.users_in_group(group)
2095 if len(users_in_group) == 0:
2098 # Split up the users into chunks, so we write no more than 1K at a
2099 # time. (Minimizing the DB modifies is more efficient, but writing
2100 # 10K+ users to a single group becomes inefficient memory-wise)
2101 for chunk in range(0, len(users_in_group), 1000):
2102 chunk_of_users = users_in_group[chunk:chunk + 1000]
2103 add_group_members(db, instance_id, group, chunk_of_users)
2105 added += len(chunk_of_users)
2108 LOGGER.info("Added %u/%u memberships" % (added, total))
2110 def add_group_members(db, instance_id, group, users_in_group):
2111 """Adds the given users to group specified."""
2113 ou = ou_name(db, instance_id)
2116 return("cn=%s,%s" % (name, ou))
2118 group_dn = build_dn(group_name(instance_id, group))
2120 m.dn = ldb.Dn(db, group_dn)
2122 for user in users_in_group:
2123 user_dn = build_dn(user_name(instance_id, user))
2124 idx = "member-" + str(user)
2125 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2130 def generate_stats(statsdir, timing_file):
2131 """Generate and print the summary stats for a run."""
2132 first = sys.float_info.max
2137 failures = Counter()
2138 unique_conversations = set()
2139 if timing_file is not None:
2140 tw = timing_file.write
2145 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2150 'Max sleep miss': 0,
2153 'Planned_conversations': 0,
2154 'Planned_packets': 0,
2155 'Unfinished_conversations': 0,
2158 for filename in os.listdir(statsdir):
2159 path = os.path.join(statsdir, filename)
2160 with open(path, 'r') as f:
2163 fields = line.rstrip('\n').split('\t')
2164 conversation = fields[1]
2165 protocol = fields[2]
2166 packet_type = fields[3]
2167 latency = float(fields[4])
2168 t = float(fields[0])
2169 first = min(t - latency, first)
2172 op = (protocol, packet_type)
2173 latencies.setdefault(op, []).append(latency)
2174 if fields[5] == 'True':
2180 unique_conversations.add(conversation)
2183 except (ValueError, IndexError):
2185 k, v = line.split(':', 1)
2186 if k in float_values:
2187 float_values[k] = max(float(v),
2189 elif k in int_values:
2190 int_values[k] = max(int(v),
2193 print(line, file=sys.stderr)
2195 # not a valid line print and ignore
2196 print(line, file=sys.stderr)
2198 duration = last - first
2202 success_rate = successful / duration
2206 failure_rate = failed / duration
2208 conversations = len(unique_conversations)
2210 print("Total conversations: %10d" % conversations)
2211 print("Successful operations: %10d (%.3f per second)"
2212 % (successful, success_rate))
2213 print("Failed operations: %10d (%.3f per second)"
2214 % (failed, failure_rate))
2216 for k, v in sorted(float_values.items()):
2217 print("%-28s %f" % (k.replace('_', ' ') + ':', v))
2218 for k, v in sorted(int_values.items()):
2219 print("%-28s %d" % (k.replace('_', ' ') + ':', v))
2221 print("Protocol Op Code Description "
2222 " Count Failed Mean Median "
2226 for proto, packet in latencies:
2227 if proto not in ops:
2229 ops[proto].add(packet)
2230 protocols = sorted(ops.keys())
2232 for protocol in protocols:
2233 packet_types = sorted(ops[protocol], key=opcode_key)
2234 for packet_type in packet_types:
2235 op = (protocol, packet_type)
2236 values = latencies[op]
2237 values = sorted(values)
2239 failed = failures[op]
2240 mean = sum(values) / count
2241 median = calc_percentile(values, 0.50)
2242 percentile = calc_percentile(values, 0.95)
2243 rng = values[-1] - values[0]
2245 desc = OP_DESCRIPTIONS.get(op, '')
2246 print("%-12s %4s %-35s %12d %12d %12.6f "
2247 "%12.6f %12.6f %12.6f %12.6f"
2261 """Sort key for the operation code to ensure that it sorts numerically"""
2263 return "%03d" % int(v)
2268 def calc_percentile(values, percentile):
2269 """Calculate the specified percentile from the list of values.
2271 Assumes the list is sorted in ascending order.
2276 k = (len(values) - 1) * percentile
2280 return values[int(k)]
2281 d0 = values[int(f)] * (c - k)
2282 d1 = values[int(c)] * (k - f)
2286 def mk_masked_dir(*path):
2287 """In a testenv we end up with 0777 directories that look an alarming
2288 green colour with ls. Use umask to avoid that."""
2289 # py3 os.mkdir can do this
2290 d = os.path.join(*path)
2291 mask = os.umask(0o077)