1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
29 from collections import OrderedDict, Counter, defaultdict, namedtuple
30 from samba.emulate import traffic_packets
31 from samba.samdb import SamDB
33 from ldb import LdbError
34 from samba.dcerpc import ClientConnection
35 from samba.dcerpc import security, drsuapi, lsa
36 from samba.dcerpc import netlogon
37 from samba.dcerpc.netlogon import netr_Authenticator
38 from samba.dcerpc import srvsvc
39 from samba.dcerpc import samr
40 from samba.drs_utils import drs_DsBind
42 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
43 from samba.auth import system_session
44 from samba.dsdb import (
46 UF_SERVER_TRUST_ACCOUNT,
47 UF_TRUSTED_FOR_DELEGATION,
48 UF_WORKSTATION_TRUST_ACCOUNT
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
59 # we don't use None, because it complicates [de]serialisation
63 ('dns', '0'): 1.0, # query
64 ('smb', '0x72'): 1.0, # Negotiate protocol
65 ('ldap', '0'): 1.0, # bind
66 ('ldap', '3'): 1.0, # searchRequest
67 ('ldap', '2'): 1.0, # unbindRequest
69 ('dcerpc', '11'): 1.0, # bind
70 ('dcerpc', '14'): 1.0, # Alter_context
71 ('nbns', '0'): 1.0, # query
75 ('dns', '1'): 1.0, # response
76 ('ldap', '1'): 1.0, # bind response
77 ('ldap', '4'): 1.0, # search result
78 ('ldap', '5'): 1.0, # search done
80 ('dcerpc', '12'): 1.0, # bind_ack
81 ('dcerpc', '13'): 1.0, # bind_nak
82 ('dcerpc', '15'): 1.0, # Alter_context response
85 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
88 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
89 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
91 # DEBUG_LEVEL can be changed by scripts with -d
94 LOGGER = get_samba_logger(name=__name__)
97 def debug(level, msg, *args):
98 """Print a formatted debug message to standard error.
101 :param level: The debug level, message will be printed if it is <= the
102 currently set debug level. The debug level can be set with
104 :param msg: The message to be logged, can contain C-Style format
106 :param args: The parameters required by the format specifiers
108 if level <= DEBUG_LEVEL:
110 print(msg, file=sys.stderr)
112 print(msg % tuple(args), file=sys.stderr)
115 def debug_lineno(*args):
116 """ Print an unformatted log message to stderr, contaning the line number
118 tb = traceback.extract_stack(limit=2)
119 print((" %s:" "\033[01;33m"
120 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
123 print(a, file=sys.stderr)
124 print(file=sys.stderr)
128 def random_colour_print(seeds):
129 """Return a function that prints a coloured line to stderr. The colour
130 of the line depends on a sort of hash of the integer arguments."""
137 prefix = "\033[38;5;%dm" % (18 + s)
142 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
147 print(a, file=sys.stderr)
152 class FakePacketError(Exception):
156 class Packet(object):
157 """Details of a network packet"""
158 __slots__ = ('timestamp',
168 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
169 protocol, opcode, desc, extra):
170 self.timestamp = timestamp
171 self.ip_protocol = ip_protocol
172 self.stream_number = stream_number
175 self.protocol = protocol
179 if self.src < self.dest:
180 self.endpoints = (self.src, self.dest)
182 self.endpoints = (self.dest, self.src)
185 def from_line(cls, line):
186 fields = line.rstrip('\n').split('\t')
197 timestamp = float(timestamp)
201 return cls(timestamp, ip_protocol, stream_number, src, dest,
202 protocol, opcode, desc, extra)
204 def as_summary(self, time_offset=0.0):
205 """Format the packet as a traffic_summary line.
207 extra = '\t'.join(self.extra)
208 t = self.timestamp + time_offset
209 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
212 self.stream_number or '',
221 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
222 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
223 self.stream_number, self.protocol, self.opcode, self.desc,
224 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
227 return "<Packet @%s>" % self
230 return self.__class__(self.timestamp,
240 def as_packet_type(self):
241 t = '%s:%s' % (self.protocol, self.opcode)
244 def client_score(self):
245 """A positive number means we think it is a client; a negative number
246 means we think it is a server. Zero means no idea. range: -1 to 1.
248 key = (self.protocol, self.opcode)
249 if key in CLIENT_CLUES:
250 return CLIENT_CLUES[key]
251 if key in SERVER_CLUES:
252 return -SERVER_CLUES[key]
255 def play(self, conversation, context):
256 """Send the packet over the network, if required.
258 Some packets are ignored, i.e. for protocols not handled,
259 server response messages, or messages that are generated by the
260 protocol layer associated with other packets.
262 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
264 fn = getattr(traffic_packets, fn_name)
266 except AttributeError as e:
267 print("Conversation(%s) Missing handler %s" %
268 (conversation.conversation_id, fn_name),
272 # Don't display a message for kerberos packets, they're not directly
273 # generated they're used to indicate kerberos should be used
274 if self.protocol != "kerberos":
275 debug(2, "Conversation(%s) Calling handler %s" %
276 (conversation.conversation_id, fn_name))
280 if fn(self, conversation, context):
281 # Only collect timing data for functions that generate
282 # network traffic, or fail
284 duration = end - start
285 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
286 (end, conversation.conversation_id, self.protocol,
287 self.opcode, duration))
288 except Exception as e:
290 duration = end - start
291 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
292 (end, conversation.conversation_id, self.protocol,
293 self.opcode, duration, e))
295 def __cmp__(self, other):
296 return self.timestamp - other.timestamp
298 def is_really_a_packet(self, missing_packet_stats=None):
299 return is_a_real_packet(self.protocol, self.opcode)
302 def is_a_real_packet(protocol, opcode):
303 """Is the packet one that can be ignored?
305 If so removing it will have no effect on the replay
307 if protocol in SKIPPED_PROTOCOLS:
308 # Ignore any packets for the protocols we're not interested in.
310 if protocol == "ldap" and opcode == '':
311 # skip ldap continuation packets
314 fn_name = 'packet_%s_%s' % (protocol, opcode)
315 fn = getattr(traffic_packets, fn_name, None)
317 LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
319 if fn is traffic_packets.null_packet:
324 class ReplayContext(object):
325 """State/Context for a conversation between an simulated client and a
326 server. Some of the context is shared amongst all conversations
327 and should be generated before the fork, while other context is
328 specific to a particular conversation and should be generated
329 *after* the fork, in generate_process_local_config().
335 badpassword_frequency=None,
336 prefer_kerberos=None,
345 self.netlogon_connection = None
348 self.prefer_kerberos = prefer_kerberos
350 self.base_dn = base_dn
352 self.statsdir = statsdir
353 self.global_tempdir = tempdir
354 self.domain_sid = domain_sid
355 self.realm = lp.get('realm')
357 # Bad password attempt controls
358 self.badpassword_frequency = badpassword_frequency
359 self.last_lsarpc_bad = False
360 self.last_lsarpc_named_bad = False
361 self.last_simple_bind_bad = False
362 self.last_bind_bad = False
363 self.last_srvsvc_bad = False
364 self.last_drsuapi_bad = False
365 self.last_netlogon_bad = False
366 self.last_samlogon_bad = False
367 self.generate_ldap_search_tables()
369 def generate_ldap_search_tables(self):
370 session = system_session()
372 db = SamDB(url="ldap://%s" % self.server,
373 session_info=session,
374 credentials=self.creds,
377 res = db.search(db.domain_dn(),
378 scope=ldb.SCOPE_SUBTREE,
379 controls=["paged_results:1:1000"],
382 # find a list of dns for each pattern
383 # e.g. CN,CN,CN,DC,DC
385 attribute_clue_map = {
391 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
392 dns = dn_map.setdefault(pattern, [])
394 if dn.startswith('CN=NTDS Settings,'):
395 attribute_clue_map['invocationId'].append(dn)
397 # extend the map in case we are working with a different
398 # number of DC components.
399 # for k, v in self.dn_map.items():
400 # print >>sys.stderr, k, len(v)
402 for k in list(dn_map.keys()):
406 while p[-3:] == ',DC':
410 if p != k and p in dn_map:
411 print('dn_map collison %s %s' % (k, p),
414 dn_map[p] = dn_map[k]
417 self.attribute_clue_map = attribute_clue_map
419 def generate_process_local_config(self, account, conversation):
420 self.ldap_connections = []
421 self.dcerpc_connections = []
422 self.lsarpc_connections = []
423 self.lsarpc_connections_named = []
424 self.drsuapi_connections = []
425 self.srvsvc_connections = []
426 self.samr_contexts = []
427 self.netbios_name = account.netbios_name
428 self.machinepass = account.machinepass
429 self.username = account.username
430 self.userpass = account.userpass
432 self.tempdir = mk_masked_dir(self.global_tempdir,
434 conversation.conversation_id)
436 self.lp.set("private dir", self.tempdir)
437 self.lp.set("lock dir", self.tempdir)
438 self.lp.set("state directory", self.tempdir)
439 self.lp.set("tls verify peer", "no_check")
441 # If the domain was not specified, check for the environment
443 if self.domain is None:
444 self.domain = os.environ["DOMAIN"]
446 self.remoteAddress = "/root/ncalrpc_as_system"
447 self.samlogon_dn = ("cn=%s,%s" %
448 (self.netbios_name, self.ou))
449 self.user_dn = ("cn=%s,%s" %
450 (self.username, self.ou))
452 self.generate_machine_creds()
453 self.generate_user_creds()
455 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
456 """Execute the supplied logon function, randomly choosing the
459 Based on the frequency in badpassword_frequency randomly perform the
460 function with the supplied bad credentials.
461 If run with bad credentials, the function is re-run with the good
463 failed_last_time is used to prevent consecutive bad credential
464 attempts. So the over all bad credential frequency will be lower
465 than that requested, but not significantly.
467 if not failed_last_time:
468 if (self.badpassword_frequency and self.badpassword_frequency > 0
469 and random.random() < self.badpassword_frequency):
473 # Ignore any exceptions as the operation may fail
474 # as it's being performed with bad credentials
476 failed_last_time = True
478 failed_last_time = False
481 return (result, failed_last_time)
483 def generate_user_creds(self):
484 """Generate the conversation specific user Credentials.
486 Each Conversation has an associated user account used to simulate
487 any non Administrative user traffic.
489 Generates user credentials with good and bad passwords and ldap
490 simple bind credentials with good and bad passwords.
492 self.user_creds = Credentials()
493 self.user_creds.guess(self.lp)
494 self.user_creds.set_workstation(self.netbios_name)
495 self.user_creds.set_password(self.userpass)
496 self.user_creds.set_username(self.username)
497 self.user_creds.set_domain(self.domain)
498 if self.prefer_kerberos:
499 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
501 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
503 self.user_creds_bad = Credentials()
504 self.user_creds_bad.guess(self.lp)
505 self.user_creds_bad.set_workstation(self.netbios_name)
506 self.user_creds_bad.set_password(self.userpass[:-4])
507 self.user_creds_bad.set_username(self.username)
508 if self.prefer_kerberos:
509 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
511 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
513 # Credentials for ldap simple bind.
514 self.simple_bind_creds = Credentials()
515 self.simple_bind_creds.guess(self.lp)
516 self.simple_bind_creds.set_workstation(self.netbios_name)
517 self.simple_bind_creds.set_password(self.userpass)
518 self.simple_bind_creds.set_username(self.username)
519 self.simple_bind_creds.set_gensec_features(
520 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
521 if self.prefer_kerberos:
522 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
524 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
525 self.simple_bind_creds.set_bind_dn(self.user_dn)
527 self.simple_bind_creds_bad = Credentials()
528 self.simple_bind_creds_bad.guess(self.lp)
529 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
530 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
531 self.simple_bind_creds_bad.set_username(self.username)
532 self.simple_bind_creds_bad.set_gensec_features(
533 self.simple_bind_creds_bad.get_gensec_features() |
535 if self.prefer_kerberos:
536 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
538 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
539 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
541 def generate_machine_creds(self):
542 """Generate the conversation specific machine Credentials.
544 Each Conversation has an associated machine account.
546 Generates machine credentials with good and bad passwords.
549 self.machine_creds = Credentials()
550 self.machine_creds.guess(self.lp)
551 self.machine_creds.set_workstation(self.netbios_name)
552 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
553 self.machine_creds.set_password(self.machinepass)
554 self.machine_creds.set_username(self.netbios_name + "$")
555 self.machine_creds.set_domain(self.domain)
556 if self.prefer_kerberos:
557 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
559 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
561 self.machine_creds_bad = Credentials()
562 self.machine_creds_bad.guess(self.lp)
563 self.machine_creds_bad.set_workstation(self.netbios_name)
564 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
565 self.machine_creds_bad.set_password(self.machinepass[:-4])
566 self.machine_creds_bad.set_username(self.netbios_name + "$")
567 if self.prefer_kerberos:
568 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
570 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
572 def get_matching_dn(self, pattern, attributes=None):
573 # If the pattern is an empty string, we assume ROOTDSE,
574 # Otherwise we try adding or removing DC suffixes, then
575 # shorter leading patterns until we hit one.
576 # e.g if there is no CN,CN,CN,CN,DC,DC
577 # we first try CN,CN,CN,CN,DC
578 # and CN,CN,CN,CN,DC,DC,DC
579 # then change to CN,CN,CN,DC,DC
580 # and as last resort we use the base_dn
581 attr_clue = self.attribute_clue_map.get(attributes)
583 return random.choice(attr_clue)
585 pattern = pattern.upper()
587 if pattern in self.dn_map:
588 return random.choice(self.dn_map[pattern])
589 # chop one off the front and try it all again.
590 pattern = pattern[3:]
594 def get_dcerpc_connection(self, new=False):
595 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
596 if self.dcerpc_connections and not new:
597 return self.dcerpc_connections[-1]
598 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
600 self.dcerpc_connections.append(c)
603 def get_srvsvc_connection(self, new=False):
604 if self.srvsvc_connections and not new:
605 return self.srvsvc_connections[-1]
608 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
612 (c, self.last_srvsvc_bad) = \
613 self.with_random_bad_credentials(connect,
616 self.last_srvsvc_bad)
618 self.srvsvc_connections.append(c)
621 def get_lsarpc_connection(self, new=False):
622 if self.lsarpc_connections and not new:
623 return self.lsarpc_connections[-1]
626 binding_options = 'schannel,seal,sign'
627 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
628 (self.server, binding_options),
632 (c, self.last_lsarpc_bad) = \
633 self.with_random_bad_credentials(connect,
635 self.machine_creds_bad,
636 self.last_lsarpc_bad)
638 self.lsarpc_connections.append(c)
641 def get_lsarpc_named_pipe_connection(self, new=False):
642 if self.lsarpc_connections_named and not new:
643 return self.lsarpc_connections_named[-1]
646 return lsa.lsarpc("ncacn_np:%s" % (self.server),
650 (c, self.last_lsarpc_named_bad) = \
651 self.with_random_bad_credentials(connect,
653 self.machine_creds_bad,
654 self.last_lsarpc_named_bad)
656 self.lsarpc_connections_named.append(c)
659 def get_drsuapi_connection_pair(self, new=False, unbind=False):
660 """get a (drs, drs_handle) tuple"""
661 if self.drsuapi_connections and not new:
662 c = self.drsuapi_connections[-1]
666 binding_options = 'seal'
667 binding_string = "ncacn_ip_tcp:%s[%s]" %\
668 (self.server, binding_options)
669 return drsuapi.drsuapi(binding_string, self.lp, creds)
671 (drs, self.last_drsuapi_bad) = \
672 self.with_random_bad_credentials(connect,
675 self.last_drsuapi_bad)
677 (drs_handle, supported_extensions) = drs_DsBind(drs)
678 c = (drs, drs_handle)
679 self.drsuapi_connections.append(c)
682 def get_ldap_connection(self, new=False, simple=False):
683 if self.ldap_connections and not new:
684 return self.ldap_connections[-1]
686 def simple_bind(creds):
688 To run simple bind against Windows, we need to run
689 following commands in PowerShell:
691 Install-windowsfeature ADCS-Cert-Authority
692 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
696 return SamDB('ldaps://%s' % self.server,
700 def sasl_bind(creds):
701 return SamDB('ldap://%s' % self.server,
705 (samdb, self.last_simple_bind_bad) = \
706 self.with_random_bad_credentials(simple_bind,
707 self.simple_bind_creds,
708 self.simple_bind_creds_bad,
709 self.last_simple_bind_bad)
711 (samdb, self.last_bind_bad) = \
712 self.with_random_bad_credentials(sasl_bind,
717 self.ldap_connections.append(samdb)
720 def get_samr_context(self, new=False):
721 if not self.samr_contexts or new:
722 self.samr_contexts.append(
723 SamrContext(self.server, lp=self.lp, creds=self.creds))
724 return self.samr_contexts[-1]
726 def get_netlogon_connection(self):
728 if self.netlogon_connection:
729 return self.netlogon_connection
732 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
736 (c, self.last_netlogon_bad) = \
737 self.with_random_bad_credentials(connect,
739 self.machine_creds_bad,
740 self.last_netlogon_bad)
741 self.netlogon_connection = c
744 def guess_a_dns_lookup(self):
745 return (self.realm, 'A')
747 def get_authenticator(self):
748 auth = self.machine_creds.new_client_authenticator()
749 current = netr_Authenticator()
750 current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
751 current.timestamp = auth["timestamp"]
753 subsequent = netr_Authenticator()
754 return (current, subsequent)
757 class SamrContext(object):
758 """State/Context associated with a samr connection.
760 def __init__(self, server, lp=None, creds=None):
761 self.connection = None
763 self.domain_handle = None
764 self.domain_sid = None
765 self.group_handle = None
766 self.user_handle = None
772 def get_connection(self):
773 if not self.connection:
774 self.connection = samr.samr(
775 "ncacn_ip_tcp:%s[seal]" % (self.server),
777 credentials=self.creds)
779 return self.connection
781 def get_handle(self):
783 c = self.get_connection()
784 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
788 class Conversation(object):
789 """Details of a converation between a simulated client and a server."""
790 def __init__(self, start_time=None, endpoints=None, seq=(),
791 conversation_id=None):
792 self.start_time = start_time
793 self.endpoints = endpoints
795 self.msg = random_colour_print(endpoints)
796 self.client_balance = 0.0
797 self.conversation_id = conversation_id
799 self.add_short_packet(*p)
801 def __cmp__(self, other):
802 if self.start_time is None:
803 if other.start_time is None:
806 if other.start_time is None:
808 return self.start_time - other.start_time
810 def add_packet(self, packet):
811 """Add a packet object to this conversation, making a local copy with
812 a conversation-relative timestamp."""
815 if self.start_time is None:
816 self.start_time = p.timestamp
818 if self.endpoints is None:
819 self.endpoints = p.endpoints
821 if p.endpoints != self.endpoints:
822 raise FakePacketError("Conversation endpoints %s don't match"
823 "packet endpoints %s" %
824 (self.endpoints, p.endpoints))
826 p.timestamp -= self.start_time
828 if p.src == p.endpoints[0]:
829 self.client_balance -= p.client_score()
831 self.client_balance += p.client_score()
833 if p.is_really_a_packet():
834 self.packets.append(p)
836 def add_short_packet(self, timestamp, protocol, opcode, extra,
838 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
839 (possibly empty) list of extra data. If client is True, assume
840 this packet is from the client to the server.
842 src, dest = self.guess_client_server()
844 src, dest = dest, src
845 key = (protocol, opcode)
846 desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
847 if protocol in IP_PROTOCOLS:
848 ip_protocol = IP_PROTOCOLS[protocol]
851 packet = Packet(timestamp - self.start_time, ip_protocol,
853 protocol, opcode, desc, extra)
854 # XXX we're assuming the timestamp is already adjusted for
856 # XXX should we adjust client balance for guessed packets?
857 if packet.src == packet.endpoints[0]:
858 self.client_balance -= packet.client_score()
860 self.client_balance += packet.client_score()
861 if packet.is_really_a_packet():
862 self.packets.append(packet)
865 return ("<Conversation %s %s starting %.3f %d packets>" %
866 (self.conversation_id, self.endpoints, self.start_time,
872 return iter(self.packets)
875 return len(self.packets)
877 def get_duration(self):
878 if len(self.packets) < 2:
880 return self.packets[-1].timestamp - self.packets[0].timestamp
882 def replay_as_summary_lines(self):
884 for p in self.packets:
885 lines.append(p.as_summary(self.start_time))
888 def replay_in_fork_with_delay(self, start, context=None, account=None):
889 """Fork a new process and replay the conversation.
891 def signal_handler(signal, frame):
892 """Signal handler closes standard out and error.
894 Triggered by a sigterm, ensures that the log messages are flushed
895 to disk and not lost.
902 now = time.time() - start
904 # we are replaying strictly in order, so it is safe to sleep
905 # in the main process if the gap is big enough. This reduces
906 # the number of concurrent threads, which allows us to make
908 if gap > 0.15 and False:
909 print("sleeping for %f in main process" % (gap - 0.1),
911 time.sleep(gap - 0.1)
912 now = time.time() - start
914 print("gap is now %f" % gap, file=sys.stderr)
916 self.conversation_id = next(context.next_conversation_id)
921 signal.signal(signal.SIGTERM, signal_handler)
922 # we must never return, or we'll end up running parts of the
923 # parent's clean-up code. So we work in a try...finally, and
924 # try to print any exceptions.
927 context.generate_process_local_config(account, self)
930 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
931 self.conversation_id)
933 sys.stdout = open(filename, 'w')
935 sleep_time = gap - SLEEP_OVERHEAD
937 time.sleep(sleep_time)
939 miss = t - (time.time() - start)
940 self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
943 print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
945 traceback.print_exc(sys.stderr)
951 def replay(self, context=None):
954 for p in self.packets:
955 now = time.time() - start
956 gap = p.timestamp - now
957 sleep_time = gap - SLEEP_OVERHEAD
959 time.sleep(sleep_time)
961 miss = p.timestamp - (time.time() - start)
963 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
966 p.play(self, context)
968 def guess_client_server(self, server_clue=None):
969 """Have a go at deciding who is the server and who is the client.
970 returns (client, server)
972 a, b = self.endpoints
974 if self.client_balance < 0:
977 # in the absense of a clue, we will fall through to assuming
978 # the lowest number is the server (which is usually true).
980 if self.client_balance == 0 and server_clue == b:
985 def forget_packets_outside_window(self, s, e):
986 """Prune any packets outside the timne window we're interested in
988 :param s: start of the window
989 :param e: end of the window
991 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
992 self.start_time = self.packets[0].timestamp if self.packets else None
994 def renormalise_times(self, start_time):
995 """Adjust the packet start times relative to the new start time."""
996 for p in self.packets:
997 p.timestamp -= start_time
999 if self.start_time is not None:
1000 self.start_time -= start_time
1003 class DnsHammer(Conversation):
1004 """A lightweight conversation that generates a lot of dns:0 packets on
1007 def __init__(self, dns_rate, duration):
1008 n = int(dns_rate * duration)
1009 self.times = [random.uniform(0, duration) for i in range(n)]
1011 self.rate = dns_rate
1012 self.duration = duration
1014 self.msg = random_colour_print()
1017 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1018 (len(self.times), self.duration, self.rate))
1020 def replay_in_fork_with_delay(self, start, context=None, account=None):
1021 return Conversation.replay_in_fork_with_delay(self,
1026 def replay(self, context=None):
1028 fn = traffic_packets.packet_dns_0
1029 for t in self.times:
1030 now = time.time() - start
1032 sleep_time = gap - SLEEP_OVERHEAD
1034 time.sleep(sleep_time)
1037 miss = t - (time.time() - start)
1038 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1042 packet_start = time.time()
1044 fn(self, self, context)
1046 duration = end - packet_start
1047 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1048 except Exception as e:
1050 duration = end - packet_start
1051 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1054 def ingest_summaries(files, dns_mode='count'):
1055 """Load a summary traffic summary file and generated Converations from it.
1058 dns_counts = defaultdict(int)
1061 if isinstance(f, str):
1063 print("Ingesting %s" % (f.name,), file=sys.stderr)
1065 p = Packet.from_line(line)
1066 if p.protocol == 'dns' and dns_mode != 'include':
1067 dns_counts[p.opcode] += 1
1076 start_time = min(p.timestamp for p in packets)
1077 last_packet = max(p.timestamp for p in packets)
1079 print("gathering packets into conversations", file=sys.stderr)
1080 conversations = OrderedDict()
1081 for i, p in enumerate(packets):
1082 p.timestamp -= start_time
1083 c = conversations.get(p.endpoints)
1085 c = Conversation(conversation_id=(i + 2))
1086 conversations[p.endpoints] = c
1089 # We only care about conversations with actual traffic, so we
1090 # filter out conversations with nothing to say. We do that here,
1091 # rather than earlier, because those empty packets contain useful
1092 # hints as to which end of the conversation was the client.
1093 conversation_list = []
1094 for c in conversations.values():
1096 conversation_list.append(c)
1098 # This is obviously not correct, as many conversations will appear
1099 # to start roughly simultaneously at the beginning of the snapshot.
1100 # To which we say: oh well, so be it.
1101 duration = float(last_packet - start_time)
1102 mean_interval = len(conversations) / duration
1104 return conversation_list, mean_interval, duration, dns_counts
1107 def guess_server_address(conversations):
1108 # we guess the most common address.
1109 addresses = Counter()
1110 for c in conversations:
1111 addresses.update(c.endpoints)
1113 return addresses.most_common(1)[0]
1116 def stringify_keys(x):
1118 for k, v in x.items():
1124 def unstringify_keys(x):
1126 for k, v in x.items():
1127 t = tuple(str(k).split('\t'))
1132 class TrafficModel(object):
1133 def __init__(self, n=3):
1135 self.query_details = {}
1137 self.dns_opcounts = defaultdict(int)
1138 self.cumulative_duration = 0.0
1139 self.conversation_rate = [0, 1]
1141 def learn(self, conversations, dns_opcounts={}):
1144 key = (NON_PACKET,) * (self.n - 1)
1146 server = guess_server_address(conversations)
1148 for k, v in dns_opcounts.items():
1149 self.dns_opcounts[k] += v
1151 if len(conversations) > 1:
1153 conversations[-1].start_time - conversations[0].start_time
1154 self.conversation_rate[0] = len(conversations)
1155 self.conversation_rate[1] = elapsed
1157 for c in conversations:
1158 client, server = c.guess_client_server(server)
1159 cum_duration += c.get_duration()
1160 key = (NON_PACKET,) * (self.n - 1)
1165 elapsed = p.timestamp - prev
1167 if elapsed > WAIT_THRESHOLD:
1168 # add the wait as an extra state
1169 wait = 'wait:%d' % (math.log(max(1.0,
1170 elapsed * WAIT_SCALE)))
1171 self.ngrams.setdefault(key, []).append(wait)
1172 key = key[1:] + (wait,)
1174 short_p = p.as_packet_type()
1175 self.query_details.setdefault(short_p,
1176 []).append(tuple(p.extra))
1177 self.ngrams.setdefault(key, []).append(short_p)
1178 key = key[1:] + (short_p,)
1180 self.cumulative_duration += cum_duration
1182 self.ngrams.setdefault(key, []).append(NON_PACKET)
1186 for k, v in self.ngrams.items():
1188 ngrams[k] = dict(Counter(v))
1191 for k, v in self.query_details.items():
1192 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1197 'query_details': query_details,
1198 'cumulative_duration': self.cumulative_duration,
1199 'conversation_rate': self.conversation_rate,
1201 d['dns'] = self.dns_opcounts
1203 if isinstance(f, str):
1206 json.dump(d, f, indent=2)
1209 if isinstance(f, str):
1214 for k, v in d['ngrams'].items():
1215 k = tuple(str(k).split('\t'))
1216 values = self.ngrams.setdefault(k, [])
1217 for p, count in v.items():
1218 values.extend([str(p)] * count)
1221 for k, v in d['query_details'].items():
1222 values = self.query_details.setdefault(str(k), [])
1223 for p, count in v.items():
1225 values.extend([()] * count)
1227 values.extend([tuple(str(p).split('\t'))] * count)
1231 for k, v in d['dns'].items():
1232 self.dns_opcounts[k] += v
1234 self.cumulative_duration = d['cumulative_duration']
1235 self.conversation_rate = d['conversation_rate']
1237 def construct_conversation(self, timestamp=0.0, client=2, server=1,
1238 hard_stop=None, replay_speed=1):
1239 """Construct a individual converation from the model."""
1241 c = Conversation(timestamp, (server, client), conversation_id=client)
1243 key = (NON_PACKET,) * (self.n - 1)
1245 while key in self.ngrams:
1246 p = random.choice(self.ngrams.get(key, NON_PACKET))
1250 if p in self.query_details:
1251 extra = random.choice(self.query_details[p])
1255 protocol, opcode = p.split(':', 1)
1256 if protocol == 'wait':
1257 log_wait_time = int(opcode) + random.random()
1258 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1261 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1262 wait = math.exp(log_wait) / replay_speed
1264 if hard_stop is not None and timestamp > hard_stop:
1266 c.add_short_packet(timestamp, protocol, opcode, extra)
1268 key = key[1:] + (p,)
1272 def generate_conversations(self, rate, duration, replay_speed=1):
1273 """Generate a list of conversations from the model."""
1275 # We run the simulation for at least ten times as long as our
1276 # desired duration, and take a section near the start.
1277 rate_n, rate_t = self.conversation_rate
1279 duration2 = max(rate_t, duration * 2)
1280 n = rate * duration2 * rate_n / rate_t
1287 start = end - duration
1289 while client < n + 2:
1290 start = random.uniform(0, duration2)
1291 c = self.construct_conversation(start,
1294 hard_stop=(duration2 * 5),
1295 replay_speed=replay_speed)
1297 c.forget_packets_outside_window(start, end)
1298 c.renormalise_times(start)
1300 conversations.append(c)
1303 print(("we have %d conversations at rate %f" %
1304 (len(conversations), rate)), file=sys.stderr)
1305 conversations.sort()
1306 return conversations
1311 'rpc_netlogon': '06',
1312 'kerberos': '06', # ratio 16248:258
1323 'smb_netlogon': '11',
1329 ('browser', '0x01'): 'Host Announcement (0x01)',
1330 ('browser', '0x02'): 'Request Announcement (0x02)',
1331 ('browser', '0x08'): 'Browser Election Request (0x08)',
1332 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1333 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1334 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1335 ('cldap', '3'): 'searchRequest',
1336 ('cldap', '5'): 'searchResDone',
1337 ('dcerpc', '0'): 'Request',
1338 ('dcerpc', '11'): 'Bind',
1339 ('dcerpc', '12'): 'Bind_ack',
1340 ('dcerpc', '13'): 'Bind_nak',
1341 ('dcerpc', '14'): 'Alter_context',
1342 ('dcerpc', '15'): 'Alter_context_resp',
1343 ('dcerpc', '16'): 'AUTH3',
1344 ('dcerpc', '2'): 'Response',
1345 ('dns', '0'): 'query',
1346 ('dns', '1'): 'response',
1347 ('drsuapi', '0'): 'DsBind',
1348 ('drsuapi', '12'): 'DsCrackNames',
1349 ('drsuapi', '13'): 'DsWriteAccountSpn',
1350 ('drsuapi', '1'): 'DsUnbind',
1351 ('drsuapi', '2'): 'DsReplicaSync',
1352 ('drsuapi', '3'): 'DsGetNCChanges',
1353 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1354 ('epm', '3'): 'Map',
1355 ('kerberos', ''): '',
1356 ('ldap', '0'): 'bindRequest',
1357 ('ldap', '1'): 'bindResponse',
1358 ('ldap', '2'): 'unbindRequest',
1359 ('ldap', '3'): 'searchRequest',
1360 ('ldap', '4'): 'searchResEntry',
1361 ('ldap', '5'): 'searchResDone',
1362 ('ldap', ''): '*** Unknown ***',
1363 ('lsarpc', '14'): 'lsa_LookupNames',
1364 ('lsarpc', '15'): 'lsa_LookupSids',
1365 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1366 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1367 ('lsarpc', '6'): 'lsa_OpenPolicy',
1368 ('lsarpc', '76'): 'lsa_LookupSids3',
1369 ('lsarpc', '77'): 'lsa_LookupNames4',
1370 ('nbns', '0'): 'query',
1371 ('nbns', '1'): 'response',
1372 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1373 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1374 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1375 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1376 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1377 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1378 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1379 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1380 ('samr', '0',): 'Connect',
1381 ('samr', '16'): 'GetAliasMembership',
1382 ('samr', '17'): 'LookupNames',
1383 ('samr', '18'): 'LookupRids',
1384 ('samr', '19'): 'OpenGroup',
1385 ('samr', '1'): 'Close',
1386 ('samr', '25'): 'QueryGroupMember',
1387 ('samr', '34'): 'OpenUser',
1388 ('samr', '36'): 'QueryUserInfo',
1389 ('samr', '39'): 'GetGroupsForUser',
1390 ('samr', '3'): 'QuerySecurity',
1391 ('samr', '5'): 'LookupDomain',
1392 ('samr', '64'): 'Connect5',
1393 ('samr', '6'): 'EnumDomains',
1394 ('samr', '7'): 'OpenDomain',
1395 ('samr', '8'): 'QueryDomainInfo',
1396 ('smb', '0x04'): 'Close (0x04)',
1397 ('smb', '0x24'): 'Locking AndX (0x24)',
1398 ('smb', '0x2e'): 'Read AndX (0x2e)',
1399 ('smb', '0x32'): 'Trans2 (0x32)',
1400 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1401 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1402 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1403 ('smb', '0x74'): 'Logoff AndX (0x74)',
1404 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1405 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1406 ('smb2', '0'): 'NegotiateProtocol',
1407 ('smb2', '11'): 'Ioctl',
1408 ('smb2', '14'): 'Find',
1409 ('smb2', '16'): 'GetInfo',
1410 ('smb2', '18'): 'Break',
1411 ('smb2', '1'): 'SessionSetup',
1412 ('smb2', '2'): 'SessionLogoff',
1413 ('smb2', '3'): 'TreeConnect',
1414 ('smb2', '4'): 'TreeDisconnect',
1415 ('smb2', '5'): 'Create',
1416 ('smb2', '6'): 'Close',
1417 ('smb2', '8'): 'Read',
1418 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1419 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1420 'user unknown (0x17)'),
1421 ('srvsvc', '16'): 'NetShareGetInfo',
1422 ('srvsvc', '21'): 'NetSrvGetInfo',
1426 def expand_short_packet(p, timestamp, src, dest, extra):
1427 protocol, opcode = p.split(':', 1)
1428 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1429 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1431 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1433 return '\t'.join(line)
1436 def replay(conversations,
1445 context = ReplayContext(server=host,
1450 if len(accounts) < len(conversations):
1451 print(("we have %d accounts but %d conversations" %
1452 (accounts, conversations)), file=sys.stderr)
1455 sorted(conversations, key=lambda x: x.start_time, reverse=True),
1458 # Set the process group so that the calling scripts are not killed
1459 # when the forked child processes are killed.
1464 if duration is None:
1465 # end 1 second after the last packet of the last conversation
1466 # to start. Conversations other than the last could still be
1467 # going, but we don't care.
1468 duration = cstack[0][0].packets[-1].timestamp + 1.0
1469 print("We will stop after %.1f seconds" % duration,
1472 end = start + duration
1474 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1475 % (len(conversations), duration))
1479 dns_hammer = DnsHammer(dns_rate, duration)
1480 cstack.append((dns_hammer, None))
1484 # we spawn a batch, wait for finishers, then spawn another
1486 batch_end = min(now + 2.0, end)
1490 c, account = cstack.pop()
1491 if c.start_time + start > batch_end:
1492 cstack.append((c, account))
1496 pid = c.replay_in_fork_with_delay(start, context, account)
1500 fork_time += elapsed
1502 print("forked %s in pid %s (in %fs)" % (c, pid,
1507 print(("forked %d times in %f seconds (avg %f)" %
1508 (fork_n, fork_time, fork_time / fork_n)),
1511 debug(2, "no forks in batch ending %f" % batch_end)
1513 while time.time() < batch_end - 1.0:
1516 pid, status = os.waitpid(-1, os.WNOHANG)
1517 except OSError as e:
1518 if e.errno != 10: # no child processes
1522 c = children.pop(pid, None)
1523 print(("process %d finished conversation %s;"
1525 (pid, c, len(children))), file=sys.stderr)
1527 if time.time() >= end:
1528 print("time to stop", file=sys.stderr)
1532 print("EXCEPTION in parent", file=sys.stderr)
1533 traceback.print_exc()
1535 for s in (15, 15, 9):
1536 print(("killing %d children with -%d" %
1537 (len(children), s)), file=sys.stderr)
1538 for pid in children:
1541 except OSError as e:
1542 if e.errno != 3: # don't fail if it has already died
1545 end = time.time() + 1
1548 pid, status = os.waitpid(-1, os.WNOHANG)
1549 except OSError as e:
1553 c = children.pop(pid, None)
1554 print(("kill -%d %d KILLED conversation %s; "
1556 (s, pid, c, len(children))),
1558 if time.time() >= end:
1566 print("%d children are missing" % len(children),
1569 # there may be stragglers that were forked just as ^C was hit
1570 # and don't appear in the list of children. We can get them
1571 # with killpg, but that will also kill us, so this is^H^H would be
1572 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1573 # so as not to have to fuss around writing signal handlers.
1576 except KeyboardInterrupt:
1577 print("ignoring fake ^C", file=sys.stderr)
1580 def openLdb(host, creds, lp):
1581 session = system_session()
1582 ldb = SamDB(url="ldap://%s" % host,
1583 session_info=session,
1584 options=['modules:paged_searches'],
1590 def ou_name(ldb, instance_id):
1591 """Generate an ou name from the instance id"""
1592 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1596 def create_ou(ldb, instance_id):
1597 """Create an ou, all created user and machine accounts will belong to it.
1599 This allows all the created resources to be cleaned up easily.
1601 ou = ou_name(ldb, instance_id)
1603 ldb.add({"dn": ou.split(',', 1)[1],
1604 "objectclass": "organizationalunit"})
1605 except LdbError as e:
1606 (status, _) = e.args
1607 # ignore already exists
1612 "objectclass": "organizationalunit"})
1613 except LdbError as e:
1614 (status, _) = e.args
1615 # ignore already exists
1621 # ConversationAccounts holds details of the machine and user accounts
1622 # associated with a conversation.
1624 # We use a named tuple to reduce shared memory usage.
1625 ConversationAccounts = namedtuple('ConversationAccounts',
1632 def generate_replay_accounts(ldb, instance_id, number, password):
1633 """Generate a series of unique machine and user account names."""
1636 for i in range(1, number + 1):
1637 netbios_name = machine_name(instance_id, i)
1638 username = user_name(instance_id, i)
1640 account = ConversationAccounts(netbios_name, password, username,
1642 accounts.append(account)
1646 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1647 traffic_account=True):
1648 """Create a machine account via ldap."""
1650 ou = ou_name(ldb, instance_id)
1651 dn = "cn=%s,%s" % (netbios_name, ou)
1652 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1655 # we set these bits for the machine account otherwise the replayed
1656 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1657 account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1658 UF_SERVER_TRUST_ACCOUNT)
1661 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1665 "objectclass": "computer",
1666 "sAMAccountName": "%s$" % netbios_name,
1667 "userAccountControl": account_controls,
1668 "unicodePwd": utf16pw})
1671 def create_user_account(ldb, instance_id, username, userpass):
1672 """Create a user account via ldap."""
1673 ou = ou_name(ldb, instance_id)
1674 user_dn = "cn=%s,%s" % (username, ou)
1675 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1678 "objectclass": "user",
1679 "sAMAccountName": username,
1680 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1681 "unicodePwd": utf16pw
1684 # grant user write permission to do things like write account SPN
1685 sdutils = sd_utils.SDUtils(ldb)
1686 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1689 def create_group(ldb, instance_id, name):
1690 """Create a group via ldap."""
1692 ou = ou_name(ldb, instance_id)
1693 dn = "cn=%s,%s" % (name, ou)
1696 "objectclass": "group",
1697 "sAMAccountName": name,
1701 def user_name(instance_id, i):
1702 """Generate a user name based in the instance id"""
1703 return "STGU-%d-%d" % (instance_id, i)
1706 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1707 """Seach objectclass, return attr in a set"""
1709 expression="(objectClass={})".format(objectclass),
1712 return {str(obj[attr]) for obj in objs}
1715 def generate_users(ldb, instance_id, number, password):
1716 """Add users to the server"""
1717 existing_objects = search_objectclass(ldb, objectclass='user')
1719 for i in range(number, 0, -1):
1720 name = user_name(instance_id, i)
1721 if name not in existing_objects:
1722 create_user_account(ldb, instance_id, name, password)
1725 LOGGER.info("Created %u/%u users" % (users, number))
1730 def machine_name(instance_id, i, traffic_account=True):
1731 """Generate a machine account name from instance id."""
1733 # traffic accounts correspond to a given user, and use different
1734 # userAccountControl flags to ensure packets get processed correctly
1736 return "STGM-%d-%d" % (instance_id, i)
1738 # Otherwise we're just generating computer accounts to simulate a
1739 # semi-realistic network. These use the default computer
1740 # userAccountControl flags, so we use a different account name so that
1741 # we don't try to use them when generating packets
1742 return "PC-%d-%d" % (instance_id, i)
1745 def generate_machine_accounts(ldb, instance_id, number, password,
1746 traffic_account=True):
1747 """Add machine accounts to the server"""
1748 existing_objects = search_objectclass(ldb, objectclass='computer')
1750 for i in range(number, 0, -1):
1751 name = machine_name(instance_id, i, traffic_account)
1752 if name + "$" not in existing_objects:
1753 create_machine_account(ldb, instance_id, name, password,
1757 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1762 def group_name(instance_id, i):
1763 """Generate a group name from instance id."""
1764 return "STGG-%d-%d" % (instance_id, i)
1767 def generate_groups(ldb, instance_id, number):
1768 """Create the required number of groups on the server."""
1769 existing_objects = search_objectclass(ldb, objectclass='group')
1771 for i in range(number, 0, -1):
1772 name = group_name(instance_id, i)
1773 if name not in existing_objects:
1774 create_group(ldb, instance_id, name)
1776 if groups % 1000 == 0:
1777 LOGGER.info("Created %u/%u groups" % (groups, number))
1782 def clean_up_accounts(ldb, instance_id):
1783 """Remove the created accounts and groups from the server."""
1784 ou = ou_name(ldb, instance_id)
1786 ldb.delete(ou, ["tree_delete:1"])
1787 except LdbError as e:
1788 (status, _) = e.args
1789 # ignore does not exist
1794 def generate_users_and_groups(ldb, instance_id, password,
1795 number_of_users, number_of_groups,
1796 group_memberships, max_members,
1797 machine_accounts, traffic_accounts=True):
1798 """Generate the required users and groups, allocating the users to
1800 memberships_added = 0
1804 create_ou(ldb, instance_id)
1806 LOGGER.info("Generating dummy user accounts")
1807 users_added = generate_users(ldb, instance_id, number_of_users, password)
1809 LOGGER.info("Generating dummy machine accounts")
1810 computers_added = generate_machine_accounts(ldb, instance_id,
1811 machine_accounts, password,
1814 if number_of_groups > 0:
1815 LOGGER.info("Generating dummy groups")
1816 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1818 if group_memberships > 0:
1819 LOGGER.info("Assigning users to groups")
1820 assignments = GroupAssignments(number_of_groups,
1826 LOGGER.info("Adding users to groups")
1827 add_users_to_groups(ldb, instance_id, assignments)
1828 memberships_added = assignments.total()
1830 if (groups_added > 0 and users_added == 0 and
1831 number_of_groups != groups_added):
1832 LOGGER.warning("The added groups will contain no members")
1834 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1835 (users_added, computers_added, groups_added,
1839 class GroupAssignments(object):
1840 def __init__(self, number_of_groups, groups_added, number_of_users,
1841 users_added, group_memberships, max_members):
1844 self.generate_group_distribution(number_of_groups)
1845 self.generate_user_distribution(number_of_users, group_memberships)
1846 self.max_members = max_members
1847 self.assignments = defaultdict(list)
1848 self.assign_groups(number_of_groups, groups_added, number_of_users,
1849 users_added, group_memberships)
1851 def cumulative_distribution(self, weights):
1852 # make sure the probabilities conform to a cumulative distribution
1853 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1854 # probability a proportional share of 1.0. Higher probabilities get a
1855 # bigger share, so are more likely to be picked. We use the cumulative
1856 # value, so we can use random.random() as a simple index into the list
1858 total = sum(weights)
1863 for probability in weights:
1864 cumulative += probability
1865 dist.append(cumulative / total)
1868 def generate_user_distribution(self, num_users, num_memberships):
1869 """Probability distribution of a user belonging to a group.
1871 # Assign a weighted probability to each user. Use the Pareto
1872 # Distribution so that some users are in a lot of groups, and the
1873 # bulk of users are in only a few groups. If we're assigning a large
1874 # number of group memberships, use a higher shape. This means slightly
1875 # fewer outlying users that are in large numbers of groups. The aim is
1876 # to have no users belonging to more than ~500 groups.
1877 if num_memberships > 5000000:
1879 elif num_memberships > 2000000:
1881 elif num_memberships > 300000:
1887 for x in range(1, num_users + 1):
1888 p = random.paretovariate(shape)
1891 # convert the weights to a cumulative distribution between 0.0 and 1.0
1892 self.user_dist = self.cumulative_distribution(weights)
1894 def generate_group_distribution(self, n):
1895 """Probability distribution of a group containing a user."""
1897 # Assign a weighted probability to each user. Probability decreases
1898 # as the group-ID increases
1900 for x in range(1, n + 1):
1904 # convert the weights to a cumulative distribution between 0.0 and 1.0
1905 self.group_weights = weights
1906 self.group_dist = self.cumulative_distribution(weights)
1908 def generate_random_membership(self):
1909 """Returns a randomly generated user-group membership"""
1911 # the list items are cumulative distribution values between 0.0 and
1912 # 1.0, which makes random() a handy way to index the list to get a
1913 # weighted random user/group. (Here the user/group returned are
1914 # zero-based array indexes)
1915 user = bisect.bisect(self.user_dist, random.random())
1916 group = bisect.bisect(self.group_dist, random.random())
1920 def users_in_group(self, group):
1921 return self.assignments[group]
1923 def get_groups(self):
1924 return self.assignments.keys()
1926 def cap_group_membership(self, group, max_members):
1927 """Prevent the group's membership from exceeding the max specified"""
1928 num_members = len(self.assignments[group])
1929 if num_members >= max_members:
1930 LOGGER.info("Group {0} has {1} members".format(group, num_members))
1932 # remove this group and then recalculate the cumulative
1933 # distribution, so this group is no longer selected
1934 self.group_weights[group - 1] = 0
1935 new_dist = self.cumulative_distribution(self.group_weights)
1936 self.group_dist = new_dist
1938 def add_assignment(self, user, group):
1939 # the assignments are stored in a dictionary where key=group,
1940 # value=list-of-users-in-group (indexing by group-ID allows us to
1941 # optimize for DB membership writes)
1942 if user not in self.assignments[group]:
1943 self.assignments[group].append(user)
1946 # check if there'a cap on how big the groups can grow
1947 if self.max_members:
1948 self.cap_group_membership(group, self.max_members)
1950 def assign_groups(self, number_of_groups, groups_added,
1951 number_of_users, users_added, group_memberships):
1952 """Allocate users to groups.
1954 The intention is to have a few users that belong to most groups, while
1955 the majority of users belong to a few groups.
1957 A few groups will contain most users, with the remaining only having a
1961 if group_memberships <= 0:
1964 # Calculate the number of group menberships required
1965 group_memberships = math.ceil(
1966 float(group_memberships) *
1967 (float(users_added) / float(number_of_users)))
1969 if self.max_members:
1970 group_memberships = min(group_memberships,
1971 self.max_members * number_of_groups)
1973 existing_users = number_of_users - users_added - 1
1974 existing_groups = number_of_groups - groups_added - 1
1975 while self.total() < group_memberships:
1976 user, group = self.generate_random_membership()
1978 if group > existing_groups or user > existing_users:
1979 # the + 1 converts the array index to the corresponding
1980 # group or user number
1981 self.add_assignment(user + 1, group + 1)
1987 def add_users_to_groups(db, instance_id, assignments):
1988 """Takes the assignments of users to groups and applies them to the DB."""
1990 total = assignments.total()
1994 for group in assignments.get_groups():
1995 users_in_group = assignments.users_in_group(group)
1996 if len(users_in_group) == 0:
1999 # Split up the users into chunks, so we write no more than 1K at a
2000 # time. (Minimizing the DB modifies is more efficient, but writing
2001 # 10K+ users to a single group becomes inefficient memory-wise)
2002 for chunk in range(0, len(users_in_group), 1000):
2003 chunk_of_users = users_in_group[chunk:chunk + 1000]
2004 add_group_members(db, instance_id, group, chunk_of_users)
2006 added += len(chunk_of_users)
2009 LOGGER.info("Added %u/%u memberships" % (added, total))
2011 def add_group_members(db, instance_id, group, users_in_group):
2012 """Adds the given users to group specified."""
2014 ou = ou_name(db, instance_id)
2017 return("cn=%s,%s" % (name, ou))
2019 group_dn = build_dn(group_name(instance_id, group))
2021 m.dn = ldb.Dn(db, group_dn)
2023 for user in users_in_group:
2024 user_dn = build_dn(user_name(instance_id, user))
2025 idx = "member-" + str(user)
2026 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2031 def generate_stats(statsdir, timing_file):
2032 """Generate and print the summary stats for a run."""
2033 first = sys.float_info.max
2039 unique_converations = set()
2042 if timing_file is not None:
2043 tw = timing_file.write
2048 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2050 for filename in os.listdir(statsdir):
2051 path = os.path.join(statsdir, filename)
2052 with open(path, 'r') as f:
2055 fields = line.rstrip('\n').split('\t')
2056 conversation = fields[1]
2057 protocol = fields[2]
2058 packet_type = fields[3]
2059 latency = float(fields[4])
2060 first = min(float(fields[0]) - latency, first)
2061 last = max(float(fields[0]), last)
2063 if protocol not in latencies:
2064 latencies[protocol] = {}
2065 if packet_type not in latencies[protocol]:
2066 latencies[protocol][packet_type] = []
2068 latencies[protocol][packet_type].append(latency)
2070 if protocol not in failures:
2071 failures[protocol] = {}
2072 if packet_type not in failures[protocol]:
2073 failures[protocol][packet_type] = 0
2075 if fields[5] == 'True':
2079 failures[protocol][packet_type] += 1
2081 if conversation not in unique_converations:
2082 unique_converations.add(conversation)
2086 except (ValueError, IndexError):
2087 # not a valid line print and ignore
2088 print(line, file=sys.stderr)
2090 duration = last - first
2094 success_rate = successful / duration
2098 failure_rate = failed / duration
2100 print("Total conversations: %10d" % conversations)
2101 print("Successful operations: %10d (%.3f per second)"
2102 % (successful, success_rate))
2103 print("Failed operations: %10d (%.3f per second)"
2104 % (failed, failure_rate))
2106 print("Protocol Op Code Description "
2107 " Count Failed Mean Median "
2110 protocols = sorted(latencies.keys())
2111 for protocol in protocols:
2112 packet_types = sorted(latencies[protocol], key=opcode_key)
2113 for packet_type in packet_types:
2114 values = latencies[protocol][packet_type]
2115 values = sorted(values)
2117 failed = failures[protocol][packet_type]
2118 mean = sum(values) / count
2119 median = calc_percentile(values, 0.50)
2120 percentile = calc_percentile(values, 0.95)
2121 rng = values[-1] - values[0]
2123 desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2124 if sys.stdout.isatty:
2125 print("%-12s %4s %-35s %12d %12d %12.6f "
2126 "%12.6f %12.6f %12.6f %12.6f"
2138 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2152 """Sort key for the operation code to ensure that it sorts numerically"""
2154 return "%03d" % int(v)
2159 def calc_percentile(values, percentile):
2160 """Calculate the specified percentile from the list of values.
2162 Assumes the list is sorted in ascending order.
2167 k = (len(values) - 1) * percentile
2171 return values[int(k)]
2172 d0 = values[int(f)] * (c - k)
2173 d1 = values[int(c)] * (k - f)
2177 def mk_masked_dir(*path):
2178 """In a testenv we end up with 0777 directories that look an alarming
2179 green colour with ls. Use umask to avoid that."""
2180 # py3 os.mkdir can do this
2181 d = os.path.join(*path)
2182 mask = os.umask(0o077)