import math
import sys
import signal
-import itertools
+from errno import ECHILD, ESRCH
+
+from collections import OrderedDict, Counter, defaultdict, namedtuple
+from dns.resolver import query as dns_query
-from collections import OrderedDict, Counter, defaultdict
from samba.emulate import traffic_packets
from samba.samdb import SamDB
import ldb
from samba.dcerpc.misc import SEC_CHAN_BDC
from samba import gensec
from samba import sd_utils
-from samba.compat import get_string
+from samba.common import get_string
from samba.logger import get_samba_logger
import bisect
+CURRENT_MODEL_VERSION = 2 # save as this
+REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
SLEEP_OVERHEAD = 3e-4
# we don't use None, because it complicates [de]serialisation
sys.stderr.flush()
-def random_colour_print():
- """Return a function that prints a randomly coloured line to stderr"""
- n = 18 + random.randrange(214)
- prefix = "\033[38;5;%dm" % n
-
- def p(*args):
- for a in args:
- print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+def random_colour_print(seeds):
+ """Return a function that prints a coloured line to stderr. The colour
+ of the line depends on a sort of hash of the integer arguments."""
+ if seeds:
+ s = 214
+ for x in seeds:
+ s += 17
+ s *= x
+ s %= 214
+ prefix = "\033[38;5;%dm" % (18 + s)
+
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+ else:
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print(a, file=sys.stderr)
return p
class Packet(object):
"""Details of a network packet"""
+ __slots__ = ('timestamp',
+ 'ip_protocol',
+ 'stream_number',
+ 'src',
+ 'dest',
+ 'protocol',
+ 'opcode',
+ 'desc',
+ 'extra',
+ 'endpoints')
def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
protocol, opcode, desc, extra):
-
self.timestamp = timestamp
self.ip_protocol = ip_protocol
self.stream_number = stream_number
self.endpoints = (self.dest, self.src)
@classmethod
- def from_line(self, line):
+ def from_line(cls, line):
fields = line.rstrip('\n').split('\t')
(timestamp,
ip_protocol,
src = int(src)
dest = int(dest)
- return Packet(timestamp, ip_protocol, stream_number, src, dest,
- protocol, opcode, desc, extra)
+ return cls(timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra)
def as_summary(self, time_offset=0.0):
"""Format the packet as a traffic_summary line.
return self.timestamp - other.timestamp
def is_really_a_packet(self, missing_packet_stats=None):
- """Is the packet one that can be ignored?
+ return is_a_real_packet(self.protocol, self.opcode)
- If so removing it will have no effect on the replay
- """
- if self.protocol in SKIPPED_PROTOCOLS:
- # Ignore any packets for the protocols we're not interested in.
- return False
- if self.protocol == "ldap" and self.opcode == '':
- # skip ldap continuation packets
- return False
- fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
- fn = getattr(traffic_packets, fn_name, None)
- if not fn:
- print("missing packet %s" % fn_name, file=sys.stderr)
- return False
- if fn is traffic_packets.null_packet:
- return False
- return True
+def is_a_real_packet(protocol, opcode):
+ """Is the packet one that can be ignored?
+
+ If so removing it will have no effect on the replay
+ """
+ if protocol in SKIPPED_PROTOCOLS:
+ # Ignore any packets for the protocols we're not interested in.
+ return False
+ if protocol == "ldap" and opcode == '':
+ # skip ldap continuation packets
+ return False
+
+ fn_name = 'packet_%s_%s' % (protocol, opcode)
+ fn = getattr(traffic_packets, fn_name, None)
+ if fn is None:
+ LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+ return False
+ if fn is traffic_packets.null_packet:
+ return False
+ return True
+
+
+def is_a_traffic_generating_packet(protocol, opcode):
+ """Return true if a packet generates traffic in its own right. Some of
+ these will generate traffic in certain contexts (e.g. ldap unbind
+ after a bind) but not if the conversation consists only of these packets.
+ """
+ if protocol == 'wait':
+ return False
+
+ if (protocol, opcode) in (
+ ('kerberos', ''),
+ ('ldap', '2'),
+ ('dcerpc', '15'),
+ ('dcerpc', '16')):
+ return False
+
+ return is_a_real_packet(protocol, opcode)
class ReplayContext(object):
- """State/Context for an individual conversation between an simulated client
- and a server.
+ """State/Context for a conversation between an simulated client and a
+ server. Some of the context is shared amongst all conversations
+ and should be generated before the fork, while other context is
+ specific to a particular conversation and should be generated
+ *after* the fork, in generate_process_local_config().
"""
-
def __init__(self,
server=None,
lp=None,
creds=None,
+ total_conversations=None,
badpassword_frequency=None,
prefer_kerberos=None,
tempdir=None,
statsdir=None,
ou=None,
base_dn=None,
- domain=None,
- domain_sid=None):
-
+ domain=os.environ.get("DOMAIN"),
+ domain_sid=None,
+ instance_id=None):
self.server = server
- self.ldap_connections = []
- self.dcerpc_connections = []
- self.lsarpc_connections = []
- self.lsarpc_connections_named = []
- self.drsuapi_connections = []
- self.srvsvc_connections = []
- self.samr_contexts = []
self.netlogon_connection = None
self.creds = creds
self.lp = lp
- self.prefer_kerberos = prefer_kerberos
+ if prefer_kerberos:
+ self.kerberos_state = MUST_USE_KERBEROS
+ else:
+ self.kerberos_state = DONT_USE_KERBEROS
self.ou = ou
self.base_dn = base_dn
self.domain = domain
self.global_tempdir = tempdir
self.domain_sid = domain_sid
self.realm = lp.get('realm')
+ self.instance_id = instance_id
# Bad password attempt controls
self.badpassword_frequency = badpassword_frequency
self.last_drsuapi_bad = False
self.last_netlogon_bad = False
self.last_samlogon_bad = False
+ self.total_conversations = total_conversations
self.generate_ldap_search_tables()
- self.next_conversation_id = itertools.count()
def generate_ldap_search_tables(self):
session = system_session()
self.dn_map = dn_map
self.attribute_clue_map = attribute_clue_map
+ # pre-populate DN-based search filters (it's simplest to generate them
+ # once, when the test starts). These are used by guess_search_filter()
+ # to avoid full-scans
+ self.search_filters = {}
+
+ # lookup all the GPO DNs
+ res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
+ expression='(objectclass=groupPolicyContainer)')
+ gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
+
+ # a search for the 'gPCFileSysPath' attribute is probably a GPO search
+ # (as per the MS-GPOL spec) which searches for GPOs by DN
+ self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
+
+ # likewise, a search for gpLink is probably the Domain SOM search part
+ # of the MS-GPOL, in which case it's looking up a few OUs by DN
+ ou_str = ""
+ for ou in ["Domain Controllers,", "traffic_replay,", ""]:
+ ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
+ self.search_filters['gpLink'] = "(|{0})".format(ou_str)
+
+ # The CEP Web Service can query the AD DC to get pKICertificateTemplate
+ # objects (as per MS-WCCE)
+ self.search_filters['pKIExtendedKeyUsage'] = \
+ '(objectCategory=pKICertificateTemplate)'
+
+ # assume that anything querying the usnChanged is some kind of
+ # synchronization tool, e.g. AD Change Detection Connector
+ res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
+ self.search_filters['usnChanged'] = \
+ '(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
+
+ # The traffic_learner script doesn't preserve the LDAP search filter, and
+ # having no filter can result in a full DB scan. This is costly for a large
+ # DB, and not necessarily representative of real world traffic. As there
+ # several standard LDAP queries that get used by AD tools, we can apply
+ # some logic and guess what the search filter might have been originally.
+ def guess_search_filter(self, attrs, dn_sig, dn):
+
+ # there are some standard spec-based searches that query fairly unique
+ # attributes. Check if the search is likely one of these
+ for key in self.search_filters.keys():
+ if key in attrs:
+ return self.search_filters[key]
+
+ # if it's the top-level domain, assume we're looking up a single user,
+ # e.g. like powershell Get-ADUser or a similar tool
+ if dn_sig == 'DC,DC':
+ random_user_id = random.random() % self.total_conversations
+ account_name = user_name(self.instance_id, random_user_id)
+ return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
+
+ # otherwise just return everything in the sub-tree
+ return '(objectClass=*)'
+
def generate_process_local_config(self, account, conversation):
- if account is None:
- return
+ self.ldap_connections = []
+ self.dcerpc_connections = []
+ self.lsarpc_connections = []
+ self.lsarpc_connections_named = []
+ self.drsuapi_connections = []
+ self.srvsvc_connections = []
+ self.samr_contexts = []
self.netbios_name = account.netbios_name
self.machinepass = account.machinepass
self.username = account.username
self.lp.set("state directory", self.tempdir)
self.lp.set("tls verify peer", "no_check")
- # If the domain was not specified, check for the environment
- # variable.
- if self.domain is None:
- self.domain = os.environ["DOMAIN"]
-
self.remoteAddress = "/root/ncalrpc_as_system"
self.samlogon_dn = ("cn=%s,%s" %
(self.netbios_name, self.ou))
than that requested, but not significantly.
"""
if not failed_last_time:
- if (self.badpassword_frequency and self.badpassword_frequency > 0
- and random.random() < self.badpassword_frequency):
+ if (self.badpassword_frequency and
+ random.random() < self.badpassword_frequency):
try:
f(bad)
- except:
+ except Exception:
# Ignore any exceptions as the operation may fail
# as it's being performed with bad credentials
pass
self.user_creds.set_password(self.userpass)
self.user_creds.set_username(self.username)
self.user_creds.set_domain(self.domain)
- if self.prefer_kerberos:
- self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
+ self.user_creds.set_kerberos_state(self.kerberos_state)
self.user_creds_bad = Credentials()
self.user_creds_bad.guess(self.lp)
self.user_creds_bad.set_workstation(self.netbios_name)
self.user_creds_bad.set_password(self.userpass[:-4])
self.user_creds_bad.set_username(self.username)
- if self.prefer_kerberos:
- self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
+ self.user_creds_bad.set_kerberos_state(self.kerberos_state)
# Credentials for ldap simple bind.
self.simple_bind_creds = Credentials()
self.simple_bind_creds.set_username(self.username)
self.simple_bind_creds.set_gensec_features(
self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
- if self.prefer_kerberos:
- self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
+ self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
self.simple_bind_creds.set_bind_dn(self.user_dn)
self.simple_bind_creds_bad = Credentials()
self.simple_bind_creds_bad.set_gensec_features(
self.simple_bind_creds_bad.get_gensec_features() |
gensec.FEATURE_SEAL)
- if self.prefer_kerberos:
- self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
+ self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
def generate_machine_creds(self):
self.machine_creds.set_password(self.machinepass)
self.machine_creds.set_username(self.netbios_name + "$")
self.machine_creds.set_domain(self.domain)
- if self.prefer_kerberos:
- self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
+ self.machine_creds.set_kerberos_state(self.kerberos_state)
self.machine_creds_bad = Credentials()
self.machine_creds_bad.guess(self.lp)
self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds_bad.set_password(self.machinepass[:-4])
self.machine_creds_bad.set_username(self.netbios_name + "$")
- if self.prefer_kerberos:
- self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
+ self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
def get_matching_dn(self, pattern, attributes=None):
# If the pattern is an empty string, we assume ROOTDSE,
def get_authenticator(self):
auth = self.machine_creds.new_client_authenticator()
current = netr_Authenticator()
- current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
+ current.cred.data = [x if isinstance(x, int) else ord(x)
+ for x in auth["credential"]]
current.timestamp = auth["timestamp"]
subsequent = netr_Authenticator()
return (current, subsequent)
+ def write_stats(self, filename, **kwargs):
+ """Write arbitrary key/value pairs to a file in our stats directory in
+ order for them to be picked up later by another process working out
+ statistics."""
+ filename = os.path.join(self.statsdir, filename)
+ f = open(filename, 'w')
+ for k, v in kwargs.items():
+ print("%s: %s" % (k, v), file=f)
+ f.close()
+
class SamrContext(object):
"""State/Context associated with a samr connection.
class Conversation(object):
"""Details of a converation between a simulated client and a server."""
- conversation_id = None
-
- def __init__(self, start_time=None, endpoints=None):
+ def __init__(self, start_time=None, endpoints=None, seq=(),
+ conversation_id=None):
self.start_time = start_time
self.endpoints = endpoints
self.packets = []
- self.msg = random_colour_print()
+ self.msg = random_colour_print(endpoints)
self.client_balance = 0.0
+ self.conversation_id = conversation_id
+ for p in seq:
+ self.add_short_packet(*p)
def __cmp__(self, other):
if self.start_time is None:
self.packets.append(p)
def add_short_packet(self, timestamp, protocol, opcode, extra,
- client=True):
+ client=True, skip_unused_packets=True):
"""Create a packet from a timestamp, and 'protocol:opcode' pair, and a
(possibly empty) list of extra data. If client is True, assume
this packet is from the client to the server.
"""
+ if skip_unused_packets and not is_a_real_packet(protocol, opcode):
+ return
+
src, dest = self.guess_client_server()
if not client:
src, dest = dest, src
key = (protocol, opcode)
- desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
- if protocol in IP_PROTOCOLS:
- ip_protocol = IP_PROTOCOLS[protocol]
- else:
- ip_protocol = '06'
+ desc = OP_DESCRIPTIONS.get(key, '')
+ ip_protocol = IP_PROTOCOLS.get(protocol, '06')
packet = Packet(timestamp - self.start_time, ip_protocol,
'', src, dest,
protocol, opcode, desc, extra)
return self.packets[-1].timestamp - self.packets[0].timestamp
def replay_as_summary_lines(self):
- lines = []
- for p in self.packets:
- lines.append(p.as_summary(self.start_time))
- return lines
-
- def replay_in_fork_with_delay(self, start, context=None, account=None):
- """Fork a new process and replay the conversation.
- """
- def signal_handler(signal, frame):
- """Signal handler closes standard out and error.
-
- Triggered by a sigterm, ensures that the log messages are flushed
- to disk and not lost.
- """
- sys.stderr.close()
- sys.stdout.close()
- os._exit(0)
+ return [p.as_summary(self.start_time) for p in self.packets]
+ def replay_with_delay(self, start, context=None, account=None):
+ """Replay the conversation at the right time.
+ (We're already in a fork)."""
+ # first we sleep until the first packet
t = self.start_time
now = time.time() - start
gap = t - now
- # we are replaying strictly in order, so it is safe to sleep
- # in the main process if the gap is big enough. This reduces
- # the number of concurrent threads, which allows us to make
- # larger loads.
- if gap > 0.15 and False:
- print("sleeping for %f in main process" % (gap - 0.1),
- file=sys.stderr)
- time.sleep(gap - 0.1)
- now = time.time() - start
- gap = t - now
- print("gap is now %f" % gap, file=sys.stderr)
-
- self.conversation_id = next(context.next_conversation_id)
- pid = os.fork()
- if pid != 0:
- return pid
- pid = os.getpid()
- signal.signal(signal.SIGTERM, signal_handler)
- # we must never return, or we'll end up running parts of the
- # parent's clean-up code. So we work in a try...finally, and
- # try to print any exceptions.
-
- try:
- context.generate_process_local_config(account, self)
- sys.stdin.close()
- os.close(0)
- filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
- self.conversation_id)
- sys.stdout.close()
- sys.stdout = open(filename, 'w')
-
- sleep_time = gap - SLEEP_OVERHEAD
- if sleep_time > 0:
- time.sleep(sleep_time)
+ sleep_time = gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
- miss = t - (time.time() - start)
- self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
- self.replay(context)
- except Exception:
- print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
- file=sys.stderr)
- traceback.print_exc(sys.stderr)
- finally:
- sys.stderr.close()
- sys.stdout.close()
- os._exit(0)
-
- def replay(self, context=None):
- start = time.time()
+ miss = (time.time() - start) - t
+ self.msg("starting %s [miss %.3f]" % (self, miss))
+ max_gap = 0.0
+ max_sleep_miss = 0.0
+ # packet times are relative to conversation start
+ p_start = time.time()
for p in self.packets:
- now = time.time() - start
- gap = p.timestamp - now
- sleep_time = gap - SLEEP_OVERHEAD
- if sleep_time > 0:
- time.sleep(sleep_time)
+ now = time.time() - p_start
+ gap = now - p.timestamp
+ if gap > max_gap:
+ max_gap = gap
+ if gap < 0:
+ sleep_time = -gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+ t = time.time() - p_start
+ if t - p.timestamp > max_sleep_miss:
+ max_sleep_miss = t - p.timestamp
- miss = p.timestamp - (time.time() - start)
- if context is None:
- self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
- os.getpid()))
- continue
p.play(self, context)
+ return max_gap, miss, max_sleep_miss
+
def guess_client_server(self, server_clue=None):
"""Have a go at deciding who is the server and who is the client.
returns (client, server)
"""A lightweight conversation that generates a lot of dns:0 packets on
the fly"""
- def __init__(self, dns_rate, duration):
+ def __init__(self, dns_rate, duration, query_file=None):
n = int(dns_rate * duration)
self.times = [random.uniform(0, duration) for i in range(n)]
self.times.sort()
self.rate = dns_rate
self.duration = duration
self.start_time = 0
- self.msg = random_colour_print()
+ self.query_choices = self._get_query_choices(query_file=query_file)
def __str__(self):
return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
(len(self.times), self.duration, self.rate))
- def replay_in_fork_with_delay(self, start, context=None, account=None):
- return Conversation.replay_in_fork_with_delay(self,
- start,
- context,
- account)
+ def _get_query_choices(self, query_file=None):
+ """
+ Read dns query choices from a file, or return default
+
+ rname may contain format string like `{realm}`
+ realm can be fetched from context.realm
+ """
+
+ if query_file:
+ with open(query_file, 'r') as f:
+ text = f.read()
+ choices = []
+ for line in text.splitlines():
+ line = line.strip()
+ if line and not line.startswith('#'):
+ args = line.split(',')
+ assert len(args) == 4
+ choices.append(args)
+ return choices
+ else:
+ return [
+ (0, '{realm}', 'A', 'yes'),
+ (1, '{realm}', 'NS', 'yes'),
+ (2, '*.{realm}', 'A', 'no'),
+ (3, '*.{realm}', 'NS', 'no'),
+ (10, '_msdcs.{realm}', 'A', 'yes'),
+ (11, '_msdcs.{realm}', 'NS', 'yes'),
+ (20, 'nx.realm.com', 'A', 'no'),
+ (21, 'nx.realm.com', 'NS', 'no'),
+ (22, '*.nx.realm.com', 'A', 'no'),
+ (23, '*.nx.realm.com', 'NS', 'no'),
+ ]
def replay(self, context=None):
+ assert context
+ assert context.realm
start = time.time()
- fn = traffic_packets.packet_dns_0
for t in self.times:
now = time.time() - start
gap = t - now
if sleep_time > 0:
time.sleep(sleep_time)
- if context is None:
- miss = t - (time.time() - start)
- self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
- os.getpid()))
- continue
-
+ opcode, rname, rtype, exist = random.choice(self.query_choices)
+ rname = rname.format(realm=context.realm)
+ success = True
packet_start = time.time()
try:
- fn(self, self, context)
- end = time.time()
- duration = end - packet_start
- print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
- except Exception as e:
+ answers = dns_query(rname, rtype)
+ if exist == 'yes' and not len(answers):
+ # expect answers but didn't get, fail
+ success = False
+ except Exception:
+ success = False
+ finally:
end = time.time()
duration = end - packet_start
- print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
+ print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
def ingest_summaries(files, dns_mode='count'):
print("gathering packets into conversations", file=sys.stderr)
conversations = OrderedDict()
- for p in packets:
+ for i, p in enumerate(packets):
p.timestamp -= start_time
c = conversations.get(p.endpoints)
if c is None:
- c = Conversation()
+ c = Conversation(conversation_id=(i + 2))
conversations[p.endpoints] = c
c.add_packet(p)
self.n = n
self.dns_opcounts = defaultdict(int)
self.cumulative_duration = 0.0
- self.conversation_rate = [0, 1]
+ self.packet_rate = [0, 1]
def learn(self, conversations, dns_opcounts={}):
prev = 0.0
self.dns_opcounts[k] += v
if len(conversations) > 1:
- elapsed =\
- conversations[-1].start_time - conversations[0].start_time
- self.conversation_rate[0] = len(conversations)
- self.conversation_rate[1] = elapsed
+ first = conversations[0].start_time
+ total = 0
+ last = first + 0.1
+ for c in conversations:
+ total += len(c)
+ last = max(last, c.packets[-1].timestamp)
+
+ self.packet_rate[0] = total
+ self.packet_rate[1] = last - first
for c in conversations:
client, server = c.guess_client_server(server)
'ngrams': ngrams,
'query_details': query_details,
'cumulative_duration': self.cumulative_duration,
- 'conversation_rate': self.conversation_rate,
+ 'packet_rate': self.packet_rate,
+ 'version': CURRENT_MODEL_VERSION
}
d['dns'] = self.dns_opcounts
d = json.load(f)
+ try:
+ version = d["version"]
+ if version < REQUIRED_MODEL_VERSION:
+ raise ValueError("the model file is version %d; "
+ "version %d is required" %
+ (version, REQUIRED_MODEL_VERSION))
+ except KeyError:
+ raise ValueError("the model file lacks a version number; "
+ "version %d is required" %
+ (REQUIRED_MODEL_VERSION))
+
for k, v in d['ngrams'].items():
k = tuple(str(k).split('\t'))
values = self.ngrams.setdefault(k, [])
for p, count in v.items():
values.extend([str(p)] * count)
+ values.sort()
for k, v in d['query_details'].items():
values = self.query_details.setdefault(str(k), [])
values.extend([()] * count)
else:
values.extend([tuple(str(p).split('\t'))] * count)
+ values.sort()
if 'dns' in d:
for k, v in d['dns'].items():
self.dns_opcounts[k] += v
self.cumulative_duration = d['cumulative_duration']
- self.conversation_rate = d['conversation_rate']
-
- def construct_conversation(self, timestamp=0.0, client=2, server=1,
- hard_stop=None, packet_rate=1):
- """Construct a individual converation from the model."""
-
- c = Conversation(timestamp, (server, client))
-
+ self.packet_rate = d['packet_rate']
+
+ def construct_conversation_sequence(self, timestamp=0.0,
+ hard_stop=None,
+ replay_speed=1,
+ ignore_before=0,
+ persistence=0):
+ """Construct an individual conversation packet sequence from the
+ model.
+ """
+ c = []
key = (NON_PACKET,) * (self.n - 1)
+ if ignore_before is None:
+ ignore_before = timestamp - 1
- while key in self.ngrams:
- p = random.choice(self.ngrams.get(key, NON_PACKET))
+ while True:
+ p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
if p == NON_PACKET:
- break
+ if timestamp < ignore_before:
+ break
+ if random.random() > persistence:
+ print("ending after %s (persistence %.1f)" % (key, persistence),
+ file=sys.stderr)
+ break
+
+ p = 'wait:%d' % random.randrange(5, 12)
+ print("trying %s instead of end" % p, file=sys.stderr)
+
if p in self.query_details:
extra = random.choice(self.query_details[p])
else:
protocol, opcode = p.split(':', 1)
if protocol == 'wait':
log_wait_time = int(opcode) + random.random()
- wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
+ wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
timestamp += wait
else:
log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
- wait = math.exp(log_wait) / packet_rate
+ wait = math.exp(log_wait) / replay_speed
timestamp += wait
if hard_stop is not None and timestamp > hard_stop:
break
- c.add_short_packet(timestamp, protocol, opcode, extra)
+ if timestamp >= ignore_before:
+ c.append((timestamp, protocol, opcode, extra))
key = key[1:] + (p,)
+ if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
+ # two waits in a row can only be caused by "persistence"
+ # tricks, and will not result in any packets being found.
+ # Instead we pretend this is a fresh start.
+ key = (NON_PACKET,) * (self.n - 1)
return c
- def generate_conversations(self, rate, duration, packet_rate=1):
- """Generate a list of conversations from the model."""
-
- # We run the simulation for at least ten times as long as our
- # desired duration, and take a section near the start.
- rate_n, rate_t = self.conversation_rate
+ def scale_to_packet_rate(self, scale):
+ rate_n, rate_t = self.packet_rate
+ return scale * rate_n / rate_t
- duration2 = max(rate_t, duration * 2)
- n = rate * duration2 * rate_n / rate_t
+ def packet_rate_to_scale(self, pps):
+ rate_n, rate_t = self.packet_rate
+ return pps * rate_t / rate_n
- server = 1
- client = 2
+ def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
+ persistence=0):
+ """Generate a list of conversation descriptions from the model."""
+ # We run the simulation for ten times as long as our desired
+ # duration, and take the section at the end.
+ lead_in = 9 * duration
+ target_packets = int(packet_rate * duration)
conversations = []
- end = duration2
- start = end - duration
-
- while client < n + 2:
- start = random.uniform(0, duration2)
- c = self.construct_conversation(start,
- client,
- server,
- hard_stop=(duration2 * 5),
- packet_rate=packet_rate)
-
- c.forget_packets_outside_window(start, end)
- c.renormalise_times(start)
- if len(c) != 0:
- conversations.append(c)
- client += 1
+ n_packets = 0
+
+ while n_packets < target_packets:
+ start = random.uniform(-lead_in, duration)
+ c = self.construct_conversation_sequence(start,
+ hard_stop=duration,
+ replay_speed=replay_speed,
+ ignore_before=0,
+ persistence=persistence)
+ # will these "packets" generate actual traffic?
+ # some (e.g. ldap unbind) will not generate anything
+ # if the previous packets are not there, and if the
+ # conversation only has those it wastes a process doing nothing.
+ for timestamp, protocol, opcode, extra in c:
+ if is_a_traffic_generating_packet(protocol, opcode):
+ break
+ else:
+ continue
- print(("we have %d conversations at rate %f" %
- (len(conversations), rate)), file=sys.stderr)
- conversations.sort()
+ conversations.append(c)
+ n_packets += len(c)
+
+ scale = self.packet_rate_to_scale(packet_rate)
+ print(("we have %d packets (target %d) in %d conversations at %.1f/s "
+ "(scale %f)" % (n_packets, target_packets, len(conversations),
+ packet_rate, scale)),
+ file=sys.stderr)
+ conversations.sort() # sorts by first element == start time
return conversations
+def seq_to_conversations(seq, server=1, client=2):
+ conversations = []
+ for s in seq:
+ if s:
+ c = Conversation(s[0][0], (server, client), s)
+ client += 1
+ conversations.append(c)
+ return conversations
+
+
IP_PROTOCOLS = {
'dns': '11',
'rpc_netlogon': '06',
return '\t'.join(line)
-def replay(conversations,
+def flushing_signal_handler(signal, frame):
+ """Signal handler closes standard out and error.
+
+ Triggered by a sigterm, ensures that the log messages are flushed
+ to disk and not lost.
+ """
+ sys.stderr.close()
+ sys.stdout.close()
+ os._exit(0)
+
+
+def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
+ """Fork a new process and replay the conversation sequence."""
+ # We will need to reseed the random number generator or all the
+ # clients will end up using the same sequence of random
+ # numbers. random.randint() is mixed in so the initial seed will
+ # have an effect here.
+ seed = client_id * 1000 + random.randint(0, 999)
+
+ # flush our buffers so messages won't be written by both sides
+ sys.stdout.flush()
+ sys.stderr.flush()
+ pid = os.fork()
+ if pid != 0:
+ return pid
+
+ # we must never return, or we'll end up running parts of the
+ # parent's clean-up code. So we work in a try...finally, and
+ # try to print any exceptions.
+ try:
+ random.seed(seed)
+ endpoints = (server_id, client_id)
+ status = 0
+ t = cs[0][0]
+ c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
+ signal.signal(signal.SIGTERM, flushing_signal_handler)
+
+ context.generate_process_local_config(account, c)
+ sys.stdin.close()
+ os.close(0)
+ filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
+ c.conversation_id)
+ f = open(filename, 'w')
+ try:
+ sys.stdout.close()
+ os.close(1)
+ except IOError as e:
+ LOGGER.info("stdout closing failed with %s" % e)
+ pass
+
+ sys.stdout = f
+ now = time.time() - start
+ gap = t - now
+ sleep_time = gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+
+ max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
+ context=context)
+ print("Maximum lag: %f" % max_lag)
+ print("Start lag: %f" % start_lag)
+ print("Max sleep miss: %f" % max_sleep_miss)
+
+ except Exception:
+ status = 1
+ print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
+ file=sys.stderr)
+ traceback.print_exc(sys.stderr)
+ sys.stderr.flush()
+ finally:
+ sys.stderr.close()
+ sys.stdout.close()
+ os._exit(status)
+
+
+def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
+ sys.stdout.flush()
+ sys.stderr.flush()
+ pid = os.fork()
+ if pid != 0:
+ return pid
+
+ sys.stdin.close()
+ os.close(0)
+
+ try:
+ sys.stdout.close()
+ os.close(1)
+ except IOError as e:
+ LOGGER.warn("stdout closing failed with %s" % e)
+ pass
+ filename = os.path.join(context.statsdir, 'stats-dns')
+ sys.stdout = open(filename, 'w')
+
+ try:
+ status = 0
+ signal.signal(signal.SIGTERM, flushing_signal_handler)
+ hammer = DnsHammer(dns_rate, duration, query_file=query_file)
+ hammer.replay(context=context)
+ except Exception:
+ status = 1
+ print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
+ file=sys.stderr)
+ traceback.print_exc(sys.stderr)
+ finally:
+ sys.stderr.close()
+ sys.stdout.close()
+ os._exit(status)
+
+
+def replay(conversation_seq,
host=None,
creds=None,
lp=None,
accounts=None,
dns_rate=0,
+ dns_query_file=None,
duration=None,
+ latency_timeout=1.0,
+ stop_on_any_error=False,
**kwargs):
context = ReplayContext(server=host,
creds=creds,
lp=lp,
+ total_conversations=len(conversation_seq),
**kwargs)
- if len(accounts) < len(conversations):
- print(("we have %d accounts but %d conversations" %
- (accounts, conversations)), file=sys.stderr)
-
- cstack = list(zip(
- sorted(conversations, key=lambda x: x.start_time, reverse=True),
- accounts))
+ if len(accounts) < len(conversation_seq):
+ raise ValueError(("we have %d accounts but %d conversations" %
+ (len(accounts), len(conversation_seq))))
# Set the process group so that the calling scripts are not killed
# when the forked child processes are killed.
os.setpgrp()
- start = time.time()
+ # we delay the start by a bit to allow all the forks to get up and
+ # running.
+ delay = len(conversation_seq) * 0.02
+ start = time.time() + delay
if duration is None:
- # end 1 second after the last packet of the last conversation
+ # end slightly after the last packet of the last conversation
# to start. Conversations other than the last could still be
# going, but we don't care.
- duration = cstack[0][0].packets[-1].timestamp + 1.0
- print("We will stop after %.1f seconds" % duration,
- file=sys.stderr)
+ duration = conversation_seq[-1][-1][0] + latency_timeout
- end = start + duration
+ print("We will start in %.1f seconds" % delay,
+ file=sys.stderr)
+ print("We will stop after %.1f seconds" % (duration + delay),
+ file=sys.stderr)
+ print("runtime %.1f seconds" % duration,
+ file=sys.stderr)
+
+ # give one second grace for packets to finish before killing begins
+ end = start + duration + 1.0
LOGGER.info("Replaying traffic for %u conversations over %d seconds"
- % (len(conversations), duration))
+ % (len(conversation_seq), duration))
- children = {}
- if dns_rate:
- dns_hammer = DnsHammer(dns_rate, duration)
- cstack.append((dns_hammer, None))
+ context.write_stats('intentions',
+ Planned_conversations=len(conversation_seq),
+ Planned_packets=sum(len(x) for x in conversation_seq))
+ children = {}
try:
- while True:
- # we spawn a batch, wait for finishers, then spawn another
- now = time.time()
- batch_end = min(now + 2.0, end)
- fork_time = 0.0
- fork_n = 0
- while cstack:
- c, account = cstack.pop()
- if c.start_time + start > batch_end:
- cstack.append((c, account))
- break
+ if dns_rate:
+ pid = dnshammer_in_fork(dns_rate, duration, context,
+ query_file=dns_query_file)
+ children[pid] = 1
+
+ for i, cs in enumerate(conversation_seq):
+ account = accounts[i]
+ client_id = i + 2
+ pid = replay_seq_in_fork(cs, start, context, account, client_id)
+ children[pid] = client_id
+
+ # HERE, we are past all the forks
+ t = time.time()
+ print("all forks done in %.1f seconds, waiting %.1f" %
+ (t - start + delay, t - start),
+ file=sys.stderr)
- st = time.time()
- pid = c.replay_in_fork_with_delay(start, context, account)
- children[pid] = c
- t = time.time()
- elapsed = t - st
- fork_time += elapsed
- fork_n += 1
- print("forked %s in pid %s (in %fs)" % (c, pid,
- elapsed),
- file=sys.stderr)
-
- if fork_n:
- print(("forked %d times in %f seconds (avg %f)" %
- (fork_n, fork_time, fork_time / fork_n)),
- file=sys.stderr)
- elif cstack:
- debug(2, "no forks in batch ending %f" % batch_end)
-
- while time.time() < batch_end - 1.0:
- time.sleep(0.01)
- try:
- pid, status = os.waitpid(-1, os.WNOHANG)
- except OSError as e:
- if e.errno != 10: # no child processes
- raise
- break
- if pid:
- c = children.pop(pid, None)
- print(("process %d finished conversation %s;"
+ while time.time() < end and children:
+ time.sleep(0.003)
+ try:
+ pid, status = os.waitpid(-1, os.WNOHANG)
+ except OSError as e:
+ if e.errno != ECHILD: # no child processes
+ raise
+ break
+ if pid:
+ c = children.pop(pid, None)
+ if DEBUG_LEVEL > 0:
+ print(("process %d finished conversation %d;"
" %d to go" %
(pid, c, len(children))), file=sys.stderr)
-
- if time.time() >= end:
- print("time to stop", file=sys.stderr)
- break
+ if stop_on_any_error and status != 0:
+ break
except Exception:
print("EXCEPTION in parent", file=sys.stderr)
traceback.print_exc()
finally:
+ context.write_stats('unfinished',
+ Unfinished_conversations=len(children))
+
for s in (15, 15, 9):
print(("killing %d children with -%d" %
(len(children), s)), file=sys.stderr)
try:
os.kill(pid, s)
except OSError as e:
- if e.errno != 3: # don't fail if it has already died
+ if e.errno != ESRCH: # don't fail if it has already died
raise
time.sleep(0.5)
end = time.time() + 1
try:
pid, status = os.waitpid(-1, os.WNOHANG)
except OSError as e:
- if e.errno != 10:
+ if e.errno != ECHILD:
raise
if pid != 0:
c = children.pop(pid, None)
- print(("kill -%d %d KILLED conversation %s; "
+ if c is None:
+ print("children is %s, no pid found" % children)
+ sys.stderr.flush()
+ sys.stdout.flush()
+ os._exit(1)
+ print(("kill -%d %d KILLED conversation; "
"%d to go" %
- (s, pid, c, len(children))),
+ (s, pid, len(children))),
file=sys.stderr)
if time.time() >= end:
break
return ou
-class ConversationAccounts(object):
- """Details of the machine and user accounts associated with a conversation.
- """
- def __init__(self, netbios_name, machinepass, username, userpass):
- self.netbios_name = netbios_name
- self.machinepass = machinepass
- self.username = username
- self.userpass = userpass
+# ConversationAccounts holds details of the machine and user accounts
+# associated with a conversation.
+#
+# We use a named tuple to reduce shared memory usage.
+ConversationAccounts = namedtuple('ConversationAccounts',
+ ('netbios_name',
+ 'machinepass',
+ 'username',
+ 'userpass'))
def generate_replay_accounts(ldb, instance_id, number, password):
def generate_users_and_groups(ldb, instance_id, password,
number_of_users, number_of_groups,
- group_memberships, machine_accounts,
- traffic_accounts=True):
+ group_memberships, max_members,
+ machine_accounts, traffic_accounts=True):
"""Generate the required users and groups, allocating the users to
those groups."""
memberships_added = 0
groups_added,
number_of_users,
users_added,
- group_memberships)
+ group_memberships,
+ max_members)
LOGGER.info("Adding users to groups")
add_users_to_groups(ldb, instance_id, assignments)
memberships_added = assignments.total()
class GroupAssignments(object):
def __init__(self, number_of_groups, groups_added, number_of_users,
- users_added, group_memberships):
+ users_added, group_memberships, max_members):
self.count = 0
self.generate_group_distribution(number_of_groups)
self.generate_user_distribution(number_of_users, group_memberships)
+ self.max_members = max_members
self.assignments = defaultdict(list)
self.assign_groups(number_of_groups, groups_added, number_of_users,
users_added, group_memberships)
# value, so we can use random.random() as a simple index into the list
dist = []
total = sum(weights)
+ if total == 0:
+ return None
+
cumulative = 0.0
for probability in weights:
cumulative += probability
weights.append(p)
# convert the weights to a cumulative distribution between 0.0 and 1.0
+ self.group_weights = weights
self.group_dist = self.cumulative_distribution(weights)
def generate_random_membership(self):
def get_groups(self):
return self.assignments.keys()
+ def cap_group_membership(self, group, max_members):
+ """Prevent the group's membership from exceeding the max specified"""
+ num_members = len(self.assignments[group])
+ if num_members >= max_members:
+ LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+ # remove this group and then recalculate the cumulative
+ # distribution, so this group is no longer selected
+ self.group_weights[group - 1] = 0
+ new_dist = self.cumulative_distribution(self.group_weights)
+ self.group_dist = new_dist
+
def add_assignment(self, user, group):
# the assignments are stored in a dictionary where key=group,
# value=list-of-users-in-group (indexing by group-ID allows us to
self.assignments[group].append(user)
self.count += 1
+ # check if there'a cap on how big the groups can grow
+ if self.max_members:
+ self.cap_group_membership(group, self.max_members)
+
def assign_groups(self, number_of_groups, groups_added,
number_of_users, users_added, group_memberships):
"""Allocate users to groups.
float(group_memberships) *
(float(users_added) / float(number_of_users)))
+ if self.max_members:
+ group_memberships = min(group_memberships,
+ self.max_members * number_of_groups)
+
existing_users = number_of_users - users_added - 1
existing_groups = number_of_groups - groups_added - 1
while self.total() < group_memberships:
successful = 0
failed = 0
latencies = {}
- failures = {}
- unique_converations = set()
- conversations = 0
-
+ failures = Counter()
+ unique_conversations = set()
if timing_file is not None:
tw = timing_file.write
else:
tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
+ float_values = {
+ 'Maximum lag': 0,
+ 'Start lag': 0,
+ 'Max sleep miss': 0,
+ }
+ int_values = {
+ 'Planned_conversations': 0,
+ 'Planned_packets': 0,
+ 'Unfinished_conversations': 0,
+ }
+
for filename in os.listdir(statsdir):
path = os.path.join(statsdir, filename)
with open(path, 'r') as f:
protocol = fields[2]
packet_type = fields[3]
latency = float(fields[4])
- first = min(float(fields[0]) - latency, first)
- last = max(float(fields[0]), last)
-
- if protocol not in latencies:
- latencies[protocol] = {}
- if packet_type not in latencies[protocol]:
- latencies[protocol][packet_type] = []
-
- latencies[protocol][packet_type].append(latency)
-
- if protocol not in failures:
- failures[protocol] = {}
- if packet_type not in failures[protocol]:
- failures[protocol][packet_type] = 0
+ t = float(fields[0])
+ first = min(t - latency, first)
+ last = max(t, last)
+ op = (protocol, packet_type)
+ latencies.setdefault(op, []).append(latency)
if fields[5] == 'True':
successful += 1
else:
failed += 1
- failures[protocol][packet_type] += 1
+ failures[op] += 1
- if conversation not in unique_converations:
- unique_converations.add(conversation)
- conversations += 1
+ unique_conversations.add(conversation)
tw(line)
except (ValueError, IndexError):
- # not a valid line print and ignore
- print(line, file=sys.stderr)
- pass
+ if ':' in line:
+ k, v = line.split(':', 1)
+ if k in float_values:
+ float_values[k] = max(float(v),
+ float_values[k])
+ elif k in int_values:
+ int_values[k] = max(int(v),
+ int_values[k])
+ else:
+ print(line, file=sys.stderr)
+ else:
+ # not a valid line print and ignore
+ print(line, file=sys.stderr)
+
duration = last - first
if successful == 0:
success_rate = 0
else:
failure_rate = failed / duration
+ conversations = len(unique_conversations)
+
print("Total conversations: %10d" % conversations)
print("Successful operations: %10d (%.3f per second)"
% (successful, success_rate))
print("Failed operations: %10d (%.3f per second)"
% (failed, failure_rate))
+ for k, v in sorted(float_values.items()):
+ print("%-28s %f" % (k.replace('_', ' ') + ':', v))
+ for k, v in sorted(int_values.items()):
+ print("%-28s %d" % (k.replace('_', ' ') + ':', v))
+
print("Protocol Op Code Description "
" Count Failed Mean Median "
"95% Range Max")
- protocols = sorted(latencies.keys())
+ ops = {}
+ for proto, packet in latencies:
+ if proto not in ops:
+ ops[proto] = set()
+ ops[proto].add(packet)
+ protocols = sorted(ops.keys())
+
for protocol in protocols:
- packet_types = sorted(latencies[protocol], key=opcode_key)
+ packet_types = sorted(ops[protocol], key=opcode_key)
for packet_type in packet_types:
- values = latencies[protocol][packet_type]
+ op = (protocol, packet_type)
+ values = latencies[op]
values = sorted(values)
count = len(values)
- failed = failures[protocol][packet_type]
+ failed = failures[op]
mean = sum(values) / count
median = calc_percentile(values, 0.50)
percentile = calc_percentile(values, 0.95)
rng = values[-1] - values[0]
maxv = values[-1]
- desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
- if sys.stdout.isatty:
- print("%-12s %4s %-35s %12d %12d %12.6f "
- "%12.6f %12.6f %12.6f %12.6f"
- % (protocol,
- packet_type,
- desc,
- count,
- failed,
- mean,
- median,
- percentile,
- rng,
- maxv))
- else:
- print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
- % (protocol,
- packet_type,
- desc,
- count,
- failed,
- mean,
- median,
- percentile,
- rng,
- maxv))
+ desc = OP_DESCRIPTIONS.get(op, '')
+ print("%-12s %4s %-35s %12d %12d %12.6f "
+ "%12.6f %12.6f %12.6f %12.6f"
+ % (protocol,
+ packet_type,
+ desc,
+ count,
+ failed,
+ mean,
+ median,
+ percentile,
+ rng,
+ maxv))
def opcode_key(v):
"""Sort key for the operation code to ensure that it sorts numerically"""
try:
return "%03d" % int(v)
- except:
+ except ValueError:
return v
def mk_masked_dir(*path):
- """In a testenv we end up with 0777 diectories that look an alarming
+ """In a testenv we end up with 0777 directories that look an alarming
green colour with ls. Use umask to avoid that."""
+ # py3 os.mkdir can do this
d = os.path.join(*path)
mask = os.umask(0o077)
os.mkdir(d)