1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
28 from errno import ECHILD, ESRCH
30 from collections import OrderedDict, Counter, defaultdict, namedtuple
31 from dns.resolver import query as dns_query
33 from samba.emulate import traffic_packets
34 from samba.samdb import SamDB
36 from ldb import LdbError
37 from samba.dcerpc import ClientConnection
38 from samba.dcerpc import security, drsuapi, lsa
39 from samba.dcerpc import netlogon
40 from samba.dcerpc.netlogon import netr_Authenticator
41 from samba.dcerpc import srvsvc
42 from samba.dcerpc import samr
43 from samba.drs_utils import drs_DsBind
45 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
46 from samba.auth import system_session
47 from samba.dsdb import (
49 UF_SERVER_TRUST_ACCOUNT,
50 UF_TRUSTED_FOR_DELEGATION,
51 UF_WORKSTATION_TRUST_ACCOUNT
53 from samba.dcerpc.misc import SEC_CHAN_BDC
54 from samba import gensec
55 from samba import sd_utils
56 from samba.common import get_string
57 from samba.logger import get_samba_logger
60 CURRENT_MODEL_VERSION = 2 # save as this
61 REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
64 # we don't use None, because it complicates [de]serialisation
68 ('dns', '0'): 1.0, # query
69 ('smb', '0x72'): 1.0, # Negotiate protocol
70 ('ldap', '0'): 1.0, # bind
71 ('ldap', '3'): 1.0, # searchRequest
72 ('ldap', '2'): 1.0, # unbindRequest
74 ('dcerpc', '11'): 1.0, # bind
75 ('dcerpc', '14'): 1.0, # Alter_context
76 ('nbns', '0'): 1.0, # query
80 ('dns', '1'): 1.0, # response
81 ('ldap', '1'): 1.0, # bind response
82 ('ldap', '4'): 1.0, # search result
83 ('ldap', '5'): 1.0, # search done
85 ('dcerpc', '12'): 1.0, # bind_ack
86 ('dcerpc', '13'): 1.0, # bind_nak
87 ('dcerpc', '15'): 1.0, # Alter_context response
90 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
93 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
94 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
96 # DEBUG_LEVEL can be changed by scripts with -d
99 LOGGER = get_samba_logger(name=__name__)
102 def debug(level, msg, *args):
103 """Print a formatted debug message to standard error.
106 :param level: The debug level, message will be printed if it is <= the
107 currently set debug level. The debug level can be set with
109 :param msg: The message to be logged, can contain C-Style format
111 :param args: The parameters required by the format specifiers
113 if level <= DEBUG_LEVEL:
115 print(msg, file=sys.stderr)
117 print(msg % tuple(args), file=sys.stderr)
120 def debug_lineno(*args):
121 """ Print an unformatted log message to stderr, contaning the line number
123 tb = traceback.extract_stack(limit=2)
124 print((" %s:" "\033[01;33m"
125 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
128 print(a, file=sys.stderr)
129 print(file=sys.stderr)
133 def random_colour_print(seeds):
134 """Return a function that prints a coloured line to stderr. The colour
135 of the line depends on a sort of hash of the integer arguments."""
142 prefix = "\033[38;5;%dm" % (18 + s)
147 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
152 print(a, file=sys.stderr)
157 class FakePacketError(Exception):
161 class Packet(object):
162 """Details of a network packet"""
163 __slots__ = ('timestamp',
173 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
174 protocol, opcode, desc, extra):
175 self.timestamp = timestamp
176 self.ip_protocol = ip_protocol
177 self.stream_number = stream_number
180 self.protocol = protocol
184 if self.src < self.dest:
185 self.endpoints = (self.src, self.dest)
187 self.endpoints = (self.dest, self.src)
190 def from_line(cls, line):
191 fields = line.rstrip('\n').split('\t')
202 timestamp = float(timestamp)
206 return cls(timestamp, ip_protocol, stream_number, src, dest,
207 protocol, opcode, desc, extra)
209 def as_summary(self, time_offset=0.0):
210 """Format the packet as a traffic_summary line.
212 extra = '\t'.join(self.extra)
213 t = self.timestamp + time_offset
214 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
217 self.stream_number or '',
226 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
227 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
228 self.stream_number, self.protocol, self.opcode, self.desc,
229 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
232 return "<Packet @%s>" % self
235 return self.__class__(self.timestamp,
245 def as_packet_type(self):
246 t = '%s:%s' % (self.protocol, self.opcode)
249 def client_score(self):
250 """A positive number means we think it is a client; a negative number
251 means we think it is a server. Zero means no idea. range: -1 to 1.
253 key = (self.protocol, self.opcode)
254 if key in CLIENT_CLUES:
255 return CLIENT_CLUES[key]
256 if key in SERVER_CLUES:
257 return -SERVER_CLUES[key]
260 def play(self, conversation, context):
261 """Send the packet over the network, if required.
263 Some packets are ignored, i.e. for protocols not handled,
264 server response messages, or messages that are generated by the
265 protocol layer associated with other packets.
267 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
269 fn = getattr(traffic_packets, fn_name)
271 except AttributeError as e:
272 print("Conversation(%s) Missing handler %s" %
273 (conversation.conversation_id, fn_name),
277 # Don't display a message for kerberos packets, they're not directly
278 # generated they're used to indicate kerberos should be used
279 if self.protocol != "kerberos":
280 debug(2, "Conversation(%s) Calling handler %s" %
281 (conversation.conversation_id, fn_name))
285 if fn(self, conversation, context):
286 # Only collect timing data for functions that generate
287 # network traffic, or fail
289 duration = end - start
290 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
291 (end, conversation.conversation_id, self.protocol,
292 self.opcode, duration))
293 except Exception as e:
295 duration = end - start
296 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
297 (end, conversation.conversation_id, self.protocol,
298 self.opcode, duration, e))
300 def __cmp__(self, other):
301 return self.timestamp - other.timestamp
303 def is_really_a_packet(self, missing_packet_stats=None):
304 return is_a_real_packet(self.protocol, self.opcode)
307 def is_a_real_packet(protocol, opcode):
308 """Is the packet one that can be ignored?
310 If so removing it will have no effect on the replay
312 if protocol in SKIPPED_PROTOCOLS:
313 # Ignore any packets for the protocols we're not interested in.
315 if protocol == "ldap" and opcode == '':
316 # skip ldap continuation packets
319 fn_name = 'packet_%s_%s' % (protocol, opcode)
320 fn = getattr(traffic_packets, fn_name, None)
322 LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
324 if fn is traffic_packets.null_packet:
329 def is_a_traffic_generating_packet(protocol, opcode):
330 """Return true if a packet generates traffic in its own right. Some of
331 these will generate traffic in certain contexts (e.g. ldap unbind
332 after a bind) but not if the conversation consists only of these packets.
334 if protocol == 'wait':
337 if (protocol, opcode) in (
344 return is_a_real_packet(protocol, opcode)
347 class ReplayContext(object):
348 """State/Context for a conversation between an simulated client and a
349 server. Some of the context is shared amongst all conversations
350 and should be generated before the fork, while other context is
351 specific to a particular conversation and should be generated
352 *after* the fork, in generate_process_local_config().
358 total_conversations=None,
359 badpassword_frequency=None,
360 prefer_kerberos=None,
365 domain=os.environ.get("DOMAIN"),
369 self.netlogon_connection = None
373 self.kerberos_state = MUST_USE_KERBEROS
375 self.kerberos_state = DONT_USE_KERBEROS
377 self.base_dn = base_dn
379 self.statsdir = statsdir
380 self.global_tempdir = tempdir
381 self.domain_sid = domain_sid
382 self.realm = lp.get('realm')
383 self.instance_id = instance_id
385 # Bad password attempt controls
386 self.badpassword_frequency = badpassword_frequency
387 self.last_lsarpc_bad = False
388 self.last_lsarpc_named_bad = False
389 self.last_simple_bind_bad = False
390 self.last_bind_bad = False
391 self.last_srvsvc_bad = False
392 self.last_drsuapi_bad = False
393 self.last_netlogon_bad = False
394 self.last_samlogon_bad = False
395 self.total_conversations = total_conversations
396 self.generate_ldap_search_tables()
398 def generate_ldap_search_tables(self):
399 session = system_session()
401 db = SamDB(url="ldap://%s" % self.server,
402 session_info=session,
403 credentials=self.creds,
406 res = db.search(db.domain_dn(),
407 scope=ldb.SCOPE_SUBTREE,
408 controls=["paged_results:1:1000"],
411 # find a list of dns for each pattern
412 # e.g. CN,CN,CN,DC,DC
414 attribute_clue_map = {
420 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
421 dns = dn_map.setdefault(pattern, [])
423 if dn.startswith('CN=NTDS Settings,'):
424 attribute_clue_map['invocationId'].append(dn)
426 # extend the map in case we are working with a different
427 # number of DC components.
428 # for k, v in self.dn_map.items():
429 # print >>sys.stderr, k, len(v)
431 for k in list(dn_map.keys()):
435 while p[-3:] == ',DC':
439 if p != k and p in dn_map:
440 print('dn_map collison %s %s' % (k, p),
443 dn_map[p] = dn_map[k]
446 self.attribute_clue_map = attribute_clue_map
448 # pre-populate DN-based search filters (it's simplest to generate them
449 # once, when the test starts). These are used by guess_search_filter()
450 # to avoid full-scans
451 self.search_filters = {}
453 # lookup all the GPO DNs
454 res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
455 expression='(objectclass=groupPolicyContainer)')
456 gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
458 # a search for the 'gPCFileSysPath' attribute is probably a GPO search
459 # (as per the MS-GPOL spec) which searches for GPOs by DN
460 self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
462 # likewise, a search for gpLink is probably the Domain SOM search part
463 # of the MS-GPOL, in which case it's looking up a few OUs by DN
465 for ou in ["Domain Controllers,", "traffic_replay,", ""]:
466 ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
467 self.search_filters['gpLink'] = "(|{0})".format(ou_str)
469 # The CEP Web Service can query the AD DC to get pKICertificateTemplate
470 # objects (as per MS-WCCE)
471 self.search_filters['pKIExtendedKeyUsage'] = \
472 '(objectCategory=pKICertificateTemplate)'
474 # assume that anything querying the usnChanged is some kind of
475 # synchronization tool, e.g. AD Change Detection Connector
476 res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
477 self.search_filters['usnChanged'] = \
478 '(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
480 # The traffic_learner script doesn't preserve the LDAP search filter, and
481 # having no filter can result in a full DB scan. This is costly for a large
482 # DB, and not necessarily representative of real world traffic. As there
483 # several standard LDAP queries that get used by AD tools, we can apply
484 # some logic and guess what the search filter might have been originally.
485 def guess_search_filter(self, attrs, dn_sig, dn):
487 # there are some standard spec-based searches that query fairly unique
488 # attributes. Check if the search is likely one of these
489 for key in self.search_filters.keys():
491 return self.search_filters[key]
493 # if it's the top-level domain, assume we're looking up a single user,
494 # e.g. like powershell Get-ADUser or a similar tool
495 if dn_sig == 'DC,DC':
496 random_user_id = random.random() % self.total_conversations
497 account_name = user_name(self.instance_id, random_user_id)
498 return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
500 # otherwise just return everything in the sub-tree
501 return '(objectClass=*)'
503 def generate_process_local_config(self, account, conversation):
504 self.ldap_connections = []
505 self.dcerpc_connections = []
506 self.lsarpc_connections = []
507 self.lsarpc_connections_named = []
508 self.drsuapi_connections = []
509 self.srvsvc_connections = []
510 self.samr_contexts = []
511 self.netbios_name = account.netbios_name
512 self.machinepass = account.machinepass
513 self.username = account.username
514 self.userpass = account.userpass
516 self.tempdir = mk_masked_dir(self.global_tempdir,
518 conversation.conversation_id)
520 self.lp.set("private dir", self.tempdir)
521 self.lp.set("lock dir", self.tempdir)
522 self.lp.set("state directory", self.tempdir)
523 self.lp.set("tls verify peer", "no_check")
525 self.remoteAddress = "/root/ncalrpc_as_system"
526 self.samlogon_dn = ("cn=%s,%s" %
527 (self.netbios_name, self.ou))
528 self.user_dn = ("cn=%s,%s" %
529 (self.username, self.ou))
531 self.generate_machine_creds()
532 self.generate_user_creds()
534 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
535 """Execute the supplied logon function, randomly choosing the
538 Based on the frequency in badpassword_frequency randomly perform the
539 function with the supplied bad credentials.
540 If run with bad credentials, the function is re-run with the good
542 failed_last_time is used to prevent consecutive bad credential
543 attempts. So the over all bad credential frequency will be lower
544 than that requested, but not significantly.
546 if not failed_last_time:
547 if (self.badpassword_frequency and
548 random.random() < self.badpassword_frequency):
552 # Ignore any exceptions as the operation may fail
553 # as it's being performed with bad credentials
555 failed_last_time = True
557 failed_last_time = False
560 return (result, failed_last_time)
562 def generate_user_creds(self):
563 """Generate the conversation specific user Credentials.
565 Each Conversation has an associated user account used to simulate
566 any non Administrative user traffic.
568 Generates user credentials with good and bad passwords and ldap
569 simple bind credentials with good and bad passwords.
571 self.user_creds = Credentials()
572 self.user_creds.guess(self.lp)
573 self.user_creds.set_workstation(self.netbios_name)
574 self.user_creds.set_password(self.userpass)
575 self.user_creds.set_username(self.username)
576 self.user_creds.set_domain(self.domain)
577 self.user_creds.set_kerberos_state(self.kerberos_state)
579 self.user_creds_bad = Credentials()
580 self.user_creds_bad.guess(self.lp)
581 self.user_creds_bad.set_workstation(self.netbios_name)
582 self.user_creds_bad.set_password(self.userpass[:-4])
583 self.user_creds_bad.set_username(self.username)
584 self.user_creds_bad.set_kerberos_state(self.kerberos_state)
586 # Credentials for ldap simple bind.
587 self.simple_bind_creds = Credentials()
588 self.simple_bind_creds.guess(self.lp)
589 self.simple_bind_creds.set_workstation(self.netbios_name)
590 self.simple_bind_creds.set_password(self.userpass)
591 self.simple_bind_creds.set_username(self.username)
592 self.simple_bind_creds.set_gensec_features(
593 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
594 self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
595 self.simple_bind_creds.set_bind_dn(self.user_dn)
597 self.simple_bind_creds_bad = Credentials()
598 self.simple_bind_creds_bad.guess(self.lp)
599 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
600 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
601 self.simple_bind_creds_bad.set_username(self.username)
602 self.simple_bind_creds_bad.set_gensec_features(
603 self.simple_bind_creds_bad.get_gensec_features() |
605 self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
606 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
608 def generate_machine_creds(self):
609 """Generate the conversation specific machine Credentials.
611 Each Conversation has an associated machine account.
613 Generates machine credentials with good and bad passwords.
616 self.machine_creds = Credentials()
617 self.machine_creds.guess(self.lp)
618 self.machine_creds.set_workstation(self.netbios_name)
619 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
620 self.machine_creds.set_password(self.machinepass)
621 self.machine_creds.set_username(self.netbios_name + "$")
622 self.machine_creds.set_domain(self.domain)
623 self.machine_creds.set_kerberos_state(self.kerberos_state)
625 self.machine_creds_bad = Credentials()
626 self.machine_creds_bad.guess(self.lp)
627 self.machine_creds_bad.set_workstation(self.netbios_name)
628 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
629 self.machine_creds_bad.set_password(self.machinepass[:-4])
630 self.machine_creds_bad.set_username(self.netbios_name + "$")
631 self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
633 def get_matching_dn(self, pattern, attributes=None):
634 # If the pattern is an empty string, we assume ROOTDSE,
635 # Otherwise we try adding or removing DC suffixes, then
636 # shorter leading patterns until we hit one.
637 # e.g if there is no CN,CN,CN,CN,DC,DC
638 # we first try CN,CN,CN,CN,DC
639 # and CN,CN,CN,CN,DC,DC,DC
640 # then change to CN,CN,CN,DC,DC
641 # and as last resort we use the base_dn
642 attr_clue = self.attribute_clue_map.get(attributes)
644 return random.choice(attr_clue)
646 pattern = pattern.upper()
648 if pattern in self.dn_map:
649 return random.choice(self.dn_map[pattern])
650 # chop one off the front and try it all again.
651 pattern = pattern[3:]
655 def get_dcerpc_connection(self, new=False):
656 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
657 if self.dcerpc_connections and not new:
658 return self.dcerpc_connections[-1]
659 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
661 self.dcerpc_connections.append(c)
664 def get_srvsvc_connection(self, new=False):
665 if self.srvsvc_connections and not new:
666 return self.srvsvc_connections[-1]
669 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
673 (c, self.last_srvsvc_bad) = \
674 self.with_random_bad_credentials(connect,
677 self.last_srvsvc_bad)
679 self.srvsvc_connections.append(c)
682 def get_lsarpc_connection(self, new=False):
683 if self.lsarpc_connections and not new:
684 return self.lsarpc_connections[-1]
687 binding_options = 'schannel,seal,sign'
688 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
689 (self.server, binding_options),
693 (c, self.last_lsarpc_bad) = \
694 self.with_random_bad_credentials(connect,
696 self.machine_creds_bad,
697 self.last_lsarpc_bad)
699 self.lsarpc_connections.append(c)
702 def get_lsarpc_named_pipe_connection(self, new=False):
703 if self.lsarpc_connections_named and not new:
704 return self.lsarpc_connections_named[-1]
707 return lsa.lsarpc("ncacn_np:%s" % (self.server),
711 (c, self.last_lsarpc_named_bad) = \
712 self.with_random_bad_credentials(connect,
714 self.machine_creds_bad,
715 self.last_lsarpc_named_bad)
717 self.lsarpc_connections_named.append(c)
720 def get_drsuapi_connection_pair(self, new=False, unbind=False):
721 """get a (drs, drs_handle) tuple"""
722 if self.drsuapi_connections and not new:
723 c = self.drsuapi_connections[-1]
727 binding_options = 'seal'
728 binding_string = "ncacn_ip_tcp:%s[%s]" %\
729 (self.server, binding_options)
730 return drsuapi.drsuapi(binding_string, self.lp, creds)
732 (drs, self.last_drsuapi_bad) = \
733 self.with_random_bad_credentials(connect,
736 self.last_drsuapi_bad)
738 (drs_handle, supported_extensions) = drs_DsBind(drs)
739 c = (drs, drs_handle)
740 self.drsuapi_connections.append(c)
743 def get_ldap_connection(self, new=False, simple=False):
744 if self.ldap_connections and not new:
745 return self.ldap_connections[-1]
747 def simple_bind(creds):
749 To run simple bind against Windows, we need to run
750 following commands in PowerShell:
752 Install-windowsfeature ADCS-Cert-Authority
753 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
757 return SamDB('ldaps://%s' % self.server,
761 def sasl_bind(creds):
762 return SamDB('ldap://%s' % self.server,
766 (samdb, self.last_simple_bind_bad) = \
767 self.with_random_bad_credentials(simple_bind,
768 self.simple_bind_creds,
769 self.simple_bind_creds_bad,
770 self.last_simple_bind_bad)
772 (samdb, self.last_bind_bad) = \
773 self.with_random_bad_credentials(sasl_bind,
778 self.ldap_connections.append(samdb)
781 def get_samr_context(self, new=False):
782 if not self.samr_contexts or new:
783 self.samr_contexts.append(
784 SamrContext(self.server, lp=self.lp, creds=self.creds))
785 return self.samr_contexts[-1]
787 def get_netlogon_connection(self):
789 if self.netlogon_connection:
790 return self.netlogon_connection
793 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
797 (c, self.last_netlogon_bad) = \
798 self.with_random_bad_credentials(connect,
800 self.machine_creds_bad,
801 self.last_netlogon_bad)
802 self.netlogon_connection = c
805 def guess_a_dns_lookup(self):
806 return (self.realm, 'A')
808 def get_authenticator(self):
809 auth = self.machine_creds.new_client_authenticator()
810 current = netr_Authenticator()
811 current.cred.data = [x if isinstance(x, int) else ord(x)
812 for x in auth["credential"]]
813 current.timestamp = auth["timestamp"]
815 subsequent = netr_Authenticator()
816 return (current, subsequent)
818 def write_stats(self, filename, **kwargs):
819 """Write arbitrary key/value pairs to a file in our stats directory in
820 order for them to be picked up later by another process working out
822 filename = os.path.join(self.statsdir, filename)
823 f = open(filename, 'w')
824 for k, v in kwargs.items():
825 print("%s: %s" % (k, v), file=f)
829 class SamrContext(object):
830 """State/Context associated with a samr connection.
832 def __init__(self, server, lp=None, creds=None):
833 self.connection = None
835 self.domain_handle = None
836 self.domain_sid = None
837 self.group_handle = None
838 self.user_handle = None
844 def get_connection(self):
845 if not self.connection:
846 self.connection = samr.samr(
847 "ncacn_ip_tcp:%s[seal]" % (self.server),
849 credentials=self.creds)
851 return self.connection
853 def get_handle(self):
855 c = self.get_connection()
856 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
860 class Conversation(object):
861 """Details of a converation between a simulated client and a server."""
862 def __init__(self, start_time=None, endpoints=None, seq=(),
863 conversation_id=None):
864 self.start_time = start_time
865 self.endpoints = endpoints
867 self.msg = random_colour_print(endpoints)
868 self.client_balance = 0.0
869 self.conversation_id = conversation_id
871 self.add_short_packet(*p)
873 def __cmp__(self, other):
874 if self.start_time is None:
875 if other.start_time is None:
878 if other.start_time is None:
880 return self.start_time - other.start_time
882 def add_packet(self, packet):
883 """Add a packet object to this conversation, making a local copy with
884 a conversation-relative timestamp."""
887 if self.start_time is None:
888 self.start_time = p.timestamp
890 if self.endpoints is None:
891 self.endpoints = p.endpoints
893 if p.endpoints != self.endpoints:
894 raise FakePacketError("Conversation endpoints %s don't match"
895 "packet endpoints %s" %
896 (self.endpoints, p.endpoints))
898 p.timestamp -= self.start_time
900 if p.src == p.endpoints[0]:
901 self.client_balance -= p.client_score()
903 self.client_balance += p.client_score()
905 if p.is_really_a_packet():
906 self.packets.append(p)
908 def add_short_packet(self, timestamp, protocol, opcode, extra,
909 client=True, skip_unused_packets=True):
910 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
911 (possibly empty) list of extra data. If client is True, assume
912 this packet is from the client to the server.
914 if skip_unused_packets and not is_a_real_packet(protocol, opcode):
917 src, dest = self.guess_client_server()
919 src, dest = dest, src
920 key = (protocol, opcode)
921 desc = OP_DESCRIPTIONS.get(key, '')
922 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
923 packet = Packet(timestamp - self.start_time, ip_protocol,
925 protocol, opcode, desc, extra)
926 # XXX we're assuming the timestamp is already adjusted for
928 # XXX should we adjust client balance for guessed packets?
929 if packet.src == packet.endpoints[0]:
930 self.client_balance -= packet.client_score()
932 self.client_balance += packet.client_score()
933 if packet.is_really_a_packet():
934 self.packets.append(packet)
937 return ("<Conversation %s %s starting %.3f %d packets>" %
938 (self.conversation_id, self.endpoints, self.start_time,
944 return iter(self.packets)
947 return len(self.packets)
949 def get_duration(self):
950 if len(self.packets) < 2:
952 return self.packets[-1].timestamp - self.packets[0].timestamp
954 def replay_as_summary_lines(self):
955 return [p.as_summary(self.start_time) for p in self.packets]
957 def replay_with_delay(self, start, context=None, account=None):
958 """Replay the conversation at the right time.
959 (We're already in a fork)."""
960 # first we sleep until the first packet
962 now = time.time() - start
964 sleep_time = gap - SLEEP_OVERHEAD
966 time.sleep(sleep_time)
968 miss = (time.time() - start) - t
969 self.msg("starting %s [miss %.3f]" % (self, miss))
973 # packet times are relative to conversation start
974 p_start = time.time()
975 for p in self.packets:
976 now = time.time() - p_start
977 gap = now - p.timestamp
981 sleep_time = -gap - SLEEP_OVERHEAD
983 time.sleep(sleep_time)
984 t = time.time() - p_start
985 if t - p.timestamp > max_sleep_miss:
986 max_sleep_miss = t - p.timestamp
988 p.play(self, context)
990 return max_gap, miss, max_sleep_miss
992 def guess_client_server(self, server_clue=None):
993 """Have a go at deciding who is the server and who is the client.
994 returns (client, server)
996 a, b = self.endpoints
998 if self.client_balance < 0:
1001 # in the absense of a clue, we will fall through to assuming
1002 # the lowest number is the server (which is usually true).
1004 if self.client_balance == 0 and server_clue == b:
1009 def forget_packets_outside_window(self, s, e):
1010 """Prune any packets outside the timne window we're interested in
1012 :param s: start of the window
1013 :param e: end of the window
1015 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
1016 self.start_time = self.packets[0].timestamp if self.packets else None
1018 def renormalise_times(self, start_time):
1019 """Adjust the packet start times relative to the new start time."""
1020 for p in self.packets:
1021 p.timestamp -= start_time
1023 if self.start_time is not None:
1024 self.start_time -= start_time
1027 class DnsHammer(Conversation):
1028 """A lightweight conversation that generates a lot of dns:0 packets on
1031 def __init__(self, dns_rate, duration, query_file=None):
1032 n = int(dns_rate * duration)
1033 self.times = [random.uniform(0, duration) for i in range(n)]
1035 self.rate = dns_rate
1036 self.duration = duration
1038 self.query_choices = self._get_query_choices(query_file=query_file)
1041 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1042 (len(self.times), self.duration, self.rate))
1044 def _get_query_choices(self, query_file=None):
1046 Read dns query choices from a file, or return default
1048 rname may contain format string like `{realm}`
1049 realm can be fetched from context.realm
1053 with open(query_file, 'r') as f:
1056 for line in text.splitlines():
1058 if line and not line.startswith('#'):
1059 args = line.split(',')
1060 assert len(args) == 4
1061 choices.append(args)
1065 (0, '{realm}', 'A', 'yes'),
1066 (1, '{realm}', 'NS', 'yes'),
1067 (2, '*.{realm}', 'A', 'no'),
1068 (3, '*.{realm}', 'NS', 'no'),
1069 (10, '_msdcs.{realm}', 'A', 'yes'),
1070 (11, '_msdcs.{realm}', 'NS', 'yes'),
1071 (20, 'nx.realm.com', 'A', 'no'),
1072 (21, 'nx.realm.com', 'NS', 'no'),
1073 (22, '*.nx.realm.com', 'A', 'no'),
1074 (23, '*.nx.realm.com', 'NS', 'no'),
1077 def replay(self, context=None):
1079 assert context.realm
1081 for t in self.times:
1082 now = time.time() - start
1084 sleep_time = gap - SLEEP_OVERHEAD
1086 time.sleep(sleep_time)
1088 opcode, rname, rtype, exist = random.choice(self.query_choices)
1089 rname = rname.format(realm=context.realm)
1091 packet_start = time.time()
1093 answers = dns_query(rname, rtype)
1094 if exist == 'yes' and not len(answers):
1095 # expect answers but didn't get, fail
1101 duration = end - packet_start
1102 print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
1105 def ingest_summaries(files, dns_mode='count'):
1106 """Load a summary traffic summary file and generated Converations from it.
1109 dns_counts = defaultdict(int)
1112 if isinstance(f, str):
1114 print("Ingesting %s" % (f.name,), file=sys.stderr)
1116 p = Packet.from_line(line)
1117 if p.protocol == 'dns' and dns_mode != 'include':
1118 dns_counts[p.opcode] += 1
1127 start_time = min(p.timestamp for p in packets)
1128 last_packet = max(p.timestamp for p in packets)
1130 print("gathering packets into conversations", file=sys.stderr)
1131 conversations = OrderedDict()
1132 for i, p in enumerate(packets):
1133 p.timestamp -= start_time
1134 c = conversations.get(p.endpoints)
1136 c = Conversation(conversation_id=(i + 2))
1137 conversations[p.endpoints] = c
1140 # We only care about conversations with actual traffic, so we
1141 # filter out conversations with nothing to say. We do that here,
1142 # rather than earlier, because those empty packets contain useful
1143 # hints as to which end of the conversation was the client.
1144 conversation_list = []
1145 for c in conversations.values():
1147 conversation_list.append(c)
1149 # This is obviously not correct, as many conversations will appear
1150 # to start roughly simultaneously at the beginning of the snapshot.
1151 # To which we say: oh well, so be it.
1152 duration = float(last_packet - start_time)
1153 mean_interval = len(conversations) / duration
1155 return conversation_list, mean_interval, duration, dns_counts
1158 def guess_server_address(conversations):
1159 # we guess the most common address.
1160 addresses = Counter()
1161 for c in conversations:
1162 addresses.update(c.endpoints)
1164 return addresses.most_common(1)[0]
1167 def stringify_keys(x):
1169 for k, v in x.items():
1175 def unstringify_keys(x):
1177 for k, v in x.items():
1178 t = tuple(str(k).split('\t'))
1183 class TrafficModel(object):
1184 def __init__(self, n=3):
1186 self.query_details = {}
1188 self.dns_opcounts = defaultdict(int)
1189 self.cumulative_duration = 0.0
1190 self.packet_rate = [0, 1]
1192 def learn(self, conversations, dns_opcounts={}):
1195 key = (NON_PACKET,) * (self.n - 1)
1197 server = guess_server_address(conversations)
1199 for k, v in dns_opcounts.items():
1200 self.dns_opcounts[k] += v
1202 if len(conversations) > 1:
1203 first = conversations[0].start_time
1206 for c in conversations:
1208 last = max(last, c.packets[-1].timestamp)
1210 self.packet_rate[0] = total
1211 self.packet_rate[1] = last - first
1213 for c in conversations:
1214 client, server = c.guess_client_server(server)
1215 cum_duration += c.get_duration()
1216 key = (NON_PACKET,) * (self.n - 1)
1221 elapsed = p.timestamp - prev
1223 if elapsed > WAIT_THRESHOLD:
1224 # add the wait as an extra state
1225 wait = 'wait:%d' % (math.log(max(1.0,
1226 elapsed * WAIT_SCALE)))
1227 self.ngrams.setdefault(key, []).append(wait)
1228 key = key[1:] + (wait,)
1230 short_p = p.as_packet_type()
1231 self.query_details.setdefault(short_p,
1232 []).append(tuple(p.extra))
1233 self.ngrams.setdefault(key, []).append(short_p)
1234 key = key[1:] + (short_p,)
1236 self.cumulative_duration += cum_duration
1238 self.ngrams.setdefault(key, []).append(NON_PACKET)
1242 for k, v in self.ngrams.items():
1244 ngrams[k] = dict(Counter(v))
1247 for k, v in self.query_details.items():
1248 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1253 'query_details': query_details,
1254 'cumulative_duration': self.cumulative_duration,
1255 'packet_rate': self.packet_rate,
1256 'version': CURRENT_MODEL_VERSION
1258 d['dns'] = self.dns_opcounts
1260 if isinstance(f, str):
1263 json.dump(d, f, indent=2)
1266 if isinstance(f, str):
1272 version = d["version"]
1273 if version < REQUIRED_MODEL_VERSION:
1274 raise ValueError("the model file is version %d; "
1275 "version %d is required" %
1276 (version, REQUIRED_MODEL_VERSION))
1278 raise ValueError("the model file lacks a version number; "
1279 "version %d is required" %
1280 (REQUIRED_MODEL_VERSION))
1282 for k, v in d['ngrams'].items():
1283 k = tuple(str(k).split('\t'))
1284 values = self.ngrams.setdefault(k, [])
1285 for p, count in v.items():
1286 values.extend([str(p)] * count)
1289 for k, v in d['query_details'].items():
1290 values = self.query_details.setdefault(str(k), [])
1291 for p, count in v.items():
1293 values.extend([()] * count)
1295 values.extend([tuple(str(p).split('\t'))] * count)
1299 for k, v in d['dns'].items():
1300 self.dns_opcounts[k] += v
1302 self.cumulative_duration = d['cumulative_duration']
1303 self.packet_rate = d['packet_rate']
1305 def construct_conversation_sequence(self, timestamp=0.0,
1310 """Construct an individual conversation packet sequence from the
1314 key = (NON_PACKET,) * (self.n - 1)
1315 if ignore_before is None:
1316 ignore_before = timestamp - 1
1319 p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1321 if timestamp < ignore_before:
1323 if random.random() > persistence:
1324 print("ending after %s (persistence %.1f)" % (key, persistence),
1328 p = 'wait:%d' % random.randrange(5, 12)
1329 print("trying %s instead of end" % p, file=sys.stderr)
1331 if p in self.query_details:
1332 extra = random.choice(self.query_details[p])
1336 protocol, opcode = p.split(':', 1)
1337 if protocol == 'wait':
1338 log_wait_time = int(opcode) + random.random()
1339 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1342 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1343 wait = math.exp(log_wait) / replay_speed
1345 if hard_stop is not None and timestamp > hard_stop:
1347 if timestamp >= ignore_before:
1348 c.append((timestamp, protocol, opcode, extra))
1350 key = key[1:] + (p,)
1351 if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
1352 # two waits in a row can only be caused by "persistence"
1353 # tricks, and will not result in any packets being found.
1354 # Instead we pretend this is a fresh start.
1355 key = (NON_PACKET,) * (self.n - 1)
1359 def scale_to_packet_rate(self, scale):
1360 rate_n, rate_t = self.packet_rate
1361 return scale * rate_n / rate_t
1363 def packet_rate_to_scale(self, pps):
1364 rate_n, rate_t = self.packet_rate
1365 return pps * rate_t / rate_n
1367 def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
1369 """Generate a list of conversation descriptions from the model."""
1371 # We run the simulation for ten times as long as our desired
1372 # duration, and take the section at the end.
1373 lead_in = 9 * duration
1374 target_packets = int(packet_rate * duration)
1378 while n_packets < target_packets:
1379 start = random.uniform(-lead_in, duration)
1380 c = self.construct_conversation_sequence(start,
1382 replay_speed=replay_speed,
1384 persistence=persistence)
1385 # will these "packets" generate actual traffic?
1386 # some (e.g. ldap unbind) will not generate anything
1387 # if the previous packets are not there, and if the
1388 # conversation only has those it wastes a process doing nothing.
1389 for timestamp, protocol, opcode, extra in c:
1390 if is_a_traffic_generating_packet(protocol, opcode):
1395 conversations.append(c)
1398 scale = self.packet_rate_to_scale(packet_rate)
1399 print(("we have %d packets (target %d) in %d conversations at %.1f/s "
1400 "(scale %f)" % (n_packets, target_packets, len(conversations),
1401 packet_rate, scale)),
1403 conversations.sort() # sorts by first element == start time
1404 return conversations
1407 def seq_to_conversations(seq, server=1, client=2):
1411 c = Conversation(s[0][0], (server, client), s)
1413 conversations.append(c)
1414 return conversations
1419 'rpc_netlogon': '06',
1420 'kerberos': '06', # ratio 16248:258
1431 'smb_netlogon': '11',
1437 ('browser', '0x01'): 'Host Announcement (0x01)',
1438 ('browser', '0x02'): 'Request Announcement (0x02)',
1439 ('browser', '0x08'): 'Browser Election Request (0x08)',
1440 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1441 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1442 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1443 ('cldap', '3'): 'searchRequest',
1444 ('cldap', '5'): 'searchResDone',
1445 ('dcerpc', '0'): 'Request',
1446 ('dcerpc', '11'): 'Bind',
1447 ('dcerpc', '12'): 'Bind_ack',
1448 ('dcerpc', '13'): 'Bind_nak',
1449 ('dcerpc', '14'): 'Alter_context',
1450 ('dcerpc', '15'): 'Alter_context_resp',
1451 ('dcerpc', '16'): 'AUTH3',
1452 ('dcerpc', '2'): 'Response',
1453 ('dns', '0'): 'query',
1454 ('dns', '1'): 'response',
1455 ('drsuapi', '0'): 'DsBind',
1456 ('drsuapi', '12'): 'DsCrackNames',
1457 ('drsuapi', '13'): 'DsWriteAccountSpn',
1458 ('drsuapi', '1'): 'DsUnbind',
1459 ('drsuapi', '2'): 'DsReplicaSync',
1460 ('drsuapi', '3'): 'DsGetNCChanges',
1461 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1462 ('epm', '3'): 'Map',
1463 ('kerberos', ''): '',
1464 ('ldap', '0'): 'bindRequest',
1465 ('ldap', '1'): 'bindResponse',
1466 ('ldap', '2'): 'unbindRequest',
1467 ('ldap', '3'): 'searchRequest',
1468 ('ldap', '4'): 'searchResEntry',
1469 ('ldap', '5'): 'searchResDone',
1470 ('ldap', ''): '*** Unknown ***',
1471 ('lsarpc', '14'): 'lsa_LookupNames',
1472 ('lsarpc', '15'): 'lsa_LookupSids',
1473 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1474 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1475 ('lsarpc', '6'): 'lsa_OpenPolicy',
1476 ('lsarpc', '76'): 'lsa_LookupSids3',
1477 ('lsarpc', '77'): 'lsa_LookupNames4',
1478 ('nbns', '0'): 'query',
1479 ('nbns', '1'): 'response',
1480 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1481 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1482 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1483 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1484 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1485 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1486 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1487 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1488 ('samr', '0',): 'Connect',
1489 ('samr', '16'): 'GetAliasMembership',
1490 ('samr', '17'): 'LookupNames',
1491 ('samr', '18'): 'LookupRids',
1492 ('samr', '19'): 'OpenGroup',
1493 ('samr', '1'): 'Close',
1494 ('samr', '25'): 'QueryGroupMember',
1495 ('samr', '34'): 'OpenUser',
1496 ('samr', '36'): 'QueryUserInfo',
1497 ('samr', '39'): 'GetGroupsForUser',
1498 ('samr', '3'): 'QuerySecurity',
1499 ('samr', '5'): 'LookupDomain',
1500 ('samr', '64'): 'Connect5',
1501 ('samr', '6'): 'EnumDomains',
1502 ('samr', '7'): 'OpenDomain',
1503 ('samr', '8'): 'QueryDomainInfo',
1504 ('smb', '0x04'): 'Close (0x04)',
1505 ('smb', '0x24'): 'Locking AndX (0x24)',
1506 ('smb', '0x2e'): 'Read AndX (0x2e)',
1507 ('smb', '0x32'): 'Trans2 (0x32)',
1508 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1509 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1510 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1511 ('smb', '0x74'): 'Logoff AndX (0x74)',
1512 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1513 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1514 ('smb2', '0'): 'NegotiateProtocol',
1515 ('smb2', '11'): 'Ioctl',
1516 ('smb2', '14'): 'Find',
1517 ('smb2', '16'): 'GetInfo',
1518 ('smb2', '18'): 'Break',
1519 ('smb2', '1'): 'SessionSetup',
1520 ('smb2', '2'): 'SessionLogoff',
1521 ('smb2', '3'): 'TreeConnect',
1522 ('smb2', '4'): 'TreeDisconnect',
1523 ('smb2', '5'): 'Create',
1524 ('smb2', '6'): 'Close',
1525 ('smb2', '8'): 'Read',
1526 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1527 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1528 'user unknown (0x17)'),
1529 ('srvsvc', '16'): 'NetShareGetInfo',
1530 ('srvsvc', '21'): 'NetSrvGetInfo',
1534 def expand_short_packet(p, timestamp, src, dest, extra):
1535 protocol, opcode = p.split(':', 1)
1536 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1537 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1539 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1541 return '\t'.join(line)
1544 def flushing_signal_handler(signal, frame):
1545 """Signal handler closes standard out and error.
1547 Triggered by a sigterm, ensures that the log messages are flushed
1548 to disk and not lost.
1555 def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1556 """Fork a new process and replay the conversation sequence."""
1557 # We will need to reseed the random number generator or all the
1558 # clients will end up using the same sequence of random
1559 # numbers. random.randint() is mixed in so the initial seed will
1560 # have an effect here.
1561 seed = client_id * 1000 + random.randint(0, 999)
1563 # flush our buffers so messages won't be written by both sides
1570 # we must never return, or we'll end up running parts of the
1571 # parent's clean-up code. So we work in a try...finally, and
1572 # try to print any exceptions.
1575 endpoints = (server_id, client_id)
1578 c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1579 signal.signal(signal.SIGTERM, flushing_signal_handler)
1581 context.generate_process_local_config(account, c)
1584 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1586 f = open(filename, 'w')
1590 except IOError as e:
1591 LOGGER.info("stdout closing failed with %s" % e)
1595 now = time.time() - start
1597 sleep_time = gap - SLEEP_OVERHEAD
1599 time.sleep(sleep_time)
1601 max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1603 print("Maximum lag: %f" % max_lag)
1604 print("Start lag: %f" % start_lag)
1605 print("Max sleep miss: %f" % max_sleep_miss)
1609 print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1611 traceback.print_exc(sys.stderr)
1619 def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
1632 except IOError as e:
1633 LOGGER.warn("stdout closing failed with %s" % e)
1635 filename = os.path.join(context.statsdir, 'stats-dns')
1636 sys.stdout = open(filename, 'w')
1640 signal.signal(signal.SIGTERM, flushing_signal_handler)
1641 hammer = DnsHammer(dns_rate, duration, query_file=query_file)
1642 hammer.replay(context=context)
1645 print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1647 traceback.print_exc(sys.stderr)
1654 def replay(conversation_seq,
1660 dns_query_file=None,
1662 latency_timeout=1.0,
1663 stop_on_any_error=False,
1666 context = ReplayContext(server=host,
1669 total_conversations=len(conversation_seq),
1672 if len(accounts) < len(conversation_seq):
1673 raise ValueError(("we have %d accounts but %d conversations" %
1674 (len(accounts), len(conversation_seq))))
1676 # Set the process group so that the calling scripts are not killed
1677 # when the forked child processes are killed.
1680 # we delay the start by a bit to allow all the forks to get up and
1682 delay = len(conversation_seq) * 0.02
1683 start = time.time() + delay
1685 if duration is None:
1686 # end slightly after the last packet of the last conversation
1687 # to start. Conversations other than the last could still be
1688 # going, but we don't care.
1689 duration = conversation_seq[-1][-1][0] + latency_timeout
1691 print("We will start in %.1f seconds" % delay,
1693 print("We will stop after %.1f seconds" % (duration + delay),
1695 print("runtime %.1f seconds" % duration,
1698 # give one second grace for packets to finish before killing begins
1699 end = start + duration + 1.0
1701 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1702 % (len(conversation_seq), duration))
1704 context.write_stats('intentions',
1705 Planned_conversations=len(conversation_seq),
1706 Planned_packets=sum(len(x) for x in conversation_seq))
1711 pid = dnshammer_in_fork(dns_rate, duration, context,
1712 query_file=dns_query_file)
1715 for i, cs in enumerate(conversation_seq):
1716 account = accounts[i]
1718 pid = replay_seq_in_fork(cs, start, context, account, client_id)
1719 children[pid] = client_id
1721 # HERE, we are past all the forks
1723 print("all forks done in %.1f seconds, waiting %.1f" %
1724 (t - start + delay, t - start),
1727 while time.time() < end and children:
1730 pid, status = os.waitpid(-1, os.WNOHANG)
1731 except OSError as e:
1732 if e.errno != ECHILD: # no child processes
1736 c = children.pop(pid, None)
1738 print(("process %d finished conversation %d;"
1740 (pid, c, len(children))), file=sys.stderr)
1741 if stop_on_any_error and status != 0:
1745 print("EXCEPTION in parent", file=sys.stderr)
1746 traceback.print_exc()
1748 context.write_stats('unfinished',
1749 Unfinished_conversations=len(children))
1751 for s in (15, 15, 9):
1752 print(("killing %d children with -%d" %
1753 (len(children), s)), file=sys.stderr)
1754 for pid in children:
1757 except OSError as e:
1758 if e.errno != ESRCH: # don't fail if it has already died
1761 end = time.time() + 1
1764 pid, status = os.waitpid(-1, os.WNOHANG)
1765 except OSError as e:
1766 if e.errno != ECHILD:
1769 c = children.pop(pid, None)
1771 print("children is %s, no pid found" % children)
1775 print(("kill -%d %d KILLED conversation; "
1777 (s, pid, len(children))),
1779 if time.time() >= end:
1787 print("%d children are missing" % len(children),
1790 # there may be stragglers that were forked just as ^C was hit
1791 # and don't appear in the list of children. We can get them
1792 # with killpg, but that will also kill us, so this is^H^H would be
1793 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1794 # so as not to have to fuss around writing signal handlers.
1797 except KeyboardInterrupt:
1798 print("ignoring fake ^C", file=sys.stderr)
1801 def openLdb(host, creds, lp):
1802 session = system_session()
1803 ldb = SamDB(url="ldap://%s" % host,
1804 session_info=session,
1805 options=['modules:paged_searches'],
1811 def ou_name(ldb, instance_id):
1812 """Generate an ou name from the instance id"""
1813 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1817 def create_ou(ldb, instance_id):
1818 """Create an ou, all created user and machine accounts will belong to it.
1820 This allows all the created resources to be cleaned up easily.
1822 ou = ou_name(ldb, instance_id)
1824 ldb.add({"dn": ou.split(',', 1)[1],
1825 "objectclass": "organizationalunit"})
1826 except LdbError as e:
1827 (status, _) = e.args
1828 # ignore already exists
1833 "objectclass": "organizationalunit"})
1834 except LdbError as e:
1835 (status, _) = e.args
1836 # ignore already exists
1842 # ConversationAccounts holds details of the machine and user accounts
1843 # associated with a conversation.
1845 # We use a named tuple to reduce shared memory usage.
1846 ConversationAccounts = namedtuple('ConversationAccounts',
1853 def generate_replay_accounts(ldb, instance_id, number, password):
1854 """Generate a series of unique machine and user account names."""
1857 for i in range(1, number + 1):
1858 netbios_name = machine_name(instance_id, i)
1859 username = user_name(instance_id, i)
1861 account = ConversationAccounts(netbios_name, password, username,
1863 accounts.append(account)
1867 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1868 traffic_account=True):
1869 """Create a machine account via ldap."""
1871 ou = ou_name(ldb, instance_id)
1872 dn = "cn=%s,%s" % (netbios_name, ou)
1873 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1876 # we set these bits for the machine account otherwise the replayed
1877 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1878 account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1879 UF_SERVER_TRUST_ACCOUNT)
1882 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1886 "objectclass": "computer",
1887 "sAMAccountName": "%s$" % netbios_name,
1888 "userAccountControl": account_controls,
1889 "unicodePwd": utf16pw})
1892 def create_user_account(ldb, instance_id, username, userpass):
1893 """Create a user account via ldap."""
1894 ou = ou_name(ldb, instance_id)
1895 user_dn = "cn=%s,%s" % (username, ou)
1896 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1899 "objectclass": "user",
1900 "sAMAccountName": username,
1901 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1902 "unicodePwd": utf16pw
1905 # grant user write permission to do things like write account SPN
1906 sdutils = sd_utils.SDUtils(ldb)
1907 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1910 def create_group(ldb, instance_id, name):
1911 """Create a group via ldap."""
1913 ou = ou_name(ldb, instance_id)
1914 dn = "cn=%s,%s" % (name, ou)
1917 "objectclass": "group",
1918 "sAMAccountName": name,
1922 def user_name(instance_id, i):
1923 """Generate a user name based in the instance id"""
1924 return "STGU-%d-%d" % (instance_id, i)
1927 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1928 """Seach objectclass, return attr in a set"""
1930 expression="(objectClass={})".format(objectclass),
1933 return {str(obj[attr]) for obj in objs}
1936 def generate_users(ldb, instance_id, number, password):
1937 """Add users to the server"""
1938 existing_objects = search_objectclass(ldb, objectclass='user')
1940 for i in range(number, 0, -1):
1941 name = user_name(instance_id, i)
1942 if name not in existing_objects:
1943 create_user_account(ldb, instance_id, name, password)
1946 LOGGER.info("Created %u/%u users" % (users, number))
1951 def machine_name(instance_id, i, traffic_account=True):
1952 """Generate a machine account name from instance id."""
1954 # traffic accounts correspond to a given user, and use different
1955 # userAccountControl flags to ensure packets get processed correctly
1957 return "STGM-%d-%d" % (instance_id, i)
1959 # Otherwise we're just generating computer accounts to simulate a
1960 # semi-realistic network. These use the default computer
1961 # userAccountControl flags, so we use a different account name so that
1962 # we don't try to use them when generating packets
1963 return "PC-%d-%d" % (instance_id, i)
1966 def generate_machine_accounts(ldb, instance_id, number, password,
1967 traffic_account=True):
1968 """Add machine accounts to the server"""
1969 existing_objects = search_objectclass(ldb, objectclass='computer')
1971 for i in range(number, 0, -1):
1972 name = machine_name(instance_id, i, traffic_account)
1973 if name + "$" not in existing_objects:
1974 create_machine_account(ldb, instance_id, name, password,
1978 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1983 def group_name(instance_id, i):
1984 """Generate a group name from instance id."""
1985 return "STGG-%d-%d" % (instance_id, i)
1988 def generate_groups(ldb, instance_id, number):
1989 """Create the required number of groups on the server."""
1990 existing_objects = search_objectclass(ldb, objectclass='group')
1992 for i in range(number, 0, -1):
1993 name = group_name(instance_id, i)
1994 if name not in existing_objects:
1995 create_group(ldb, instance_id, name)
1997 if groups % 1000 == 0:
1998 LOGGER.info("Created %u/%u groups" % (groups, number))
2003 def clean_up_accounts(ldb, instance_id):
2004 """Remove the created accounts and groups from the server."""
2005 ou = ou_name(ldb, instance_id)
2007 ldb.delete(ou, ["tree_delete:1"])
2008 except LdbError as e:
2009 (status, _) = e.args
2010 # ignore does not exist
2015 def generate_users_and_groups(ldb, instance_id, password,
2016 number_of_users, number_of_groups,
2017 group_memberships, max_members,
2018 machine_accounts, traffic_accounts=True):
2019 """Generate the required users and groups, allocating the users to
2021 memberships_added = 0
2025 create_ou(ldb, instance_id)
2027 LOGGER.info("Generating dummy user accounts")
2028 users_added = generate_users(ldb, instance_id, number_of_users, password)
2030 LOGGER.info("Generating dummy machine accounts")
2031 computers_added = generate_machine_accounts(ldb, instance_id,
2032 machine_accounts, password,
2035 if number_of_groups > 0:
2036 LOGGER.info("Generating dummy groups")
2037 groups_added = generate_groups(ldb, instance_id, number_of_groups)
2039 if group_memberships > 0:
2040 LOGGER.info("Assigning users to groups")
2041 assignments = GroupAssignments(number_of_groups,
2047 LOGGER.info("Adding users to groups")
2048 add_users_to_groups(ldb, instance_id, assignments)
2049 memberships_added = assignments.total()
2051 if (groups_added > 0 and users_added == 0 and
2052 number_of_groups != groups_added):
2053 LOGGER.warning("The added groups will contain no members")
2055 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
2056 (users_added, computers_added, groups_added,
2060 class GroupAssignments(object):
2061 def __init__(self, number_of_groups, groups_added, number_of_users,
2062 users_added, group_memberships, max_members):
2065 self.generate_group_distribution(number_of_groups)
2066 self.generate_user_distribution(number_of_users, group_memberships)
2067 self.max_members = max_members
2068 self.assignments = defaultdict(list)
2069 self.assign_groups(number_of_groups, groups_added, number_of_users,
2070 users_added, group_memberships)
2072 def cumulative_distribution(self, weights):
2073 # make sure the probabilities conform to a cumulative distribution
2074 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
2075 # probability a proportional share of 1.0. Higher probabilities get a
2076 # bigger share, so are more likely to be picked. We use the cumulative
2077 # value, so we can use random.random() as a simple index into the list
2079 total = sum(weights)
2084 for probability in weights:
2085 cumulative += probability
2086 dist.append(cumulative / total)
2089 def generate_user_distribution(self, num_users, num_memberships):
2090 """Probability distribution of a user belonging to a group.
2092 # Assign a weighted probability to each user. Use the Pareto
2093 # Distribution so that some users are in a lot of groups, and the
2094 # bulk of users are in only a few groups. If we're assigning a large
2095 # number of group memberships, use a higher shape. This means slightly
2096 # fewer outlying users that are in large numbers of groups. The aim is
2097 # to have no users belonging to more than ~500 groups.
2098 if num_memberships > 5000000:
2100 elif num_memberships > 2000000:
2102 elif num_memberships > 300000:
2108 for x in range(1, num_users + 1):
2109 p = random.paretovariate(shape)
2112 # convert the weights to a cumulative distribution between 0.0 and 1.0
2113 self.user_dist = self.cumulative_distribution(weights)
2115 def generate_group_distribution(self, n):
2116 """Probability distribution of a group containing a user."""
2118 # Assign a weighted probability to each user. Probability decreases
2119 # as the group-ID increases
2121 for x in range(1, n + 1):
2125 # convert the weights to a cumulative distribution between 0.0 and 1.0
2126 self.group_weights = weights
2127 self.group_dist = self.cumulative_distribution(weights)
2129 def generate_random_membership(self):
2130 """Returns a randomly generated user-group membership"""
2132 # the list items are cumulative distribution values between 0.0 and
2133 # 1.0, which makes random() a handy way to index the list to get a
2134 # weighted random user/group. (Here the user/group returned are
2135 # zero-based array indexes)
2136 user = bisect.bisect(self.user_dist, random.random())
2137 group = bisect.bisect(self.group_dist, random.random())
2141 def users_in_group(self, group):
2142 return self.assignments[group]
2144 def get_groups(self):
2145 return self.assignments.keys()
2147 def cap_group_membership(self, group, max_members):
2148 """Prevent the group's membership from exceeding the max specified"""
2149 num_members = len(self.assignments[group])
2150 if num_members >= max_members:
2151 LOGGER.info("Group {0} has {1} members".format(group, num_members))
2153 # remove this group and then recalculate the cumulative
2154 # distribution, so this group is no longer selected
2155 self.group_weights[group - 1] = 0
2156 new_dist = self.cumulative_distribution(self.group_weights)
2157 self.group_dist = new_dist
2159 def add_assignment(self, user, group):
2160 # the assignments are stored in a dictionary where key=group,
2161 # value=list-of-users-in-group (indexing by group-ID allows us to
2162 # optimize for DB membership writes)
2163 if user not in self.assignments[group]:
2164 self.assignments[group].append(user)
2167 # check if there'a cap on how big the groups can grow
2168 if self.max_members:
2169 self.cap_group_membership(group, self.max_members)
2171 def assign_groups(self, number_of_groups, groups_added,
2172 number_of_users, users_added, group_memberships):
2173 """Allocate users to groups.
2175 The intention is to have a few users that belong to most groups, while
2176 the majority of users belong to a few groups.
2178 A few groups will contain most users, with the remaining only having a
2182 if group_memberships <= 0:
2185 # Calculate the number of group menberships required
2186 group_memberships = math.ceil(
2187 float(group_memberships) *
2188 (float(users_added) / float(number_of_users)))
2190 if self.max_members:
2191 group_memberships = min(group_memberships,
2192 self.max_members * number_of_groups)
2194 existing_users = number_of_users - users_added - 1
2195 existing_groups = number_of_groups - groups_added - 1
2196 while self.total() < group_memberships:
2197 user, group = self.generate_random_membership()
2199 if group > existing_groups or user > existing_users:
2200 # the + 1 converts the array index to the corresponding
2201 # group or user number
2202 self.add_assignment(user + 1, group + 1)
2208 def add_users_to_groups(db, instance_id, assignments):
2209 """Takes the assignments of users to groups and applies them to the DB."""
2211 total = assignments.total()
2215 for group in assignments.get_groups():
2216 users_in_group = assignments.users_in_group(group)
2217 if len(users_in_group) == 0:
2220 # Split up the users into chunks, so we write no more than 1K at a
2221 # time. (Minimizing the DB modifies is more efficient, but writing
2222 # 10K+ users to a single group becomes inefficient memory-wise)
2223 for chunk in range(0, len(users_in_group), 1000):
2224 chunk_of_users = users_in_group[chunk:chunk + 1000]
2225 add_group_members(db, instance_id, group, chunk_of_users)
2227 added += len(chunk_of_users)
2230 LOGGER.info("Added %u/%u memberships" % (added, total))
2232 def add_group_members(db, instance_id, group, users_in_group):
2233 """Adds the given users to group specified."""
2235 ou = ou_name(db, instance_id)
2238 return("cn=%s,%s" % (name, ou))
2240 group_dn = build_dn(group_name(instance_id, group))
2242 m.dn = ldb.Dn(db, group_dn)
2244 for user in users_in_group:
2245 user_dn = build_dn(user_name(instance_id, user))
2246 idx = "member-" + str(user)
2247 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2252 def generate_stats(statsdir, timing_file):
2253 """Generate and print the summary stats for a run."""
2254 first = sys.float_info.max
2259 failures = Counter()
2260 unique_conversations = set()
2261 if timing_file is not None:
2262 tw = timing_file.write
2267 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2272 'Max sleep miss': 0,
2275 'Planned_conversations': 0,
2276 'Planned_packets': 0,
2277 'Unfinished_conversations': 0,
2280 for filename in os.listdir(statsdir):
2281 path = os.path.join(statsdir, filename)
2282 with open(path, 'r') as f:
2285 fields = line.rstrip('\n').split('\t')
2286 conversation = fields[1]
2287 protocol = fields[2]
2288 packet_type = fields[3]
2289 latency = float(fields[4])
2290 t = float(fields[0])
2291 first = min(t - latency, first)
2294 op = (protocol, packet_type)
2295 latencies.setdefault(op, []).append(latency)
2296 if fields[5] == 'True':
2302 unique_conversations.add(conversation)
2305 except (ValueError, IndexError):
2307 k, v = line.split(':', 1)
2308 if k in float_values:
2309 float_values[k] = max(float(v),
2311 elif k in int_values:
2312 int_values[k] = max(int(v),
2315 print(line, file=sys.stderr)
2317 # not a valid line print and ignore
2318 print(line, file=sys.stderr)
2320 duration = last - first
2324 success_rate = successful / duration
2328 failure_rate = failed / duration
2330 conversations = len(unique_conversations)
2332 print("Total conversations: %10d" % conversations)
2333 print("Successful operations: %10d (%.3f per second)"
2334 % (successful, success_rate))
2335 print("Failed operations: %10d (%.3f per second)"
2336 % (failed, failure_rate))
2338 for k, v in sorted(float_values.items()):
2339 print("%-28s %f" % (k.replace('_', ' ') + ':', v))
2340 for k, v in sorted(int_values.items()):
2341 print("%-28s %d" % (k.replace('_', ' ') + ':', v))
2343 print("Protocol Op Code Description "
2344 " Count Failed Mean Median "
2348 for proto, packet in latencies:
2349 if proto not in ops:
2351 ops[proto].add(packet)
2352 protocols = sorted(ops.keys())
2354 for protocol in protocols:
2355 packet_types = sorted(ops[protocol], key=opcode_key)
2356 for packet_type in packet_types:
2357 op = (protocol, packet_type)
2358 values = latencies[op]
2359 values = sorted(values)
2361 failed = failures[op]
2362 mean = sum(values) / count
2363 median = calc_percentile(values, 0.50)
2364 percentile = calc_percentile(values, 0.95)
2365 rng = values[-1] - values[0]
2367 desc = OP_DESCRIPTIONS.get(op, '')
2368 print("%-12s %4s %-35s %12d %12d %12.6f "
2369 "%12.6f %12.6f %12.6f %12.6f"
2383 """Sort key for the operation code to ensure that it sorts numerically"""
2385 return "%03d" % int(v)
2390 def calc_percentile(values, percentile):
2391 """Calculate the specified percentile from the list of values.
2393 Assumes the list is sorted in ascending order.
2398 k = (len(values) - 1) * percentile
2402 return values[int(k)]
2403 d0 = values[int(f)] * (c - k)
2404 d1 = values[int(c)] * (k - f)
2408 def mk_masked_dir(*path):
2409 """In a testenv we end up with 0777 directories that look an alarming
2410 green colour with ls. Use umask to avoid that."""
2411 # py3 os.mkdir can do this
2412 d = os.path.join(*path)
2413 mask = os.umask(0o077)