1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
47 UF_SERVER_TRUST_ACCOUNT,
48 UF_TRUSTED_FOR_DELEGATION
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
59 # we don't use None, because it complicates [de]serialisation
63 ('dns', '0'): 1.0, # query
64 ('smb', '0x72'): 1.0, # Negotiate protocol
65 ('ldap', '0'): 1.0, # bind
66 ('ldap', '3'): 1.0, # searchRequest
67 ('ldap', '2'): 1.0, # unbindRequest
69 ('dcerpc', '11'): 1.0, # bind
70 ('dcerpc', '14'): 1.0, # Alter_context
71 ('nbns', '0'): 1.0, # query
75 ('dns', '1'): 1.0, # response
76 ('ldap', '1'): 1.0, # bind response
77 ('ldap', '4'): 1.0, # search result
78 ('ldap', '5'): 1.0, # search done
80 ('dcerpc', '12'): 1.0, # bind_ack
81 ('dcerpc', '13'): 1.0, # bind_nak
82 ('dcerpc', '15'): 1.0, # Alter_context response
85 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
88 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
89 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
91 # DEBUG_LEVEL can be changed by scripts with -d
94 LOGGER = get_samba_logger(name=__name__)
97 def debug(level, msg, *args):
98 """Print a formatted debug message to standard error.
101 :param level: The debug level, message will be printed if it is <= the
102 currently set debug level. The debug level can be set with
104 :param msg: The message to be logged, can contain C-Style format
106 :param args: The parameters required by the format specifiers
108 if level <= DEBUG_LEVEL:
110 print(msg, file=sys.stderr)
112 print(msg % tuple(args), file=sys.stderr)
115 def debug_lineno(*args):
116 """ Print an unformatted log message to stderr, contaning the line number
118 tb = traceback.extract_stack(limit=2)
119 print((" %s:" "\033[01;33m"
120 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
123 print(a, file=sys.stderr)
124 print(file=sys.stderr)
128 def random_colour_print():
129 """Return a function that prints a randomly coloured line to stderr"""
130 n = 18 + random.randrange(214)
131 prefix = "\033[38;5;%dm" % n
135 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
140 class FakePacketError(Exception):
144 class Packet(object):
145 """Details of a network packet"""
146 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
147 protocol, opcode, desc, extra):
149 self.timestamp = timestamp
150 self.ip_protocol = ip_protocol
151 self.stream_number = stream_number
154 self.protocol = protocol
158 if self.src < self.dest:
159 self.endpoints = (self.src, self.dest)
161 self.endpoints = (self.dest, self.src)
164 def from_line(self, line):
165 fields = line.rstrip('\n').split('\t')
176 timestamp = float(timestamp)
180 return Packet(timestamp, ip_protocol, stream_number, src, dest,
181 protocol, opcode, desc, extra)
183 def as_summary(self, time_offset=0.0):
184 """Format the packet as a traffic_summary line.
186 extra = '\t'.join(self.extra)
187 t = self.timestamp + time_offset
188 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
191 self.stream_number or '',
200 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
201 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
202 self.stream_number, self.protocol, self.opcode, self.desc,
203 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
206 return "<Packet @%s>" % self
209 return self.__class__(self.timestamp,
219 def as_packet_type(self):
220 t = '%s:%s' % (self.protocol, self.opcode)
223 def client_score(self):
224 """A positive number means we think it is a client; a negative number
225 means we think it is a server. Zero means no idea. range: -1 to 1.
227 key = (self.protocol, self.opcode)
228 if key in CLIENT_CLUES:
229 return CLIENT_CLUES[key]
230 if key in SERVER_CLUES:
231 return -SERVER_CLUES[key]
234 def play(self, conversation, context):
235 """Send the packet over the network, if required.
237 Some packets are ignored, i.e. for protocols not handled,
238 server response messages, or messages that are generated by the
239 protocol layer associated with other packets.
241 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
243 fn = getattr(traffic_packets, fn_name)
245 except AttributeError as e:
246 print("Conversation(%s) Missing handler %s" %
247 (conversation.conversation_id, fn_name),
251 # Don't display a message for kerberos packets, they're not directly
252 # generated they're used to indicate kerberos should be used
253 if self.protocol != "kerberos":
254 debug(2, "Conversation(%s) Calling handler %s" %
255 (conversation.conversation_id, fn_name))
259 if fn(self, conversation, context):
260 # Only collect timing data for functions that generate
261 # network traffic, or fail
263 duration = end - start
264 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
265 (end, conversation.conversation_id, self.protocol,
266 self.opcode, duration))
267 except Exception as e:
269 duration = end - start
270 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
271 (end, conversation.conversation_id, self.protocol,
272 self.opcode, duration, e))
274 def __cmp__(self, other):
275 return self.timestamp - other.timestamp
277 def is_really_a_packet(self, missing_packet_stats=None):
278 """Is the packet one that can be ignored?
280 If so removing it will have no effect on the replay
282 if self.protocol in SKIPPED_PROTOCOLS:
283 # Ignore any packets for the protocols we're not interested in.
285 if self.protocol == "ldap" and self.opcode == '':
286 # skip ldap continuation packets
289 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
290 fn = getattr(traffic_packets, fn_name, None)
292 print("missing packet %s" % fn_name, file=sys.stderr)
294 if fn is traffic_packets.null_packet:
299 class ReplayContext(object):
300 """State/Context for an individual conversation between an simulated client
308 badpassword_frequency=None,
309 prefer_kerberos=None,
318 self.ldap_connections = []
319 self.dcerpc_connections = []
320 self.lsarpc_connections = []
321 self.lsarpc_connections_named = []
322 self.drsuapi_connections = []
323 self.srvsvc_connections = []
324 self.samr_contexts = []
325 self.netlogon_connection = None
328 self.prefer_kerberos = prefer_kerberos
330 self.base_dn = base_dn
332 self.statsdir = statsdir
333 self.global_tempdir = tempdir
334 self.domain_sid = domain_sid
335 self.realm = lp.get('realm')
337 # Bad password attempt controls
338 self.badpassword_frequency = badpassword_frequency
339 self.last_lsarpc_bad = False
340 self.last_lsarpc_named_bad = False
341 self.last_simple_bind_bad = False
342 self.last_bind_bad = False
343 self.last_srvsvc_bad = False
344 self.last_drsuapi_bad = False
345 self.last_netlogon_bad = False
346 self.last_samlogon_bad = False
347 self.generate_ldap_search_tables()
348 self.next_conversation_id = itertools.count()
350 def generate_ldap_search_tables(self):
351 session = system_session()
353 db = SamDB(url="ldap://%s" % self.server,
354 session_info=session,
355 credentials=self.creds,
358 res = db.search(db.domain_dn(),
359 scope=ldb.SCOPE_SUBTREE,
360 controls=["paged_results:1:1000"],
363 # find a list of dns for each pattern
364 # e.g. CN,CN,CN,DC,DC
366 attribute_clue_map = {
372 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
373 dns = dn_map.setdefault(pattern, [])
375 if dn.startswith('CN=NTDS Settings,'):
376 attribute_clue_map['invocationId'].append(dn)
378 # extend the map in case we are working with a different
379 # number of DC components.
380 # for k, v in self.dn_map.items():
381 # print >>sys.stderr, k, len(v)
383 for k in list(dn_map.keys()):
387 while p[-3:] == ',DC':
391 if p != k and p in dn_map:
392 print('dn_map collison %s %s' % (k, p),
395 dn_map[p] = dn_map[k]
398 self.attribute_clue_map = attribute_clue_map
400 def generate_process_local_config(self, account, conversation):
403 self.netbios_name = account.netbios_name
404 self.machinepass = account.machinepass
405 self.username = account.username
406 self.userpass = account.userpass
408 self.tempdir = mk_masked_dir(self.global_tempdir,
410 conversation.conversation_id)
412 self.lp.set("private dir", self.tempdir)
413 self.lp.set("lock dir", self.tempdir)
414 self.lp.set("state directory", self.tempdir)
415 self.lp.set("tls verify peer", "no_check")
417 # If the domain was not specified, check for the environment
419 if self.domain is None:
420 self.domain = os.environ["DOMAIN"]
422 self.remoteAddress = "/root/ncalrpc_as_system"
423 self.samlogon_dn = ("cn=%s,%s" %
424 (self.netbios_name, self.ou))
425 self.user_dn = ("cn=%s,%s" %
426 (self.username, self.ou))
428 self.generate_machine_creds()
429 self.generate_user_creds()
431 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
432 """Execute the supplied logon function, randomly choosing the
435 Based on the frequency in badpassword_frequency randomly perform the
436 function with the supplied bad credentials.
437 If run with bad credentials, the function is re-run with the good
439 failed_last_time is used to prevent consecutive bad credential
440 attempts. So the over all bad credential frequency will be lower
441 than that requested, but not significantly.
443 if not failed_last_time:
444 if (self.badpassword_frequency and self.badpassword_frequency > 0
445 and random.random() < self.badpassword_frequency):
449 # Ignore any exceptions as the operation may fail
450 # as it's being performed with bad credentials
452 failed_last_time = True
454 failed_last_time = False
457 return (result, failed_last_time)
459 def generate_user_creds(self):
460 """Generate the conversation specific user Credentials.
462 Each Conversation has an associated user account used to simulate
463 any non Administrative user traffic.
465 Generates user credentials with good and bad passwords and ldap
466 simple bind credentials with good and bad passwords.
468 self.user_creds = Credentials()
469 self.user_creds.guess(self.lp)
470 self.user_creds.set_workstation(self.netbios_name)
471 self.user_creds.set_password(self.userpass)
472 self.user_creds.set_username(self.username)
473 self.user_creds.set_domain(self.domain)
474 if self.prefer_kerberos:
475 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
477 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
479 self.user_creds_bad = Credentials()
480 self.user_creds_bad.guess(self.lp)
481 self.user_creds_bad.set_workstation(self.netbios_name)
482 self.user_creds_bad.set_password(self.userpass[:-4])
483 self.user_creds_bad.set_username(self.username)
484 if self.prefer_kerberos:
485 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
487 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
489 # Credentials for ldap simple bind.
490 self.simple_bind_creds = Credentials()
491 self.simple_bind_creds.guess(self.lp)
492 self.simple_bind_creds.set_workstation(self.netbios_name)
493 self.simple_bind_creds.set_password(self.userpass)
494 self.simple_bind_creds.set_username(self.username)
495 self.simple_bind_creds.set_gensec_features(
496 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
497 if self.prefer_kerberos:
498 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
500 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
501 self.simple_bind_creds.set_bind_dn(self.user_dn)
503 self.simple_bind_creds_bad = Credentials()
504 self.simple_bind_creds_bad.guess(self.lp)
505 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
506 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
507 self.simple_bind_creds_bad.set_username(self.username)
508 self.simple_bind_creds_bad.set_gensec_features(
509 self.simple_bind_creds_bad.get_gensec_features() |
511 if self.prefer_kerberos:
512 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
514 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
515 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
517 def generate_machine_creds(self):
518 """Generate the conversation specific machine Credentials.
520 Each Conversation has an associated machine account.
522 Generates machine credentials with good and bad passwords.
525 self.machine_creds = Credentials()
526 self.machine_creds.guess(self.lp)
527 self.machine_creds.set_workstation(self.netbios_name)
528 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
529 self.machine_creds.set_password(self.machinepass)
530 self.machine_creds.set_username(self.netbios_name + "$")
531 self.machine_creds.set_domain(self.domain)
532 if self.prefer_kerberos:
533 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
535 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
537 self.machine_creds_bad = Credentials()
538 self.machine_creds_bad.guess(self.lp)
539 self.machine_creds_bad.set_workstation(self.netbios_name)
540 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
541 self.machine_creds_bad.set_password(self.machinepass[:-4])
542 self.machine_creds_bad.set_username(self.netbios_name + "$")
543 if self.prefer_kerberos:
544 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
546 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
548 def get_matching_dn(self, pattern, attributes=None):
549 # If the pattern is an empty string, we assume ROOTDSE,
550 # Otherwise we try adding or removing DC suffixes, then
551 # shorter leading patterns until we hit one.
552 # e.g if there is no CN,CN,CN,CN,DC,DC
553 # we first try CN,CN,CN,CN,DC
554 # and CN,CN,CN,CN,DC,DC,DC
555 # then change to CN,CN,CN,DC,DC
556 # and as last resort we use the base_dn
557 attr_clue = self.attribute_clue_map.get(attributes)
559 return random.choice(attr_clue)
561 pattern = pattern.upper()
563 if pattern in self.dn_map:
564 return random.choice(self.dn_map[pattern])
565 # chop one off the front and try it all again.
566 pattern = pattern[3:]
570 def get_dcerpc_connection(self, new=False):
571 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
572 if self.dcerpc_connections and not new:
573 return self.dcerpc_connections[-1]
574 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
576 self.dcerpc_connections.append(c)
579 def get_srvsvc_connection(self, new=False):
580 if self.srvsvc_connections and not new:
581 return self.srvsvc_connections[-1]
584 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
588 (c, self.last_srvsvc_bad) = \
589 self.with_random_bad_credentials(connect,
592 self.last_srvsvc_bad)
594 self.srvsvc_connections.append(c)
597 def get_lsarpc_connection(self, new=False):
598 if self.lsarpc_connections and not new:
599 return self.lsarpc_connections[-1]
602 binding_options = 'schannel,seal,sign'
603 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
604 (self.server, binding_options),
608 (c, self.last_lsarpc_bad) = \
609 self.with_random_bad_credentials(connect,
611 self.machine_creds_bad,
612 self.last_lsarpc_bad)
614 self.lsarpc_connections.append(c)
617 def get_lsarpc_named_pipe_connection(self, new=False):
618 if self.lsarpc_connections_named and not new:
619 return self.lsarpc_connections_named[-1]
622 return lsa.lsarpc("ncacn_np:%s" % (self.server),
626 (c, self.last_lsarpc_named_bad) = \
627 self.with_random_bad_credentials(connect,
629 self.machine_creds_bad,
630 self.last_lsarpc_named_bad)
632 self.lsarpc_connections_named.append(c)
635 def get_drsuapi_connection_pair(self, new=False, unbind=False):
636 """get a (drs, drs_handle) tuple"""
637 if self.drsuapi_connections and not new:
638 c = self.drsuapi_connections[-1]
642 binding_options = 'seal'
643 binding_string = "ncacn_ip_tcp:%s[%s]" %\
644 (self.server, binding_options)
645 return drsuapi.drsuapi(binding_string, self.lp, creds)
647 (drs, self.last_drsuapi_bad) = \
648 self.with_random_bad_credentials(connect,
651 self.last_drsuapi_bad)
653 (drs_handle, supported_extensions) = drs_DsBind(drs)
654 c = (drs, drs_handle)
655 self.drsuapi_connections.append(c)
658 def get_ldap_connection(self, new=False, simple=False):
659 if self.ldap_connections and not new:
660 return self.ldap_connections[-1]
662 def simple_bind(creds):
664 To run simple bind against Windows, we need to run
665 following commands in PowerShell:
667 Install-windowsfeature ADCS-Cert-Authority
668 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
672 return SamDB('ldaps://%s' % self.server,
676 def sasl_bind(creds):
677 return SamDB('ldap://%s' % self.server,
681 (samdb, self.last_simple_bind_bad) = \
682 self.with_random_bad_credentials(simple_bind,
683 self.simple_bind_creds,
684 self.simple_bind_creds_bad,
685 self.last_simple_bind_bad)
687 (samdb, self.last_bind_bad) = \
688 self.with_random_bad_credentials(sasl_bind,
693 self.ldap_connections.append(samdb)
696 def get_samr_context(self, new=False):
697 if not self.samr_contexts or new:
698 self.samr_contexts.append(
699 SamrContext(self.server, lp=self.lp, creds=self.creds))
700 return self.samr_contexts[-1]
702 def get_netlogon_connection(self):
704 if self.netlogon_connection:
705 return self.netlogon_connection
708 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
712 (c, self.last_netlogon_bad) = \
713 self.with_random_bad_credentials(connect,
715 self.machine_creds_bad,
716 self.last_netlogon_bad)
717 self.netlogon_connection = c
720 def guess_a_dns_lookup(self):
721 return (self.realm, 'A')
723 def get_authenticator(self):
724 auth = self.machine_creds.new_client_authenticator()
725 current = netr_Authenticator()
726 current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
727 current.timestamp = auth["timestamp"]
729 subsequent = netr_Authenticator()
730 return (current, subsequent)
733 class SamrContext(object):
734 """State/Context associated with a samr connection.
736 def __init__(self, server, lp=None, creds=None):
737 self.connection = None
739 self.domain_handle = None
740 self.domain_sid = None
741 self.group_handle = None
742 self.user_handle = None
748 def get_connection(self):
749 if not self.connection:
750 self.connection = samr.samr(
751 "ncacn_ip_tcp:%s[seal]" % (self.server),
753 credentials=self.creds)
755 return self.connection
757 def get_handle(self):
759 c = self.get_connection()
760 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
764 class Conversation(object):
765 """Details of a converation between a simulated client and a server."""
766 conversation_id = None
768 def __init__(self, start_time=None, endpoints=None):
769 self.start_time = start_time
770 self.endpoints = endpoints
772 self.msg = random_colour_print()
773 self.client_balance = 0.0
775 def __cmp__(self, other):
776 if self.start_time is None:
777 if other.start_time is None:
780 if other.start_time is None:
782 return self.start_time - other.start_time
784 def add_packet(self, packet):
785 """Add a packet object to this conversation, making a local copy with
786 a conversation-relative timestamp."""
789 if self.start_time is None:
790 self.start_time = p.timestamp
792 if self.endpoints is None:
793 self.endpoints = p.endpoints
795 if p.endpoints != self.endpoints:
796 raise FakePacketError("Conversation endpoints %s don't match"
797 "packet endpoints %s" %
798 (self.endpoints, p.endpoints))
800 p.timestamp -= self.start_time
802 if p.src == p.endpoints[0]:
803 self.client_balance -= p.client_score()
805 self.client_balance += p.client_score()
807 if p.is_really_a_packet():
808 self.packets.append(p)
810 def add_short_packet(self, timestamp, protocol, opcode, extra,
812 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
813 (possibly empty) list of extra data. If client is True, assume
814 this packet is from the client to the server.
816 src, dest = self.guess_client_server()
818 src, dest = dest, src
819 key = (protocol, opcode)
820 desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
821 if protocol in IP_PROTOCOLS:
822 ip_protocol = IP_PROTOCOLS[protocol]
825 packet = Packet(timestamp - self.start_time, ip_protocol,
827 protocol, opcode, desc, extra)
828 # XXX we're assuming the timestamp is already adjusted for
830 # XXX should we adjust client balance for guessed packets?
831 if packet.src == packet.endpoints[0]:
832 self.client_balance -= packet.client_score()
834 self.client_balance += packet.client_score()
835 if packet.is_really_a_packet():
836 self.packets.append(packet)
839 return ("<Conversation %s %s starting %.3f %d packets>" %
840 (self.conversation_id, self.endpoints, self.start_time,
846 return iter(self.packets)
849 return len(self.packets)
851 def get_duration(self):
852 if len(self.packets) < 2:
854 return self.packets[-1].timestamp - self.packets[0].timestamp
856 def replay_as_summary_lines(self):
858 for p in self.packets:
859 lines.append(p.as_summary(self.start_time))
862 def replay_in_fork_with_delay(self, start, context=None, account=None):
863 """Fork a new process and replay the conversation.
865 def signal_handler(signal, frame):
866 """Signal handler closes standard out and error.
868 Triggered by a sigterm, ensures that the log messages are flushed
869 to disk and not lost.
876 now = time.time() - start
878 # we are replaying strictly in order, so it is safe to sleep
879 # in the main process if the gap is big enough. This reduces
880 # the number of concurrent threads, which allows us to make
882 if gap > 0.15 and False:
883 print("sleeping for %f in main process" % (gap - 0.1),
885 time.sleep(gap - 0.1)
886 now = time.time() - start
888 print("gap is now %f" % gap, file=sys.stderr)
890 self.conversation_id = next(context.next_conversation_id)
895 signal.signal(signal.SIGTERM, signal_handler)
896 # we must never return, or we'll end up running parts of the
897 # parent's clean-up code. So we work in a try...finally, and
898 # try to print any exceptions.
901 context.generate_process_local_config(account, self)
904 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
905 self.conversation_id)
907 sys.stdout = open(filename, 'w')
909 sleep_time = gap - SLEEP_OVERHEAD
911 time.sleep(sleep_time)
913 miss = t - (time.time() - start)
914 self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
917 print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
919 traceback.print_exc(sys.stderr)
925 def replay(self, context=None):
928 for p in self.packets:
929 now = time.time() - start
930 gap = p.timestamp - now
931 sleep_time = gap - SLEEP_OVERHEAD
933 time.sleep(sleep_time)
935 miss = p.timestamp - (time.time() - start)
937 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
940 p.play(self, context)
942 def guess_client_server(self, server_clue=None):
943 """Have a go at deciding who is the server and who is the client.
944 returns (client, server)
946 a, b = self.endpoints
948 if self.client_balance < 0:
951 # in the absense of a clue, we will fall through to assuming
952 # the lowest number is the server (which is usually true).
954 if self.client_balance == 0 and server_clue == b:
959 def forget_packets_outside_window(self, s, e):
960 """Prune any packets outside the timne window we're interested in
962 :param s: start of the window
963 :param e: end of the window
965 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
966 self.start_time = self.packets[0].timestamp if self.packets else None
968 def renormalise_times(self, start_time):
969 """Adjust the packet start times relative to the new start time."""
970 for p in self.packets:
971 p.timestamp -= start_time
973 if self.start_time is not None:
974 self.start_time -= start_time
977 class DnsHammer(Conversation):
978 """A lightweight conversation that generates a lot of dns:0 packets on
981 def __init__(self, dns_rate, duration):
982 n = int(dns_rate * duration)
983 self.times = [random.uniform(0, duration) for i in range(n)]
986 self.duration = duration
988 self.msg = random_colour_print()
991 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
992 (len(self.times), self.duration, self.rate))
994 def replay_in_fork_with_delay(self, start, context=None, account=None):
995 return Conversation.replay_in_fork_with_delay(self,
1000 def replay(self, context=None):
1002 fn = traffic_packets.packet_dns_0
1003 for t in self.times:
1004 now = time.time() - start
1006 sleep_time = gap - SLEEP_OVERHEAD
1008 time.sleep(sleep_time)
1011 miss = t - (time.time() - start)
1012 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1016 packet_start = time.time()
1018 fn(self, self, context)
1020 duration = end - packet_start
1021 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1022 except Exception as e:
1024 duration = end - packet_start
1025 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1028 def ingest_summaries(files, dns_mode='count'):
1029 """Load a summary traffic summary file and generated Converations from it.
1032 dns_counts = defaultdict(int)
1035 if isinstance(f, str):
1037 print("Ingesting %s" % (f.name,), file=sys.stderr)
1039 p = Packet.from_line(line)
1040 if p.protocol == 'dns' and dns_mode != 'include':
1041 dns_counts[p.opcode] += 1
1050 start_time = min(p.timestamp for p in packets)
1051 last_packet = max(p.timestamp for p in packets)
1053 print("gathering packets into conversations", file=sys.stderr)
1054 conversations = OrderedDict()
1056 p.timestamp -= start_time
1057 c = conversations.get(p.endpoints)
1060 conversations[p.endpoints] = c
1063 # We only care about conversations with actual traffic, so we
1064 # filter out conversations with nothing to say. We do that here,
1065 # rather than earlier, because those empty packets contain useful
1066 # hints as to which end of the conversation was the client.
1067 conversation_list = []
1068 for c in conversations.values():
1070 conversation_list.append(c)
1072 # This is obviously not correct, as many conversations will appear
1073 # to start roughly simultaneously at the beginning of the snapshot.
1074 # To which we say: oh well, so be it.
1075 duration = float(last_packet - start_time)
1076 mean_interval = len(conversations) / duration
1078 return conversation_list, mean_interval, duration, dns_counts
1081 def guess_server_address(conversations):
1082 # we guess the most common address.
1083 addresses = Counter()
1084 for c in conversations:
1085 addresses.update(c.endpoints)
1087 return addresses.most_common(1)[0]
1090 def stringify_keys(x):
1092 for k, v in x.items():
1098 def unstringify_keys(x):
1100 for k, v in x.items():
1101 t = tuple(str(k).split('\t'))
1106 class TrafficModel(object):
1107 def __init__(self, n=3):
1109 self.query_details = {}
1111 self.dns_opcounts = defaultdict(int)
1112 self.cumulative_duration = 0.0
1113 self.conversation_rate = [0, 1]
1115 def learn(self, conversations, dns_opcounts={}):
1118 key = (NON_PACKET,) * (self.n - 1)
1120 server = guess_server_address(conversations)
1122 for k, v in dns_opcounts.items():
1123 self.dns_opcounts[k] += v
1125 if len(conversations) > 1:
1127 conversations[-1].start_time - conversations[0].start_time
1128 self.conversation_rate[0] = len(conversations)
1129 self.conversation_rate[1] = elapsed
1131 for c in conversations:
1132 client, server = c.guess_client_server(server)
1133 cum_duration += c.get_duration()
1134 key = (NON_PACKET,) * (self.n - 1)
1139 elapsed = p.timestamp - prev
1141 if elapsed > WAIT_THRESHOLD:
1142 # add the wait as an extra state
1143 wait = 'wait:%d' % (math.log(max(1.0,
1144 elapsed * WAIT_SCALE)))
1145 self.ngrams.setdefault(key, []).append(wait)
1146 key = key[1:] + (wait,)
1148 short_p = p.as_packet_type()
1149 self.query_details.setdefault(short_p,
1150 []).append(tuple(p.extra))
1151 self.ngrams.setdefault(key, []).append(short_p)
1152 key = key[1:] + (short_p,)
1154 self.cumulative_duration += cum_duration
1156 self.ngrams.setdefault(key, []).append(NON_PACKET)
1160 for k, v in self.ngrams.items():
1162 ngrams[k] = dict(Counter(v))
1165 for k, v in self.query_details.items():
1166 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1171 'query_details': query_details,
1172 'cumulative_duration': self.cumulative_duration,
1173 'conversation_rate': self.conversation_rate,
1175 d['dns'] = self.dns_opcounts
1177 if isinstance(f, str):
1180 json.dump(d, f, indent=2)
1183 if isinstance(f, str):
1188 for k, v in d['ngrams'].items():
1189 k = tuple(str(k).split('\t'))
1190 values = self.ngrams.setdefault(k, [])
1191 for p, count in v.items():
1192 values.extend([str(p)] * count)
1194 for k, v in d['query_details'].items():
1195 values = self.query_details.setdefault(str(k), [])
1196 for p, count in v.items():
1198 values.extend([()] * count)
1200 values.extend([tuple(str(p).split('\t'))] * count)
1203 for k, v in d['dns'].items():
1204 self.dns_opcounts[k] += v
1206 self.cumulative_duration = d['cumulative_duration']
1207 self.conversation_rate = d['conversation_rate']
1209 def construct_conversation(self, timestamp=0.0, client=2, server=1,
1210 hard_stop=None, packet_rate=1):
1211 """Construct a individual converation from the model."""
1213 c = Conversation(timestamp, (server, client))
1215 key = (NON_PACKET,) * (self.n - 1)
1217 while key in self.ngrams:
1218 p = random.choice(self.ngrams.get(key, NON_PACKET))
1221 if p in self.query_details:
1222 extra = random.choice(self.query_details[p])
1226 protocol, opcode = p.split(':', 1)
1227 if protocol == 'wait':
1228 log_wait_time = int(opcode) + random.random()
1229 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1232 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1233 wait = math.exp(log_wait) / packet_rate
1235 if hard_stop is not None and timestamp > hard_stop:
1237 c.add_short_packet(timestamp, protocol, opcode, extra)
1239 key = key[1:] + (p,)
1243 def generate_conversations(self, rate, duration, packet_rate=1):
1244 """Generate a list of conversations from the model."""
1246 # We run the simulation for at least ten times as long as our
1247 # desired duration, and take a section near the start.
1248 rate_n, rate_t = self.conversation_rate
1250 duration2 = max(rate_t, duration * 2)
1251 n = rate * duration2 * rate_n / rate_t
1258 start = end - duration
1260 while client < n + 2:
1261 start = random.uniform(0, duration2)
1262 c = self.construct_conversation(start,
1265 hard_stop=(duration2 * 5),
1266 packet_rate=packet_rate)
1268 c.forget_packets_outside_window(start, end)
1269 c.renormalise_times(start)
1271 conversations.append(c)
1274 print(("we have %d conversations at rate %f" %
1275 (len(conversations), rate)), file=sys.stderr)
1276 conversations.sort()
1277 return conversations
1282 'rpc_netlogon': '06',
1283 'kerberos': '06', # ratio 16248:258
1294 'smb_netlogon': '11',
1300 ('browser', '0x01'): 'Host Announcement (0x01)',
1301 ('browser', '0x02'): 'Request Announcement (0x02)',
1302 ('browser', '0x08'): 'Browser Election Request (0x08)',
1303 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1304 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1305 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1306 ('cldap', '3'): 'searchRequest',
1307 ('cldap', '5'): 'searchResDone',
1308 ('dcerpc', '0'): 'Request',
1309 ('dcerpc', '11'): 'Bind',
1310 ('dcerpc', '12'): 'Bind_ack',
1311 ('dcerpc', '13'): 'Bind_nak',
1312 ('dcerpc', '14'): 'Alter_context',
1313 ('dcerpc', '15'): 'Alter_context_resp',
1314 ('dcerpc', '16'): 'AUTH3',
1315 ('dcerpc', '2'): 'Response',
1316 ('dns', '0'): 'query',
1317 ('dns', '1'): 'response',
1318 ('drsuapi', '0'): 'DsBind',
1319 ('drsuapi', '12'): 'DsCrackNames',
1320 ('drsuapi', '13'): 'DsWriteAccountSpn',
1321 ('drsuapi', '1'): 'DsUnbind',
1322 ('drsuapi', '2'): 'DsReplicaSync',
1323 ('drsuapi', '3'): 'DsGetNCChanges',
1324 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1325 ('epm', '3'): 'Map',
1326 ('kerberos', ''): '',
1327 ('ldap', '0'): 'bindRequest',
1328 ('ldap', '1'): 'bindResponse',
1329 ('ldap', '2'): 'unbindRequest',
1330 ('ldap', '3'): 'searchRequest',
1331 ('ldap', '4'): 'searchResEntry',
1332 ('ldap', '5'): 'searchResDone',
1333 ('ldap', ''): '*** Unknown ***',
1334 ('lsarpc', '14'): 'lsa_LookupNames',
1335 ('lsarpc', '15'): 'lsa_LookupSids',
1336 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1337 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1338 ('lsarpc', '6'): 'lsa_OpenPolicy',
1339 ('lsarpc', '76'): 'lsa_LookupSids3',
1340 ('lsarpc', '77'): 'lsa_LookupNames4',
1341 ('nbns', '0'): 'query',
1342 ('nbns', '1'): 'response',
1343 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1344 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1345 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1346 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1347 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1348 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1349 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1350 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1351 ('samr', '0',): 'Connect',
1352 ('samr', '16'): 'GetAliasMembership',
1353 ('samr', '17'): 'LookupNames',
1354 ('samr', '18'): 'LookupRids',
1355 ('samr', '19'): 'OpenGroup',
1356 ('samr', '1'): 'Close',
1357 ('samr', '25'): 'QueryGroupMember',
1358 ('samr', '34'): 'OpenUser',
1359 ('samr', '36'): 'QueryUserInfo',
1360 ('samr', '39'): 'GetGroupsForUser',
1361 ('samr', '3'): 'QuerySecurity',
1362 ('samr', '5'): 'LookupDomain',
1363 ('samr', '64'): 'Connect5',
1364 ('samr', '6'): 'EnumDomains',
1365 ('samr', '7'): 'OpenDomain',
1366 ('samr', '8'): 'QueryDomainInfo',
1367 ('smb', '0x04'): 'Close (0x04)',
1368 ('smb', '0x24'): 'Locking AndX (0x24)',
1369 ('smb', '0x2e'): 'Read AndX (0x2e)',
1370 ('smb', '0x32'): 'Trans2 (0x32)',
1371 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1372 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1373 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1374 ('smb', '0x74'): 'Logoff AndX (0x74)',
1375 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1376 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1377 ('smb2', '0'): 'NegotiateProtocol',
1378 ('smb2', '11'): 'Ioctl',
1379 ('smb2', '14'): 'Find',
1380 ('smb2', '16'): 'GetInfo',
1381 ('smb2', '18'): 'Break',
1382 ('smb2', '1'): 'SessionSetup',
1383 ('smb2', '2'): 'SessionLogoff',
1384 ('smb2', '3'): 'TreeConnect',
1385 ('smb2', '4'): 'TreeDisconnect',
1386 ('smb2', '5'): 'Create',
1387 ('smb2', '6'): 'Close',
1388 ('smb2', '8'): 'Read',
1389 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1390 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1391 'user unknown (0x17)'),
1392 ('srvsvc', '16'): 'NetShareGetInfo',
1393 ('srvsvc', '21'): 'NetSrvGetInfo',
1397 def expand_short_packet(p, timestamp, src, dest, extra):
1398 protocol, opcode = p.split(':', 1)
1399 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1400 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1402 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1404 return '\t'.join(line)
1407 def replay(conversations,
1416 context = ReplayContext(server=host,
1421 if len(accounts) < len(conversations):
1422 print(("we have %d accounts but %d conversations" %
1423 (accounts, conversations)), file=sys.stderr)
1426 sorted(conversations, key=lambda x: x.start_time, reverse=True),
1429 # Set the process group so that the calling scripts are not killed
1430 # when the forked child processes are killed.
1435 if duration is None:
1436 # end 1 second after the last packet of the last conversation
1437 # to start. Conversations other than the last could still be
1438 # going, but we don't care.
1439 duration = cstack[0][0].packets[-1].timestamp + 1.0
1440 print("We will stop after %.1f seconds" % duration,
1443 end = start + duration
1445 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1446 % (len(conversations), duration))
1450 dns_hammer = DnsHammer(dns_rate, duration)
1451 cstack.append((dns_hammer, None))
1455 # we spawn a batch, wait for finishers, then spawn another
1457 batch_end = min(now + 2.0, end)
1461 c, account = cstack.pop()
1462 if c.start_time + start > batch_end:
1463 cstack.append((c, account))
1467 pid = c.replay_in_fork_with_delay(start, context, account)
1471 fork_time += elapsed
1473 print("forked %s in pid %s (in %fs)" % (c, pid,
1478 print(("forked %d times in %f seconds (avg %f)" %
1479 (fork_n, fork_time, fork_time / fork_n)),
1482 debug(2, "no forks in batch ending %f" % batch_end)
1484 while time.time() < batch_end - 1.0:
1487 pid, status = os.waitpid(-1, os.WNOHANG)
1488 except OSError as e:
1489 if e.errno != 10: # no child processes
1493 c = children.pop(pid, None)
1494 print(("process %d finished conversation %s;"
1496 (pid, c, len(children))), file=sys.stderr)
1498 if time.time() >= end:
1499 print("time to stop", file=sys.stderr)
1503 print("EXCEPTION in parent", file=sys.stderr)
1504 traceback.print_exc()
1506 for s in (15, 15, 9):
1507 print(("killing %d children with -%d" %
1508 (len(children), s)), file=sys.stderr)
1509 for pid in children:
1512 except OSError as e:
1513 if e.errno != 3: # don't fail if it has already died
1516 end = time.time() + 1
1519 pid, status = os.waitpid(-1, os.WNOHANG)
1520 except OSError as e:
1524 c = children.pop(pid, None)
1525 print(("kill -%d %d KILLED conversation %s; "
1527 (s, pid, c, len(children))),
1529 if time.time() >= end:
1537 print("%d children are missing" % len(children),
1540 # there may be stragglers that were forked just as ^C was hit
1541 # and don't appear in the list of children. We can get them
1542 # with killpg, but that will also kill us, so this is^H^H would be
1543 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1544 # so as not to have to fuss around writing signal handlers.
1547 except KeyboardInterrupt:
1548 print("ignoring fake ^C", file=sys.stderr)
1551 def openLdb(host, creds, lp):
1552 session = system_session()
1553 ldb = SamDB(url="ldap://%s" % host,
1554 session_info=session,
1555 options=['modules:paged_searches'],
1561 def ou_name(ldb, instance_id):
1562 """Generate an ou name from the instance id"""
1563 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1567 def create_ou(ldb, instance_id):
1568 """Create an ou, all created user and machine accounts will belong to it.
1570 This allows all the created resources to be cleaned up easily.
1572 ou = ou_name(ldb, instance_id)
1574 ldb.add({"dn": ou.split(',', 1)[1],
1575 "objectclass": "organizationalunit"})
1576 except LdbError as e:
1577 (status, _) = e.args
1578 # ignore already exists
1583 "objectclass": "organizationalunit"})
1584 except LdbError as e:
1585 (status, _) = e.args
1586 # ignore already exists
1592 class ConversationAccounts(object):
1593 """Details of the machine and user accounts associated with a conversation.
1595 def __init__(self, netbios_name, machinepass, username, userpass):
1596 self.netbios_name = netbios_name
1597 self.machinepass = machinepass
1598 self.username = username
1599 self.userpass = userpass
1602 def generate_replay_accounts(ldb, instance_id, number, password):
1603 """Generate a series of unique machine and user account names."""
1605 generate_traffic_accounts(ldb, instance_id, number, password)
1607 for i in range(1, number + 1):
1608 netbios_name = "STGM-%d-%d" % (instance_id, i)
1609 username = "STGU-%d-%d" % (instance_id, i)
1611 account = ConversationAccounts(netbios_name, password, username,
1613 accounts.append(account)
1617 def generate_traffic_accounts(ldb, instance_id, number, password):
1618 """Create the specified number of user and machine accounts.
1620 As accounts are not explicitly deleted between runs. This function starts
1621 with the last account and iterates backwards stopping either when it
1622 finds an already existing account or it has generated all the required
1625 print(("Generating machine and conversation accounts, "
1626 "as required for %d conversations" % number),
1629 for i in range(number, 0, -1):
1631 netbios_name = "STGM-%d-%d" % (instance_id, i)
1632 create_machine_account(ldb, instance_id, netbios_name, password)
1634 except LdbError as e:
1635 (status, _) = e.args
1641 print("Added %d new machine accounts" % added,
1645 for i in range(number, 0, -1):
1647 username = "STGU-%d-%d" % (instance_id, i)
1648 create_user_account(ldb, instance_id, username, password)
1650 except LdbError as e:
1651 (status, _) = e.args
1658 print("Added %d new user accounts" % added,
1662 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1663 """Create a machine account via ldap."""
1665 ou = ou_name(ldb, instance_id)
1666 dn = "cn=%s,%s" % (netbios_name, ou)
1667 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1672 "objectclass": "computer",
1673 "sAMAccountName": "%s$" % netbios_name,
1674 "userAccountControl":
1675 str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
1676 "unicodePwd": utf16pw})
1678 duration = end - start
1679 LOGGER.info("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1682 def create_user_account(ldb, instance_id, username, userpass):
1683 """Create a user account via ldap."""
1684 ou = ou_name(ldb, instance_id)
1685 user_dn = "cn=%s,%s" % (username, ou)
1686 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1690 "objectclass": "user",
1691 "sAMAccountName": username,
1692 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1693 "unicodePwd": utf16pw
1696 # grant user write permission to do things like write account SPN
1697 sdutils = sd_utils.SDUtils(ldb)
1698 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1701 duration = end - start
1702 LOGGER.info("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1705 def create_group(ldb, instance_id, name):
1706 """Create a group via ldap."""
1708 ou = ou_name(ldb, instance_id)
1709 dn = "cn=%s,%s" % (name, ou)
1713 "objectclass": "group",
1714 "sAMAccountName": name,
1717 duration = end - start
1718 LOGGER.info("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1721 def user_name(instance_id, i):
1722 """Generate a user name based in the instance id"""
1723 return "STGU-%d-%d" % (instance_id, i)
1726 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1727 """Seach objectclass, return attr in a set"""
1729 expression="(objectClass={})".format(objectclass),
1732 return {str(obj[attr]) for obj in objs}
1735 def generate_users(ldb, instance_id, number, password):
1736 """Add users to the server"""
1737 existing_objects = search_objectclass(ldb, objectclass='user')
1739 for i in range(number, 0, -1):
1740 name = user_name(instance_id, i)
1741 if name not in existing_objects:
1742 create_user_account(ldb, instance_id, name, password)
1748 def group_name(instance_id, i):
1749 """Generate a group name from instance id."""
1750 return "STGG-%d-%d" % (instance_id, i)
1753 def generate_groups(ldb, instance_id, number):
1754 """Create the required number of groups on the server."""
1755 existing_objects = search_objectclass(ldb, objectclass='group')
1757 for i in range(number, 0, -1):
1758 name = group_name(instance_id, i)
1759 if name not in existing_objects:
1760 create_group(ldb, instance_id, name)
1766 def clean_up_accounts(ldb, instance_id):
1767 """Remove the created accounts and groups from the server."""
1768 ou = ou_name(ldb, instance_id)
1770 ldb.delete(ou, ["tree_delete:1"])
1771 except LdbError as e:
1772 (status, _) = e.args
1773 # ignore does not exist
1778 def generate_users_and_groups(ldb, instance_id, password,
1779 number_of_users, number_of_groups,
1781 """Generate the required users and groups, allocating the users to
1783 memberships_added = 0
1786 create_ou(ldb, instance_id)
1788 print("Generating dummy user accounts", file=sys.stderr)
1789 users_added = generate_users(ldb, instance_id, number_of_users, password)
1791 if number_of_groups > 0:
1792 print("Generating dummy groups", file=sys.stderr)
1793 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1795 if group_memberships > 0:
1796 print("Assigning users to groups", file=sys.stderr)
1797 assignments = GroupAssignments(number_of_groups,
1802 print("Adding users to groups", file=sys.stderr)
1803 add_users_to_groups(ldb, instance_id, assignments.assignments)
1804 memberships_added = assignments.total()
1806 if (groups_added > 0 and users_added == 0 and
1807 number_of_groups != groups_added):
1808 print("Warning: the added groups will contain no members",
1811 print(("Added %d users, %d groups and %d group memberships" %
1812 (users_added, groups_added, memberships_added)),
1816 class GroupAssignments(object):
1817 def __init__(self, number_of_groups, groups_added, number_of_users,
1818 users_added, group_memberships):
1820 self.generate_group_distribution(number_of_groups)
1821 self.generate_user_distribution(number_of_users, group_memberships)
1822 self.assignments = self.assign_groups(number_of_groups,
1828 def cumulative_distribution(self, weights):
1829 # make sure the probabilities conform to a cumulative distribution
1830 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1831 # probability a proportional share of 1.0. Higher probabilities get a
1832 # bigger share, so are more likely to be picked. We use the cumulative
1833 # value, so we can use random.random() as a simple index into the list
1835 total = sum(weights)
1837 for probability in weights:
1838 cumulative += probability
1839 dist.append(cumulative / total)
1842 def generate_user_distribution(self, num_users, num_memberships):
1843 """Probability distribution of a user belonging to a group.
1845 # Assign a weighted probability to each user. Use the Pareto
1846 # Distribution so that some users are in a lot of groups, and the
1847 # bulk of users are in only a few groups. If we're assigning a large
1848 # number of group memberships, use a higher shape. This means slightly
1849 # fewer outlying users that are in large numbers of groups. The aim is
1850 # to have no users belonging to more than ~500 groups.
1851 if num_memberships > 5000000:
1853 elif num_memberships > 2000000:
1855 elif num_memberships > 300000:
1861 for x in range(1, num_users + 1):
1862 p = random.paretovariate(shape)
1865 # convert the weights to a cumulative distribution between 0.0 and 1.0
1866 self.user_dist = self.cumulative_distribution(weights)
1868 def generate_group_distribution(self, n):
1869 """Probability distribution of a group containing a user."""
1871 # Assign a weighted probability to each user. Probability decreases
1872 # as the group-ID increases
1874 for x in range(1, n + 1):
1878 # convert the weights to a cumulative distribution between 0.0 and 1.0
1879 self.group_dist = self.cumulative_distribution(weights)
1881 def generate_random_membership(self):
1882 """Returns a randomly generated user-group membership"""
1884 # the list items are cumulative distribution values between 0.0 and
1885 # 1.0, which makes random() a handy way to index the list to get a
1886 # weighted random user/group. (Here the user/group returned are
1887 # zero-based array indexes)
1888 user = bisect.bisect(self.user_dist, random.random())
1889 group = bisect.bisect(self.group_dist, random.random())
1893 def assign_groups(self, number_of_groups, groups_added,
1894 number_of_users, users_added, group_memberships):
1895 """Allocate users to groups.
1897 The intention is to have a few users that belong to most groups, while
1898 the majority of users belong to a few groups.
1900 A few groups will contain most users, with the remaining only having a
1905 if group_memberships <= 0:
1908 # Calculate the number of group menberships required
1909 group_memberships = math.ceil(
1910 float(group_memberships) *
1911 (float(users_added) / float(number_of_users)))
1913 existing_users = number_of_users - users_added - 1
1914 existing_groups = number_of_groups - groups_added - 1
1915 while len(assignments) < group_memberships:
1916 user, group = self.generate_random_membership()
1918 if group > existing_groups or user > existing_users:
1919 # the + 1 converts the array index to the corresponding
1920 # group or user number
1921 assignments.add(((user + 1), (group + 1)))
1926 return len(self.assignments)
1929 def add_users_to_groups(db, instance_id, assignments):
1930 """Add users to their assigned groups.
1932 Takes the list of (group,user) tuples generated by assign_groups and
1933 assign the users to their specified groups."""
1935 ou = ou_name(db, instance_id)
1938 return("cn=%s,%s" % (name, ou))
1940 for (user, group) in assignments:
1941 user_dn = build_dn(user_name(instance_id, user))
1942 group_dn = build_dn(group_name(instance_id, group))
1945 m.dn = ldb.Dn(db, group_dn)
1946 m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1950 duration = end - start
1951 LOGGER.info("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1954 def generate_stats(statsdir, timing_file):
1955 """Generate and print the summary stats for a run."""
1956 first = sys.float_info.max
1962 unique_converations = set()
1965 if timing_file is not None:
1966 tw = timing_file.write
1971 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1973 for filename in os.listdir(statsdir):
1974 path = os.path.join(statsdir, filename)
1975 with open(path, 'r') as f:
1978 fields = line.rstrip('\n').split('\t')
1979 conversation = fields[1]
1980 protocol = fields[2]
1981 packet_type = fields[3]
1982 latency = float(fields[4])
1983 first = min(float(fields[0]) - latency, first)
1984 last = max(float(fields[0]), last)
1986 if protocol not in latencies:
1987 latencies[protocol] = {}
1988 if packet_type not in latencies[protocol]:
1989 latencies[protocol][packet_type] = []
1991 latencies[protocol][packet_type].append(latency)
1993 if protocol not in failures:
1994 failures[protocol] = {}
1995 if packet_type not in failures[protocol]:
1996 failures[protocol][packet_type] = 0
1998 if fields[5] == 'True':
2002 failures[protocol][packet_type] += 1
2004 if conversation not in unique_converations:
2005 unique_converations.add(conversation)
2009 except (ValueError, IndexError):
2010 # not a valid line print and ignore
2011 print(line, file=sys.stderr)
2013 duration = last - first
2017 success_rate = successful / duration
2021 failure_rate = failed / duration
2023 print("Total conversations: %10d" % conversations)
2024 print("Successful operations: %10d (%.3f per second)"
2025 % (successful, success_rate))
2026 print("Failed operations: %10d (%.3f per second)"
2027 % (failed, failure_rate))
2029 print("Protocol Op Code Description "
2030 " Count Failed Mean Median "
2033 protocols = sorted(latencies.keys())
2034 for protocol in protocols:
2035 packet_types = sorted(latencies[protocol], key=opcode_key)
2036 for packet_type in packet_types:
2037 values = latencies[protocol][packet_type]
2038 values = sorted(values)
2040 failed = failures[protocol][packet_type]
2041 mean = sum(values) / count
2042 median = calc_percentile(values, 0.50)
2043 percentile = calc_percentile(values, 0.95)
2044 rng = values[-1] - values[0]
2046 desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2047 if sys.stdout.isatty:
2048 print("%-12s %4s %-35s %12d %12d %12.6f "
2049 "%12.6f %12.6f %12.6f %12.6f"
2061 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2075 """Sort key for the operation code to ensure that it sorts numerically"""
2077 return "%03d" % int(v)
2082 def calc_percentile(values, percentile):
2083 """Calculate the specified percentile from the list of values.
2085 Assumes the list is sorted in ascending order.
2090 k = (len(values) - 1) * percentile
2094 return values[int(k)]
2095 d0 = values[int(f)] * (c - k)
2096 d1 = values[int(c)] * (k - f)
2100 def mk_masked_dir(*path):
2101 """In a testenv we end up with 0777 diectories that look an alarming
2102 green colour with ls. Use umask to avoid that."""
2103 d = os.path.join(*path)
2104 mask = os.umask(0o077)