1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
29 from collections import OrderedDict, Counter, defaultdict, namedtuple
30 from samba.emulate import traffic_packets
31 from samba.samdb import SamDB
33 from ldb import LdbError
34 from samba.dcerpc import ClientConnection
35 from samba.dcerpc import security, drsuapi, lsa
36 from samba.dcerpc import netlogon
37 from samba.dcerpc.netlogon import netr_Authenticator
38 from samba.dcerpc import srvsvc
39 from samba.dcerpc import samr
40 from samba.drs_utils import drs_DsBind
42 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
43 from samba.auth import system_session
44 from samba.dsdb import (
46 UF_SERVER_TRUST_ACCOUNT,
47 UF_TRUSTED_FOR_DELEGATION,
48 UF_WORKSTATION_TRUST_ACCOUNT
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
57 CURRENT_MODEL_VERSION = 2 # save as this
58 REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
61 # we don't use None, because it complicates [de]serialisation
65 ('dns', '0'): 1.0, # query
66 ('smb', '0x72'): 1.0, # Negotiate protocol
67 ('ldap', '0'): 1.0, # bind
68 ('ldap', '3'): 1.0, # searchRequest
69 ('ldap', '2'): 1.0, # unbindRequest
71 ('dcerpc', '11'): 1.0, # bind
72 ('dcerpc', '14'): 1.0, # Alter_context
73 ('nbns', '0'): 1.0, # query
77 ('dns', '1'): 1.0, # response
78 ('ldap', '1'): 1.0, # bind response
79 ('ldap', '4'): 1.0, # search result
80 ('ldap', '5'): 1.0, # search done
82 ('dcerpc', '12'): 1.0, # bind_ack
83 ('dcerpc', '13'): 1.0, # bind_nak
84 ('dcerpc', '15'): 1.0, # Alter_context response
87 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
90 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
91 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
93 # DEBUG_LEVEL can be changed by scripts with -d
96 LOGGER = get_samba_logger(name=__name__)
99 def debug(level, msg, *args):
100 """Print a formatted debug message to standard error.
103 :param level: The debug level, message will be printed if it is <= the
104 currently set debug level. The debug level can be set with
106 :param msg: The message to be logged, can contain C-Style format
108 :param args: The parameters required by the format specifiers
110 if level <= DEBUG_LEVEL:
112 print(msg, file=sys.stderr)
114 print(msg % tuple(args), file=sys.stderr)
117 def debug_lineno(*args):
118 """ Print an unformatted log message to stderr, contaning the line number
120 tb = traceback.extract_stack(limit=2)
121 print((" %s:" "\033[01;33m"
122 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
125 print(a, file=sys.stderr)
126 print(file=sys.stderr)
130 def random_colour_print(seeds):
131 """Return a function that prints a coloured line to stderr. The colour
132 of the line depends on a sort of hash of the integer arguments."""
139 prefix = "\033[38;5;%dm" % (18 + s)
144 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
149 print(a, file=sys.stderr)
154 class FakePacketError(Exception):
158 class Packet(object):
159 """Details of a network packet"""
160 __slots__ = ('timestamp',
170 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
171 protocol, opcode, desc, extra):
172 self.timestamp = timestamp
173 self.ip_protocol = ip_protocol
174 self.stream_number = stream_number
177 self.protocol = protocol
181 if self.src < self.dest:
182 self.endpoints = (self.src, self.dest)
184 self.endpoints = (self.dest, self.src)
187 def from_line(cls, line):
188 fields = line.rstrip('\n').split('\t')
199 timestamp = float(timestamp)
203 return cls(timestamp, ip_protocol, stream_number, src, dest,
204 protocol, opcode, desc, extra)
206 def as_summary(self, time_offset=0.0):
207 """Format the packet as a traffic_summary line.
209 extra = '\t'.join(self.extra)
210 t = self.timestamp + time_offset
211 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
214 self.stream_number or '',
223 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
224 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
225 self.stream_number, self.protocol, self.opcode, self.desc,
226 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
229 return "<Packet @%s>" % self
232 return self.__class__(self.timestamp,
242 def as_packet_type(self):
243 t = '%s:%s' % (self.protocol, self.opcode)
246 def client_score(self):
247 """A positive number means we think it is a client; a negative number
248 means we think it is a server. Zero means no idea. range: -1 to 1.
250 key = (self.protocol, self.opcode)
251 if key in CLIENT_CLUES:
252 return CLIENT_CLUES[key]
253 if key in SERVER_CLUES:
254 return -SERVER_CLUES[key]
257 def play(self, conversation, context):
258 """Send the packet over the network, if required.
260 Some packets are ignored, i.e. for protocols not handled,
261 server response messages, or messages that are generated by the
262 protocol layer associated with other packets.
264 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
266 fn = getattr(traffic_packets, fn_name)
268 except AttributeError as e:
269 print("Conversation(%s) Missing handler %s" %
270 (conversation.conversation_id, fn_name),
274 # Don't display a message for kerberos packets, they're not directly
275 # generated they're used to indicate kerberos should be used
276 if self.protocol != "kerberos":
277 debug(2, "Conversation(%s) Calling handler %s" %
278 (conversation.conversation_id, fn_name))
282 if fn(self, conversation, context):
283 # Only collect timing data for functions that generate
284 # network traffic, or fail
286 duration = end - start
287 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
288 (end, conversation.conversation_id, self.protocol,
289 self.opcode, duration))
290 except Exception as e:
292 duration = end - start
293 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
294 (end, conversation.conversation_id, self.protocol,
295 self.opcode, duration, e))
297 def __cmp__(self, other):
298 return self.timestamp - other.timestamp
300 def is_really_a_packet(self, missing_packet_stats=None):
301 return is_a_real_packet(self.protocol, self.opcode)
304 def is_a_real_packet(protocol, opcode):
305 """Is the packet one that can be ignored?
307 If so removing it will have no effect on the replay
309 if protocol in SKIPPED_PROTOCOLS:
310 # Ignore any packets for the protocols we're not interested in.
312 if protocol == "ldap" and opcode == '':
313 # skip ldap continuation packets
316 fn_name = 'packet_%s_%s' % (protocol, opcode)
317 fn = getattr(traffic_packets, fn_name, None)
319 LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
321 if fn is traffic_packets.null_packet:
326 class ReplayContext(object):
327 """State/Context for a conversation between an simulated client and a
328 server. Some of the context is shared amongst all conversations
329 and should be generated before the fork, while other context is
330 specific to a particular conversation and should be generated
331 *after* the fork, in generate_process_local_config().
337 badpassword_frequency=None,
338 prefer_kerberos=None,
347 self.netlogon_connection = None
350 self.prefer_kerberos = prefer_kerberos
352 self.base_dn = base_dn
354 self.statsdir = statsdir
355 self.global_tempdir = tempdir
356 self.domain_sid = domain_sid
357 self.realm = lp.get('realm')
359 # Bad password attempt controls
360 self.badpassword_frequency = badpassword_frequency
361 self.last_lsarpc_bad = False
362 self.last_lsarpc_named_bad = False
363 self.last_simple_bind_bad = False
364 self.last_bind_bad = False
365 self.last_srvsvc_bad = False
366 self.last_drsuapi_bad = False
367 self.last_netlogon_bad = False
368 self.last_samlogon_bad = False
369 self.generate_ldap_search_tables()
371 def generate_ldap_search_tables(self):
372 session = system_session()
374 db = SamDB(url="ldap://%s" % self.server,
375 session_info=session,
376 credentials=self.creds,
379 res = db.search(db.domain_dn(),
380 scope=ldb.SCOPE_SUBTREE,
381 controls=["paged_results:1:1000"],
384 # find a list of dns for each pattern
385 # e.g. CN,CN,CN,DC,DC
387 attribute_clue_map = {
393 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
394 dns = dn_map.setdefault(pattern, [])
396 if dn.startswith('CN=NTDS Settings,'):
397 attribute_clue_map['invocationId'].append(dn)
399 # extend the map in case we are working with a different
400 # number of DC components.
401 # for k, v in self.dn_map.items():
402 # print >>sys.stderr, k, len(v)
404 for k in list(dn_map.keys()):
408 while p[-3:] == ',DC':
412 if p != k and p in dn_map:
413 print('dn_map collison %s %s' % (k, p),
416 dn_map[p] = dn_map[k]
419 self.attribute_clue_map = attribute_clue_map
421 def generate_process_local_config(self, account, conversation):
422 self.ldap_connections = []
423 self.dcerpc_connections = []
424 self.lsarpc_connections = []
425 self.lsarpc_connections_named = []
426 self.drsuapi_connections = []
427 self.srvsvc_connections = []
428 self.samr_contexts = []
429 self.netbios_name = account.netbios_name
430 self.machinepass = account.machinepass
431 self.username = account.username
432 self.userpass = account.userpass
434 self.tempdir = mk_masked_dir(self.global_tempdir,
436 conversation.conversation_id)
438 self.lp.set("private dir", self.tempdir)
439 self.lp.set("lock dir", self.tempdir)
440 self.lp.set("state directory", self.tempdir)
441 self.lp.set("tls verify peer", "no_check")
443 # If the domain was not specified, check for the environment
445 if self.domain is None:
446 self.domain = os.environ["DOMAIN"]
448 self.remoteAddress = "/root/ncalrpc_as_system"
449 self.samlogon_dn = ("cn=%s,%s" %
450 (self.netbios_name, self.ou))
451 self.user_dn = ("cn=%s,%s" %
452 (self.username, self.ou))
454 self.generate_machine_creds()
455 self.generate_user_creds()
457 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
458 """Execute the supplied logon function, randomly choosing the
461 Based on the frequency in badpassword_frequency randomly perform the
462 function with the supplied bad credentials.
463 If run with bad credentials, the function is re-run with the good
465 failed_last_time is used to prevent consecutive bad credential
466 attempts. So the over all bad credential frequency will be lower
467 than that requested, but not significantly.
469 if not failed_last_time:
470 if (self.badpassword_frequency and self.badpassword_frequency > 0
471 and random.random() < self.badpassword_frequency):
475 # Ignore any exceptions as the operation may fail
476 # as it's being performed with bad credentials
478 failed_last_time = True
480 failed_last_time = False
483 return (result, failed_last_time)
485 def generate_user_creds(self):
486 """Generate the conversation specific user Credentials.
488 Each Conversation has an associated user account used to simulate
489 any non Administrative user traffic.
491 Generates user credentials with good and bad passwords and ldap
492 simple bind credentials with good and bad passwords.
494 self.user_creds = Credentials()
495 self.user_creds.guess(self.lp)
496 self.user_creds.set_workstation(self.netbios_name)
497 self.user_creds.set_password(self.userpass)
498 self.user_creds.set_username(self.username)
499 self.user_creds.set_domain(self.domain)
500 if self.prefer_kerberos:
501 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
503 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
505 self.user_creds_bad = Credentials()
506 self.user_creds_bad.guess(self.lp)
507 self.user_creds_bad.set_workstation(self.netbios_name)
508 self.user_creds_bad.set_password(self.userpass[:-4])
509 self.user_creds_bad.set_username(self.username)
510 if self.prefer_kerberos:
511 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
513 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
515 # Credentials for ldap simple bind.
516 self.simple_bind_creds = Credentials()
517 self.simple_bind_creds.guess(self.lp)
518 self.simple_bind_creds.set_workstation(self.netbios_name)
519 self.simple_bind_creds.set_password(self.userpass)
520 self.simple_bind_creds.set_username(self.username)
521 self.simple_bind_creds.set_gensec_features(
522 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
523 if self.prefer_kerberos:
524 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
526 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
527 self.simple_bind_creds.set_bind_dn(self.user_dn)
529 self.simple_bind_creds_bad = Credentials()
530 self.simple_bind_creds_bad.guess(self.lp)
531 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
532 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
533 self.simple_bind_creds_bad.set_username(self.username)
534 self.simple_bind_creds_bad.set_gensec_features(
535 self.simple_bind_creds_bad.get_gensec_features() |
537 if self.prefer_kerberos:
538 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
540 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
541 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
543 def generate_machine_creds(self):
544 """Generate the conversation specific machine Credentials.
546 Each Conversation has an associated machine account.
548 Generates machine credentials with good and bad passwords.
551 self.machine_creds = Credentials()
552 self.machine_creds.guess(self.lp)
553 self.machine_creds.set_workstation(self.netbios_name)
554 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
555 self.machine_creds.set_password(self.machinepass)
556 self.machine_creds.set_username(self.netbios_name + "$")
557 self.machine_creds.set_domain(self.domain)
558 if self.prefer_kerberos:
559 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
561 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
563 self.machine_creds_bad = Credentials()
564 self.machine_creds_bad.guess(self.lp)
565 self.machine_creds_bad.set_workstation(self.netbios_name)
566 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
567 self.machine_creds_bad.set_password(self.machinepass[:-4])
568 self.machine_creds_bad.set_username(self.netbios_name + "$")
569 if self.prefer_kerberos:
570 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
572 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
574 def get_matching_dn(self, pattern, attributes=None):
575 # If the pattern is an empty string, we assume ROOTDSE,
576 # Otherwise we try adding or removing DC suffixes, then
577 # shorter leading patterns until we hit one.
578 # e.g if there is no CN,CN,CN,CN,DC,DC
579 # we first try CN,CN,CN,CN,DC
580 # and CN,CN,CN,CN,DC,DC,DC
581 # then change to CN,CN,CN,DC,DC
582 # and as last resort we use the base_dn
583 attr_clue = self.attribute_clue_map.get(attributes)
585 return random.choice(attr_clue)
587 pattern = pattern.upper()
589 if pattern in self.dn_map:
590 return random.choice(self.dn_map[pattern])
591 # chop one off the front and try it all again.
592 pattern = pattern[3:]
596 def get_dcerpc_connection(self, new=False):
597 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
598 if self.dcerpc_connections and not new:
599 return self.dcerpc_connections[-1]
600 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
602 self.dcerpc_connections.append(c)
605 def get_srvsvc_connection(self, new=False):
606 if self.srvsvc_connections and not new:
607 return self.srvsvc_connections[-1]
610 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
614 (c, self.last_srvsvc_bad) = \
615 self.with_random_bad_credentials(connect,
618 self.last_srvsvc_bad)
620 self.srvsvc_connections.append(c)
623 def get_lsarpc_connection(self, new=False):
624 if self.lsarpc_connections and not new:
625 return self.lsarpc_connections[-1]
628 binding_options = 'schannel,seal,sign'
629 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
630 (self.server, binding_options),
634 (c, self.last_lsarpc_bad) = \
635 self.with_random_bad_credentials(connect,
637 self.machine_creds_bad,
638 self.last_lsarpc_bad)
640 self.lsarpc_connections.append(c)
643 def get_lsarpc_named_pipe_connection(self, new=False):
644 if self.lsarpc_connections_named and not new:
645 return self.lsarpc_connections_named[-1]
648 return lsa.lsarpc("ncacn_np:%s" % (self.server),
652 (c, self.last_lsarpc_named_bad) = \
653 self.with_random_bad_credentials(connect,
655 self.machine_creds_bad,
656 self.last_lsarpc_named_bad)
658 self.lsarpc_connections_named.append(c)
661 def get_drsuapi_connection_pair(self, new=False, unbind=False):
662 """get a (drs, drs_handle) tuple"""
663 if self.drsuapi_connections and not new:
664 c = self.drsuapi_connections[-1]
668 binding_options = 'seal'
669 binding_string = "ncacn_ip_tcp:%s[%s]" %\
670 (self.server, binding_options)
671 return drsuapi.drsuapi(binding_string, self.lp, creds)
673 (drs, self.last_drsuapi_bad) = \
674 self.with_random_bad_credentials(connect,
677 self.last_drsuapi_bad)
679 (drs_handle, supported_extensions) = drs_DsBind(drs)
680 c = (drs, drs_handle)
681 self.drsuapi_connections.append(c)
684 def get_ldap_connection(self, new=False, simple=False):
685 if self.ldap_connections and not new:
686 return self.ldap_connections[-1]
688 def simple_bind(creds):
690 To run simple bind against Windows, we need to run
691 following commands in PowerShell:
693 Install-windowsfeature ADCS-Cert-Authority
694 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
698 return SamDB('ldaps://%s' % self.server,
702 def sasl_bind(creds):
703 return SamDB('ldap://%s' % self.server,
707 (samdb, self.last_simple_bind_bad) = \
708 self.with_random_bad_credentials(simple_bind,
709 self.simple_bind_creds,
710 self.simple_bind_creds_bad,
711 self.last_simple_bind_bad)
713 (samdb, self.last_bind_bad) = \
714 self.with_random_bad_credentials(sasl_bind,
719 self.ldap_connections.append(samdb)
722 def get_samr_context(self, new=False):
723 if not self.samr_contexts or new:
724 self.samr_contexts.append(
725 SamrContext(self.server, lp=self.lp, creds=self.creds))
726 return self.samr_contexts[-1]
728 def get_netlogon_connection(self):
730 if self.netlogon_connection:
731 return self.netlogon_connection
734 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
738 (c, self.last_netlogon_bad) = \
739 self.with_random_bad_credentials(connect,
741 self.machine_creds_bad,
742 self.last_netlogon_bad)
743 self.netlogon_connection = c
746 def guess_a_dns_lookup(self):
747 return (self.realm, 'A')
749 def get_authenticator(self):
750 auth = self.machine_creds.new_client_authenticator()
751 current = netr_Authenticator()
752 current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
753 current.timestamp = auth["timestamp"]
755 subsequent = netr_Authenticator()
756 return (current, subsequent)
759 class SamrContext(object):
760 """State/Context associated with a samr connection.
762 def __init__(self, server, lp=None, creds=None):
763 self.connection = None
765 self.domain_handle = None
766 self.domain_sid = None
767 self.group_handle = None
768 self.user_handle = None
774 def get_connection(self):
775 if not self.connection:
776 self.connection = samr.samr(
777 "ncacn_ip_tcp:%s[seal]" % (self.server),
779 credentials=self.creds)
781 return self.connection
783 def get_handle(self):
785 c = self.get_connection()
786 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
790 class Conversation(object):
791 """Details of a converation between a simulated client and a server."""
792 def __init__(self, start_time=None, endpoints=None, seq=(),
793 conversation_id=None):
794 self.start_time = start_time
795 self.endpoints = endpoints
797 self.msg = random_colour_print(endpoints)
798 self.client_balance = 0.0
799 self.conversation_id = conversation_id
801 self.add_short_packet(*p)
803 def __cmp__(self, other):
804 if self.start_time is None:
805 if other.start_time is None:
808 if other.start_time is None:
810 return self.start_time - other.start_time
812 def add_packet(self, packet):
813 """Add a packet object to this conversation, making a local copy with
814 a conversation-relative timestamp."""
817 if self.start_time is None:
818 self.start_time = p.timestamp
820 if self.endpoints is None:
821 self.endpoints = p.endpoints
823 if p.endpoints != self.endpoints:
824 raise FakePacketError("Conversation endpoints %s don't match"
825 "packet endpoints %s" %
826 (self.endpoints, p.endpoints))
828 p.timestamp -= self.start_time
830 if p.src == p.endpoints[0]:
831 self.client_balance -= p.client_score()
833 self.client_balance += p.client_score()
835 if p.is_really_a_packet():
836 self.packets.append(p)
838 def add_short_packet(self, timestamp, protocol, opcode, extra,
840 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
841 (possibly empty) list of extra data. If client is True, assume
842 this packet is from the client to the server.
844 src, dest = self.guess_client_server()
846 src, dest = dest, src
847 key = (protocol, opcode)
848 desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
849 if protocol in IP_PROTOCOLS:
850 ip_protocol = IP_PROTOCOLS[protocol]
853 packet = Packet(timestamp - self.start_time, ip_protocol,
855 protocol, opcode, desc, extra)
856 # XXX we're assuming the timestamp is already adjusted for
858 # XXX should we adjust client balance for guessed packets?
859 if packet.src == packet.endpoints[0]:
860 self.client_balance -= packet.client_score()
862 self.client_balance += packet.client_score()
863 if packet.is_really_a_packet():
864 self.packets.append(packet)
867 return ("<Conversation %s %s starting %.3f %d packets>" %
868 (self.conversation_id, self.endpoints, self.start_time,
874 return iter(self.packets)
877 return len(self.packets)
879 def get_duration(self):
880 if len(self.packets) < 2:
882 return self.packets[-1].timestamp - self.packets[0].timestamp
884 def replay_as_summary_lines(self):
886 for p in self.packets:
887 lines.append(p.as_summary(self.start_time))
890 def replay_in_fork_with_delay(self, start, context=None, account=None):
891 """Fork a new process and replay the conversation.
893 def signal_handler(signal, frame):
894 """Signal handler closes standard out and error.
896 Triggered by a sigterm, ensures that the log messages are flushed
897 to disk and not lost.
904 now = time.time() - start
906 # we are replaying strictly in order, so it is safe to sleep
907 # in the main process if the gap is big enough. This reduces
908 # the number of concurrent threads, which allows us to make
910 if gap > 0.15 and False:
911 print("sleeping for %f in main process" % (gap - 0.1),
913 time.sleep(gap - 0.1)
914 now = time.time() - start
916 print("gap is now %f" % gap, file=sys.stderr)
918 self.conversation_id = next(context.next_conversation_id)
923 signal.signal(signal.SIGTERM, signal_handler)
924 # we must never return, or we'll end up running parts of the
925 # parent's clean-up code. So we work in a try...finally, and
926 # try to print any exceptions.
929 context.generate_process_local_config(account, self)
932 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
933 self.conversation_id)
935 sys.stdout = open(filename, 'w')
937 sleep_time = gap - SLEEP_OVERHEAD
939 time.sleep(sleep_time)
941 miss = t - (time.time() - start)
942 self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
945 print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
947 traceback.print_exc(sys.stderr)
953 def replay(self, context=None):
956 for p in self.packets:
957 now = time.time() - start
958 gap = p.timestamp - now
959 sleep_time = gap - SLEEP_OVERHEAD
961 time.sleep(sleep_time)
963 miss = p.timestamp - (time.time() - start)
965 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
968 p.play(self, context)
970 def guess_client_server(self, server_clue=None):
971 """Have a go at deciding who is the server and who is the client.
972 returns (client, server)
974 a, b = self.endpoints
976 if self.client_balance < 0:
979 # in the absense of a clue, we will fall through to assuming
980 # the lowest number is the server (which is usually true).
982 if self.client_balance == 0 and server_clue == b:
987 def forget_packets_outside_window(self, s, e):
988 """Prune any packets outside the timne window we're interested in
990 :param s: start of the window
991 :param e: end of the window
993 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
994 self.start_time = self.packets[0].timestamp if self.packets else None
996 def renormalise_times(self, start_time):
997 """Adjust the packet start times relative to the new start time."""
998 for p in self.packets:
999 p.timestamp -= start_time
1001 if self.start_time is not None:
1002 self.start_time -= start_time
1005 class DnsHammer(Conversation):
1006 """A lightweight conversation that generates a lot of dns:0 packets on
1009 def __init__(self, dns_rate, duration):
1010 n = int(dns_rate * duration)
1011 self.times = [random.uniform(0, duration) for i in range(n)]
1013 self.rate = dns_rate
1014 self.duration = duration
1016 self.msg = random_colour_print()
1019 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1020 (len(self.times), self.duration, self.rate))
1022 def replay_in_fork_with_delay(self, start, context=None, account=None):
1023 return Conversation.replay_in_fork_with_delay(self,
1028 def replay(self, context=None):
1030 fn = traffic_packets.packet_dns_0
1031 for t in self.times:
1032 now = time.time() - start
1034 sleep_time = gap - SLEEP_OVERHEAD
1036 time.sleep(sleep_time)
1039 miss = t - (time.time() - start)
1040 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1044 packet_start = time.time()
1046 fn(self, self, context)
1048 duration = end - packet_start
1049 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1050 except Exception as e:
1052 duration = end - packet_start
1053 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1056 def ingest_summaries(files, dns_mode='count'):
1057 """Load a summary traffic summary file and generated Converations from it.
1060 dns_counts = defaultdict(int)
1063 if isinstance(f, str):
1065 print("Ingesting %s" % (f.name,), file=sys.stderr)
1067 p = Packet.from_line(line)
1068 if p.protocol == 'dns' and dns_mode != 'include':
1069 dns_counts[p.opcode] += 1
1078 start_time = min(p.timestamp for p in packets)
1079 last_packet = max(p.timestamp for p in packets)
1081 print("gathering packets into conversations", file=sys.stderr)
1082 conversations = OrderedDict()
1083 for i, p in enumerate(packets):
1084 p.timestamp -= start_time
1085 c = conversations.get(p.endpoints)
1087 c = Conversation(conversation_id=(i + 2))
1088 conversations[p.endpoints] = c
1091 # We only care about conversations with actual traffic, so we
1092 # filter out conversations with nothing to say. We do that here,
1093 # rather than earlier, because those empty packets contain useful
1094 # hints as to which end of the conversation was the client.
1095 conversation_list = []
1096 for c in conversations.values():
1098 conversation_list.append(c)
1100 # This is obviously not correct, as many conversations will appear
1101 # to start roughly simultaneously at the beginning of the snapshot.
1102 # To which we say: oh well, so be it.
1103 duration = float(last_packet - start_time)
1104 mean_interval = len(conversations) / duration
1106 return conversation_list, mean_interval, duration, dns_counts
1109 def guess_server_address(conversations):
1110 # we guess the most common address.
1111 addresses = Counter()
1112 for c in conversations:
1113 addresses.update(c.endpoints)
1115 return addresses.most_common(1)[0]
1118 def stringify_keys(x):
1120 for k, v in x.items():
1126 def unstringify_keys(x):
1128 for k, v in x.items():
1129 t = tuple(str(k).split('\t'))
1134 class TrafficModel(object):
1135 def __init__(self, n=3):
1137 self.query_details = {}
1139 self.dns_opcounts = defaultdict(int)
1140 self.cumulative_duration = 0.0
1141 self.packet_rate = [0, 1]
1143 def learn(self, conversations, dns_opcounts={}):
1146 key = (NON_PACKET,) * (self.n - 1)
1148 server = guess_server_address(conversations)
1150 for k, v in dns_opcounts.items():
1151 self.dns_opcounts[k] += v
1153 if len(conversations) > 1:
1154 first = conversations[0].start_time
1157 for c in conversations:
1159 last = max(last, c.packets[-1].timestamp)
1161 self.packet_rate[0] = total
1162 self.packet_rate[1] = last - first
1164 for c in conversations:
1165 client, server = c.guess_client_server(server)
1166 cum_duration += c.get_duration()
1167 key = (NON_PACKET,) * (self.n - 1)
1172 elapsed = p.timestamp - prev
1174 if elapsed > WAIT_THRESHOLD:
1175 # add the wait as an extra state
1176 wait = 'wait:%d' % (math.log(max(1.0,
1177 elapsed * WAIT_SCALE)))
1178 self.ngrams.setdefault(key, []).append(wait)
1179 key = key[1:] + (wait,)
1181 short_p = p.as_packet_type()
1182 self.query_details.setdefault(short_p,
1183 []).append(tuple(p.extra))
1184 self.ngrams.setdefault(key, []).append(short_p)
1185 key = key[1:] + (short_p,)
1187 self.cumulative_duration += cum_duration
1189 self.ngrams.setdefault(key, []).append(NON_PACKET)
1193 for k, v in self.ngrams.items():
1195 ngrams[k] = dict(Counter(v))
1198 for k, v in self.query_details.items():
1199 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1204 'query_details': query_details,
1205 'cumulative_duration': self.cumulative_duration,
1206 'packet_rate': self.packet_rate,
1207 'version': CURRENT_MODEL_VERSION
1209 d['dns'] = self.dns_opcounts
1211 if isinstance(f, str):
1214 json.dump(d, f, indent=2)
1217 if isinstance(f, str):
1223 version = d["version"]
1224 if version < REQUIRED_MODEL_VERSION:
1225 raise ValueError("the model file is version %d; "
1226 "version %d is required" %
1227 (version, REQUIRED_MODEL_VERSION))
1229 raise ValueError("the model file lacks a version number; "
1230 "version %d is required" %
1231 (REQUIRED_MODEL_VERSION))
1233 for k, v in d['ngrams'].items():
1234 k = tuple(str(k).split('\t'))
1235 values = self.ngrams.setdefault(k, [])
1236 for p, count in v.items():
1237 values.extend([str(p)] * count)
1240 for k, v in d['query_details'].items():
1241 values = self.query_details.setdefault(str(k), [])
1242 for p, count in v.items():
1244 values.extend([()] * count)
1246 values.extend([tuple(str(p).split('\t'))] * count)
1250 for k, v in d['dns'].items():
1251 self.dns_opcounts[k] += v
1253 self.cumulative_duration = d['cumulative_duration']
1254 self.packet_rate = d['packet_rate']
1256 def construct_conversation_sequence(self, timestamp=0.0,
1260 """Construct an individual conversation packet sequence from the
1264 key = (NON_PACKET,) * (self.n - 1)
1265 if ignore_before is None:
1266 ignore_before = timestamp - 1
1269 p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1273 if p in self.query_details:
1274 extra = random.choice(self.query_details[p])
1278 protocol, opcode = p.split(':', 1)
1279 if protocol == 'wait':
1280 log_wait_time = int(opcode) + random.random()
1281 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1284 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1285 wait = math.exp(log_wait) / replay_speed
1287 if hard_stop is not None and timestamp > hard_stop:
1289 if timestamp >= ignore_before:
1290 c.append((timestamp, protocol, opcode, extra))
1292 key = key[1:] + (p,)
1296 def generate_conversations(self, scale, duration, replay_speed=1,
1297 server=1, client=2):
1298 """Generate a list of conversations from the model."""
1300 # We run the simulation for ten times as long as our desired
1301 # duration, and take the section at the end.
1302 lead_in = 9 * duration
1303 rate_n, rate_t = self.packet_rate
1304 target_packets = int(duration * scale * rate_n / rate_t)
1309 while n_packets < target_packets:
1310 start = random.uniform(-lead_in, duration)
1311 c = self.construct_conversation_sequence(start,
1313 replay_speed=replay_speed,
1315 conversations.append(c)
1318 print(("we have %d packets (target %d) in %d conversations at scale %f"
1319 % (n_packets, target_packets, len(conversations), scale)),
1321 conversations.sort() # sorts by first element == start time
1322 return seq_to_conversations(conversations)
1325 def seq_to_conversations(seq, server=1, client=2):
1329 c = Conversation(s[0][0], (server, client), s)
1331 conversations.append(c)
1332 return conversations
1337 'rpc_netlogon': '06',
1338 'kerberos': '06', # ratio 16248:258
1349 'smb_netlogon': '11',
1355 ('browser', '0x01'): 'Host Announcement (0x01)',
1356 ('browser', '0x02'): 'Request Announcement (0x02)',
1357 ('browser', '0x08'): 'Browser Election Request (0x08)',
1358 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1359 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1360 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1361 ('cldap', '3'): 'searchRequest',
1362 ('cldap', '5'): 'searchResDone',
1363 ('dcerpc', '0'): 'Request',
1364 ('dcerpc', '11'): 'Bind',
1365 ('dcerpc', '12'): 'Bind_ack',
1366 ('dcerpc', '13'): 'Bind_nak',
1367 ('dcerpc', '14'): 'Alter_context',
1368 ('dcerpc', '15'): 'Alter_context_resp',
1369 ('dcerpc', '16'): 'AUTH3',
1370 ('dcerpc', '2'): 'Response',
1371 ('dns', '0'): 'query',
1372 ('dns', '1'): 'response',
1373 ('drsuapi', '0'): 'DsBind',
1374 ('drsuapi', '12'): 'DsCrackNames',
1375 ('drsuapi', '13'): 'DsWriteAccountSpn',
1376 ('drsuapi', '1'): 'DsUnbind',
1377 ('drsuapi', '2'): 'DsReplicaSync',
1378 ('drsuapi', '3'): 'DsGetNCChanges',
1379 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1380 ('epm', '3'): 'Map',
1381 ('kerberos', ''): '',
1382 ('ldap', '0'): 'bindRequest',
1383 ('ldap', '1'): 'bindResponse',
1384 ('ldap', '2'): 'unbindRequest',
1385 ('ldap', '3'): 'searchRequest',
1386 ('ldap', '4'): 'searchResEntry',
1387 ('ldap', '5'): 'searchResDone',
1388 ('ldap', ''): '*** Unknown ***',
1389 ('lsarpc', '14'): 'lsa_LookupNames',
1390 ('lsarpc', '15'): 'lsa_LookupSids',
1391 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1392 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1393 ('lsarpc', '6'): 'lsa_OpenPolicy',
1394 ('lsarpc', '76'): 'lsa_LookupSids3',
1395 ('lsarpc', '77'): 'lsa_LookupNames4',
1396 ('nbns', '0'): 'query',
1397 ('nbns', '1'): 'response',
1398 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1399 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1400 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1401 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1402 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1403 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1404 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1405 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1406 ('samr', '0',): 'Connect',
1407 ('samr', '16'): 'GetAliasMembership',
1408 ('samr', '17'): 'LookupNames',
1409 ('samr', '18'): 'LookupRids',
1410 ('samr', '19'): 'OpenGroup',
1411 ('samr', '1'): 'Close',
1412 ('samr', '25'): 'QueryGroupMember',
1413 ('samr', '34'): 'OpenUser',
1414 ('samr', '36'): 'QueryUserInfo',
1415 ('samr', '39'): 'GetGroupsForUser',
1416 ('samr', '3'): 'QuerySecurity',
1417 ('samr', '5'): 'LookupDomain',
1418 ('samr', '64'): 'Connect5',
1419 ('samr', '6'): 'EnumDomains',
1420 ('samr', '7'): 'OpenDomain',
1421 ('samr', '8'): 'QueryDomainInfo',
1422 ('smb', '0x04'): 'Close (0x04)',
1423 ('smb', '0x24'): 'Locking AndX (0x24)',
1424 ('smb', '0x2e'): 'Read AndX (0x2e)',
1425 ('smb', '0x32'): 'Trans2 (0x32)',
1426 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1427 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1428 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1429 ('smb', '0x74'): 'Logoff AndX (0x74)',
1430 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1431 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1432 ('smb2', '0'): 'NegotiateProtocol',
1433 ('smb2', '11'): 'Ioctl',
1434 ('smb2', '14'): 'Find',
1435 ('smb2', '16'): 'GetInfo',
1436 ('smb2', '18'): 'Break',
1437 ('smb2', '1'): 'SessionSetup',
1438 ('smb2', '2'): 'SessionLogoff',
1439 ('smb2', '3'): 'TreeConnect',
1440 ('smb2', '4'): 'TreeDisconnect',
1441 ('smb2', '5'): 'Create',
1442 ('smb2', '6'): 'Close',
1443 ('smb2', '8'): 'Read',
1444 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1445 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1446 'user unknown (0x17)'),
1447 ('srvsvc', '16'): 'NetShareGetInfo',
1448 ('srvsvc', '21'): 'NetSrvGetInfo',
1452 def expand_short_packet(p, timestamp, src, dest, extra):
1453 protocol, opcode = p.split(':', 1)
1454 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1455 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1457 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1459 return '\t'.join(line)
1462 def replay(conversations,
1471 context = ReplayContext(server=host,
1476 if len(accounts) < len(conversations):
1477 print(("we have %d accounts but %d conversations" %
1478 (accounts, conversations)), file=sys.stderr)
1481 sorted(conversations, key=lambda x: x.start_time, reverse=True),
1484 # Set the process group so that the calling scripts are not killed
1485 # when the forked child processes are killed.
1490 if duration is None:
1491 # end 1 second after the last packet of the last conversation
1492 # to start. Conversations other than the last could still be
1493 # going, but we don't care.
1494 duration = cstack[0][0].packets[-1].timestamp + 1.0
1495 print("We will stop after %.1f seconds" % duration,
1498 end = start + duration
1500 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1501 % (len(conversations), duration))
1505 dns_hammer = DnsHammer(dns_rate, duration)
1506 cstack.append((dns_hammer, None))
1510 # we spawn a batch, wait for finishers, then spawn another
1512 batch_end = min(now + 2.0, end)
1516 c, account = cstack.pop()
1517 if c.start_time + start > batch_end:
1518 cstack.append((c, account))
1522 pid = c.replay_in_fork_with_delay(start, context, account)
1526 fork_time += elapsed
1528 print("forked %s in pid %s (in %fs)" % (c, pid,
1533 print(("forked %d times in %f seconds (avg %f)" %
1534 (fork_n, fork_time, fork_time / fork_n)),
1537 debug(2, "no forks in batch ending %f" % batch_end)
1539 while time.time() < batch_end - 1.0:
1542 pid, status = os.waitpid(-1, os.WNOHANG)
1543 except OSError as e:
1544 if e.errno != 10: # no child processes
1548 c = children.pop(pid, None)
1549 print(("process %d finished conversation %s;"
1551 (pid, c, len(children))), file=sys.stderr)
1553 if time.time() >= end:
1554 print("time to stop", file=sys.stderr)
1558 print("EXCEPTION in parent", file=sys.stderr)
1559 traceback.print_exc()
1561 for s in (15, 15, 9):
1562 print(("killing %d children with -%d" %
1563 (len(children), s)), file=sys.stderr)
1564 for pid in children:
1567 except OSError as e:
1568 if e.errno != 3: # don't fail if it has already died
1571 end = time.time() + 1
1574 pid, status = os.waitpid(-1, os.WNOHANG)
1575 except OSError as e:
1579 c = children.pop(pid, None)
1580 print(("kill -%d %d KILLED conversation %s; "
1582 (s, pid, c, len(children))),
1584 if time.time() >= end:
1592 print("%d children are missing" % len(children),
1595 # there may be stragglers that were forked just as ^C was hit
1596 # and don't appear in the list of children. We can get them
1597 # with killpg, but that will also kill us, so this is^H^H would be
1598 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1599 # so as not to have to fuss around writing signal handlers.
1602 except KeyboardInterrupt:
1603 print("ignoring fake ^C", file=sys.stderr)
1606 def openLdb(host, creds, lp):
1607 session = system_session()
1608 ldb = SamDB(url="ldap://%s" % host,
1609 session_info=session,
1610 options=['modules:paged_searches'],
1616 def ou_name(ldb, instance_id):
1617 """Generate an ou name from the instance id"""
1618 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1622 def create_ou(ldb, instance_id):
1623 """Create an ou, all created user and machine accounts will belong to it.
1625 This allows all the created resources to be cleaned up easily.
1627 ou = ou_name(ldb, instance_id)
1629 ldb.add({"dn": ou.split(',', 1)[1],
1630 "objectclass": "organizationalunit"})
1631 except LdbError as e:
1632 (status, _) = e.args
1633 # ignore already exists
1638 "objectclass": "organizationalunit"})
1639 except LdbError as e:
1640 (status, _) = e.args
1641 # ignore already exists
1647 # ConversationAccounts holds details of the machine and user accounts
1648 # associated with a conversation.
1650 # We use a named tuple to reduce shared memory usage.
1651 ConversationAccounts = namedtuple('ConversationAccounts',
1658 def generate_replay_accounts(ldb, instance_id, number, password):
1659 """Generate a series of unique machine and user account names."""
1662 for i in range(1, number + 1):
1663 netbios_name = machine_name(instance_id, i)
1664 username = user_name(instance_id, i)
1666 account = ConversationAccounts(netbios_name, password, username,
1668 accounts.append(account)
1672 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1673 traffic_account=True):
1674 """Create a machine account via ldap."""
1676 ou = ou_name(ldb, instance_id)
1677 dn = "cn=%s,%s" % (netbios_name, ou)
1678 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1681 # we set these bits for the machine account otherwise the replayed
1682 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1683 account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1684 UF_SERVER_TRUST_ACCOUNT)
1687 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1691 "objectclass": "computer",
1692 "sAMAccountName": "%s$" % netbios_name,
1693 "userAccountControl": account_controls,
1694 "unicodePwd": utf16pw})
1697 def create_user_account(ldb, instance_id, username, userpass):
1698 """Create a user account via ldap."""
1699 ou = ou_name(ldb, instance_id)
1700 user_dn = "cn=%s,%s" % (username, ou)
1701 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1704 "objectclass": "user",
1705 "sAMAccountName": username,
1706 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1707 "unicodePwd": utf16pw
1710 # grant user write permission to do things like write account SPN
1711 sdutils = sd_utils.SDUtils(ldb)
1712 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1715 def create_group(ldb, instance_id, name):
1716 """Create a group via ldap."""
1718 ou = ou_name(ldb, instance_id)
1719 dn = "cn=%s,%s" % (name, ou)
1722 "objectclass": "group",
1723 "sAMAccountName": name,
1727 def user_name(instance_id, i):
1728 """Generate a user name based in the instance id"""
1729 return "STGU-%d-%d" % (instance_id, i)
1732 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1733 """Seach objectclass, return attr in a set"""
1735 expression="(objectClass={})".format(objectclass),
1738 return {str(obj[attr]) for obj in objs}
1741 def generate_users(ldb, instance_id, number, password):
1742 """Add users to the server"""
1743 existing_objects = search_objectclass(ldb, objectclass='user')
1745 for i in range(number, 0, -1):
1746 name = user_name(instance_id, i)
1747 if name not in existing_objects:
1748 create_user_account(ldb, instance_id, name, password)
1751 LOGGER.info("Created %u/%u users" % (users, number))
1756 def machine_name(instance_id, i, traffic_account=True):
1757 """Generate a machine account name from instance id."""
1759 # traffic accounts correspond to a given user, and use different
1760 # userAccountControl flags to ensure packets get processed correctly
1762 return "STGM-%d-%d" % (instance_id, i)
1764 # Otherwise we're just generating computer accounts to simulate a
1765 # semi-realistic network. These use the default computer
1766 # userAccountControl flags, so we use a different account name so that
1767 # we don't try to use them when generating packets
1768 return "PC-%d-%d" % (instance_id, i)
1771 def generate_machine_accounts(ldb, instance_id, number, password,
1772 traffic_account=True):
1773 """Add machine accounts to the server"""
1774 existing_objects = search_objectclass(ldb, objectclass='computer')
1776 for i in range(number, 0, -1):
1777 name = machine_name(instance_id, i, traffic_account)
1778 if name + "$" not in existing_objects:
1779 create_machine_account(ldb, instance_id, name, password,
1783 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1788 def group_name(instance_id, i):
1789 """Generate a group name from instance id."""
1790 return "STGG-%d-%d" % (instance_id, i)
1793 def generate_groups(ldb, instance_id, number):
1794 """Create the required number of groups on the server."""
1795 existing_objects = search_objectclass(ldb, objectclass='group')
1797 for i in range(number, 0, -1):
1798 name = group_name(instance_id, i)
1799 if name not in existing_objects:
1800 create_group(ldb, instance_id, name)
1802 if groups % 1000 == 0:
1803 LOGGER.info("Created %u/%u groups" % (groups, number))
1808 def clean_up_accounts(ldb, instance_id):
1809 """Remove the created accounts and groups from the server."""
1810 ou = ou_name(ldb, instance_id)
1812 ldb.delete(ou, ["tree_delete:1"])
1813 except LdbError as e:
1814 (status, _) = e.args
1815 # ignore does not exist
1820 def generate_users_and_groups(ldb, instance_id, password,
1821 number_of_users, number_of_groups,
1822 group_memberships, max_members,
1823 machine_accounts, traffic_accounts=True):
1824 """Generate the required users and groups, allocating the users to
1826 memberships_added = 0
1830 create_ou(ldb, instance_id)
1832 LOGGER.info("Generating dummy user accounts")
1833 users_added = generate_users(ldb, instance_id, number_of_users, password)
1835 LOGGER.info("Generating dummy machine accounts")
1836 computers_added = generate_machine_accounts(ldb, instance_id,
1837 machine_accounts, password,
1840 if number_of_groups > 0:
1841 LOGGER.info("Generating dummy groups")
1842 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1844 if group_memberships > 0:
1845 LOGGER.info("Assigning users to groups")
1846 assignments = GroupAssignments(number_of_groups,
1852 LOGGER.info("Adding users to groups")
1853 add_users_to_groups(ldb, instance_id, assignments)
1854 memberships_added = assignments.total()
1856 if (groups_added > 0 and users_added == 0 and
1857 number_of_groups != groups_added):
1858 LOGGER.warning("The added groups will contain no members")
1860 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1861 (users_added, computers_added, groups_added,
1865 class GroupAssignments(object):
1866 def __init__(self, number_of_groups, groups_added, number_of_users,
1867 users_added, group_memberships, max_members):
1870 self.generate_group_distribution(number_of_groups)
1871 self.generate_user_distribution(number_of_users, group_memberships)
1872 self.max_members = max_members
1873 self.assignments = defaultdict(list)
1874 self.assign_groups(number_of_groups, groups_added, number_of_users,
1875 users_added, group_memberships)
1877 def cumulative_distribution(self, weights):
1878 # make sure the probabilities conform to a cumulative distribution
1879 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1880 # probability a proportional share of 1.0. Higher probabilities get a
1881 # bigger share, so are more likely to be picked. We use the cumulative
1882 # value, so we can use random.random() as a simple index into the list
1884 total = sum(weights)
1889 for probability in weights:
1890 cumulative += probability
1891 dist.append(cumulative / total)
1894 def generate_user_distribution(self, num_users, num_memberships):
1895 """Probability distribution of a user belonging to a group.
1897 # Assign a weighted probability to each user. Use the Pareto
1898 # Distribution so that some users are in a lot of groups, and the
1899 # bulk of users are in only a few groups. If we're assigning a large
1900 # number of group memberships, use a higher shape. This means slightly
1901 # fewer outlying users that are in large numbers of groups. The aim is
1902 # to have no users belonging to more than ~500 groups.
1903 if num_memberships > 5000000:
1905 elif num_memberships > 2000000:
1907 elif num_memberships > 300000:
1913 for x in range(1, num_users + 1):
1914 p = random.paretovariate(shape)
1917 # convert the weights to a cumulative distribution between 0.0 and 1.0
1918 self.user_dist = self.cumulative_distribution(weights)
1920 def generate_group_distribution(self, n):
1921 """Probability distribution of a group containing a user."""
1923 # Assign a weighted probability to each user. Probability decreases
1924 # as the group-ID increases
1926 for x in range(1, n + 1):
1930 # convert the weights to a cumulative distribution between 0.0 and 1.0
1931 self.group_weights = weights
1932 self.group_dist = self.cumulative_distribution(weights)
1934 def generate_random_membership(self):
1935 """Returns a randomly generated user-group membership"""
1937 # the list items are cumulative distribution values between 0.0 and
1938 # 1.0, which makes random() a handy way to index the list to get a
1939 # weighted random user/group. (Here the user/group returned are
1940 # zero-based array indexes)
1941 user = bisect.bisect(self.user_dist, random.random())
1942 group = bisect.bisect(self.group_dist, random.random())
1946 def users_in_group(self, group):
1947 return self.assignments[group]
1949 def get_groups(self):
1950 return self.assignments.keys()
1952 def cap_group_membership(self, group, max_members):
1953 """Prevent the group's membership from exceeding the max specified"""
1954 num_members = len(self.assignments[group])
1955 if num_members >= max_members:
1956 LOGGER.info("Group {0} has {1} members".format(group, num_members))
1958 # remove this group and then recalculate the cumulative
1959 # distribution, so this group is no longer selected
1960 self.group_weights[group - 1] = 0
1961 new_dist = self.cumulative_distribution(self.group_weights)
1962 self.group_dist = new_dist
1964 def add_assignment(self, user, group):
1965 # the assignments are stored in a dictionary where key=group,
1966 # value=list-of-users-in-group (indexing by group-ID allows us to
1967 # optimize for DB membership writes)
1968 if user not in self.assignments[group]:
1969 self.assignments[group].append(user)
1972 # check if there'a cap on how big the groups can grow
1973 if self.max_members:
1974 self.cap_group_membership(group, self.max_members)
1976 def assign_groups(self, number_of_groups, groups_added,
1977 number_of_users, users_added, group_memberships):
1978 """Allocate users to groups.
1980 The intention is to have a few users that belong to most groups, while
1981 the majority of users belong to a few groups.
1983 A few groups will contain most users, with the remaining only having a
1987 if group_memberships <= 0:
1990 # Calculate the number of group menberships required
1991 group_memberships = math.ceil(
1992 float(group_memberships) *
1993 (float(users_added) / float(number_of_users)))
1995 if self.max_members:
1996 group_memberships = min(group_memberships,
1997 self.max_members * number_of_groups)
1999 existing_users = number_of_users - users_added - 1
2000 existing_groups = number_of_groups - groups_added - 1
2001 while self.total() < group_memberships:
2002 user, group = self.generate_random_membership()
2004 if group > existing_groups or user > existing_users:
2005 # the + 1 converts the array index to the corresponding
2006 # group or user number
2007 self.add_assignment(user + 1, group + 1)
2013 def add_users_to_groups(db, instance_id, assignments):
2014 """Takes the assignments of users to groups and applies them to the DB."""
2016 total = assignments.total()
2020 for group in assignments.get_groups():
2021 users_in_group = assignments.users_in_group(group)
2022 if len(users_in_group) == 0:
2025 # Split up the users into chunks, so we write no more than 1K at a
2026 # time. (Minimizing the DB modifies is more efficient, but writing
2027 # 10K+ users to a single group becomes inefficient memory-wise)
2028 for chunk in range(0, len(users_in_group), 1000):
2029 chunk_of_users = users_in_group[chunk:chunk + 1000]
2030 add_group_members(db, instance_id, group, chunk_of_users)
2032 added += len(chunk_of_users)
2035 LOGGER.info("Added %u/%u memberships" % (added, total))
2037 def add_group_members(db, instance_id, group, users_in_group):
2038 """Adds the given users to group specified."""
2040 ou = ou_name(db, instance_id)
2043 return("cn=%s,%s" % (name, ou))
2045 group_dn = build_dn(group_name(instance_id, group))
2047 m.dn = ldb.Dn(db, group_dn)
2049 for user in users_in_group:
2050 user_dn = build_dn(user_name(instance_id, user))
2051 idx = "member-" + str(user)
2052 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2057 def generate_stats(statsdir, timing_file):
2058 """Generate and print the summary stats for a run."""
2059 first = sys.float_info.max
2065 unique_converations = set()
2068 if timing_file is not None:
2069 tw = timing_file.write
2074 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2076 for filename in os.listdir(statsdir):
2077 path = os.path.join(statsdir, filename)
2078 with open(path, 'r') as f:
2081 fields = line.rstrip('\n').split('\t')
2082 conversation = fields[1]
2083 protocol = fields[2]
2084 packet_type = fields[3]
2085 latency = float(fields[4])
2086 first = min(float(fields[0]) - latency, first)
2087 last = max(float(fields[0]), last)
2089 if protocol not in latencies:
2090 latencies[protocol] = {}
2091 if packet_type not in latencies[protocol]:
2092 latencies[protocol][packet_type] = []
2094 latencies[protocol][packet_type].append(latency)
2096 if protocol not in failures:
2097 failures[protocol] = {}
2098 if packet_type not in failures[protocol]:
2099 failures[protocol][packet_type] = 0
2101 if fields[5] == 'True':
2105 failures[protocol][packet_type] += 1
2107 if conversation not in unique_converations:
2108 unique_converations.add(conversation)
2112 except (ValueError, IndexError):
2113 # not a valid line print and ignore
2114 print(line, file=sys.stderr)
2116 duration = last - first
2120 success_rate = successful / duration
2124 failure_rate = failed / duration
2126 print("Total conversations: %10d" % conversations)
2127 print("Successful operations: %10d (%.3f per second)"
2128 % (successful, success_rate))
2129 print("Failed operations: %10d (%.3f per second)"
2130 % (failed, failure_rate))
2132 print("Protocol Op Code Description "
2133 " Count Failed Mean Median "
2136 protocols = sorted(latencies.keys())
2137 for protocol in protocols:
2138 packet_types = sorted(latencies[protocol], key=opcode_key)
2139 for packet_type in packet_types:
2140 values = latencies[protocol][packet_type]
2141 values = sorted(values)
2143 failed = failures[protocol][packet_type]
2144 mean = sum(values) / count
2145 median = calc_percentile(values, 0.50)
2146 percentile = calc_percentile(values, 0.95)
2147 rng = values[-1] - values[0]
2149 desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2150 if sys.stdout.isatty:
2151 print("%-12s %4s %-35s %12d %12d %12.6f "
2152 "%12.6f %12.6f %12.6f %12.6f"
2164 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2178 """Sort key for the operation code to ensure that it sorts numerically"""
2180 return "%03d" % int(v)
2185 def calc_percentile(values, percentile):
2186 """Calculate the specified percentile from the list of values.
2188 Assumes the list is sorted in ascending order.
2193 k = (len(values) - 1) * percentile
2197 return values[int(k)]
2198 d0 = values[int(f)] * (c - k)
2199 d1 = values[int(c)] * (k - f)
2203 def mk_masked_dir(*path):
2204 """In a testenv we end up with 0777 directories that look an alarming
2205 green colour with ls. Use umask to avoid that."""
2206 # py3 os.mkdir can do this
2207 d = os.path.join(*path)
2208 mask = os.umask(0o077)