from errno import ECHILD, ESRCH
from collections import OrderedDict, Counter, defaultdict, namedtuple
+from dns.resolver import query as dns_query
+
from samba.emulate import traffic_packets
from samba.samdb import SamDB
import ldb
from samba.dcerpc.misc import SEC_CHAN_BDC
from samba import gensec
from samba import sd_utils
-from samba.compat import get_string
+from samba.common import get_string
from samba.logger import get_samba_logger
import bisect
server=None,
lp=None,
creds=None,
+ total_conversations=None,
badpassword_frequency=None,
prefer_kerberos=None,
tempdir=None,
ou=None,
base_dn=None,
domain=os.environ.get("DOMAIN"),
- domain_sid=None):
+ domain_sid=None,
+ instance_id=None):
self.server = server
self.netlogon_connection = None
self.creds = creds
self.lp = lp
- self.prefer_kerberos = prefer_kerberos
+ if prefer_kerberos:
+ self.kerberos_state = MUST_USE_KERBEROS
+ else:
+ self.kerberos_state = DONT_USE_KERBEROS
self.ou = ou
self.base_dn = base_dn
self.domain = domain
self.global_tempdir = tempdir
self.domain_sid = domain_sid
self.realm = lp.get('realm')
+ self.instance_id = instance_id
# Bad password attempt controls
self.badpassword_frequency = badpassword_frequency
self.last_drsuapi_bad = False
self.last_netlogon_bad = False
self.last_samlogon_bad = False
+ self.total_conversations = total_conversations
self.generate_ldap_search_tables()
def generate_ldap_search_tables(self):
self.dn_map = dn_map
self.attribute_clue_map = attribute_clue_map
+ # pre-populate DN-based search filters (it's simplest to generate them
+ # once, when the test starts). These are used by guess_search_filter()
+ # to avoid full-scans
+ self.search_filters = {}
+
+ # lookup all the GPO DNs
+ res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
+ expression='(objectclass=groupPolicyContainer)')
+ gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
+
+ # a search for the 'gPCFileSysPath' attribute is probably a GPO search
+ # (as per the MS-GPOL spec) which searches for GPOs by DN
+ self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
+
+ # likewise, a search for gpLink is probably the Domain SOM search part
+ # of the MS-GPOL, in which case it's looking up a few OUs by DN
+ ou_str = ""
+ for ou in ["Domain Controllers,", "traffic_replay,", ""]:
+ ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
+ self.search_filters['gpLink'] = "(|{0})".format(ou_str)
+
+ # The CEP Web Service can query the AD DC to get pKICertificateTemplate
+ # objects (as per MS-WCCE)
+ self.search_filters['pKIExtendedKeyUsage'] = \
+ '(objectCategory=pKICertificateTemplate)'
+
+ # assume that anything querying the usnChanged is some kind of
+ # synchronization tool, e.g. AD Change Detection Connector
+ res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
+ self.search_filters['usnChanged'] = \
+ '(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
+
+ # The traffic_learner script doesn't preserve the LDAP search filter, and
+ # having no filter can result in a full DB scan. This is costly for a large
+ # DB, and not necessarily representative of real world traffic. As there
+ # several standard LDAP queries that get used by AD tools, we can apply
+ # some logic and guess what the search filter might have been originally.
+ def guess_search_filter(self, attrs, dn_sig, dn):
+
+ # there are some standard spec-based searches that query fairly unique
+ # attributes. Check if the search is likely one of these
+ for key in self.search_filters.keys():
+ if key in attrs:
+ return self.search_filters[key]
+
+ # if it's the top-level domain, assume we're looking up a single user,
+ # e.g. like powershell Get-ADUser or a similar tool
+ if dn_sig == 'DC,DC':
+ random_user_id = random.random() % self.total_conversations
+ account_name = user_name(self.instance_id, random_user_id)
+ return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
+
+ # otherwise just return everything in the sub-tree
+ return '(objectClass=*)'
+
def generate_process_local_config(self, account, conversation):
self.ldap_connections = []
self.dcerpc_connections = []
than that requested, but not significantly.
"""
if not failed_last_time:
- if (self.badpassword_frequency and self.badpassword_frequency > 0
- and random.random() < self.badpassword_frequency):
+ if (self.badpassword_frequency and
+ random.random() < self.badpassword_frequency):
try:
f(bad)
- except:
+ except Exception:
# Ignore any exceptions as the operation may fail
# as it's being performed with bad credentials
pass
self.user_creds.set_password(self.userpass)
self.user_creds.set_username(self.username)
self.user_creds.set_domain(self.domain)
- if self.prefer_kerberos:
- self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
+ self.user_creds.set_kerberos_state(self.kerberos_state)
self.user_creds_bad = Credentials()
self.user_creds_bad.guess(self.lp)
self.user_creds_bad.set_workstation(self.netbios_name)
self.user_creds_bad.set_password(self.userpass[:-4])
self.user_creds_bad.set_username(self.username)
- if self.prefer_kerberos:
- self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
+ self.user_creds_bad.set_kerberos_state(self.kerberos_state)
# Credentials for ldap simple bind.
self.simple_bind_creds = Credentials()
self.simple_bind_creds.set_username(self.username)
self.simple_bind_creds.set_gensec_features(
self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
- if self.prefer_kerberos:
- self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
+ self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
self.simple_bind_creds.set_bind_dn(self.user_dn)
self.simple_bind_creds_bad = Credentials()
self.simple_bind_creds_bad.set_gensec_features(
self.simple_bind_creds_bad.get_gensec_features() |
gensec.FEATURE_SEAL)
- if self.prefer_kerberos:
- self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
+ self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
def generate_machine_creds(self):
self.machine_creds.set_password(self.machinepass)
self.machine_creds.set_username(self.netbios_name + "$")
self.machine_creds.set_domain(self.domain)
- if self.prefer_kerberos:
- self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
+ self.machine_creds.set_kerberos_state(self.kerberos_state)
self.machine_creds_bad = Credentials()
self.machine_creds_bad.guess(self.lp)
self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds_bad.set_password(self.machinepass[:-4])
self.machine_creds_bad.set_username(self.netbios_name + "$")
- if self.prefer_kerberos:
- self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
- else:
- self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
+ self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
def get_matching_dn(self, pattern, attributes=None):
# If the pattern is an empty string, we assume ROOTDSE,
def get_authenticator(self):
auth = self.machine_creds.new_client_authenticator()
current = netr_Authenticator()
- current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
+ current.cred.data = [x if isinstance(x, int) else ord(x)
+ for x in auth["credential"]]
current.timestamp = auth["timestamp"]
subsequent = netr_Authenticator()
return (current, subsequent)
+ def write_stats(self, filename, **kwargs):
+ """Write arbitrary key/value pairs to a file in our stats directory in
+ order for them to be picked up later by another process working out
+ statistics."""
+ filename = os.path.join(self.statsdir, filename)
+ f = open(filename, 'w')
+ for k, v in kwargs.items():
+ print("%s: %s" % (k, v), file=f)
+ f.close()
+
class SamrContext(object):
"""State/Context associated with a samr connection.
self.packets.append(p)
def add_short_packet(self, timestamp, protocol, opcode, extra,
- client=True):
+ client=True, skip_unused_packets=True):
"""Create a packet from a timestamp, and 'protocol:opcode' pair, and a
(possibly empty) list of extra data. If client is True, assume
this packet is from the client to the server.
"""
+ if skip_unused_packets and not is_a_real_packet(protocol, opcode):
+ return
+
src, dest = self.guess_client_server()
if not client:
src, dest = dest, src
key = (protocol, opcode)
- desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
- if protocol in IP_PROTOCOLS:
- ip_protocol = IP_PROTOCOLS[protocol]
- else:
- ip_protocol = '06'
+ desc = OP_DESCRIPTIONS.get(key, '')
+ ip_protocol = IP_PROTOCOLS.get(protocol, '06')
packet = Packet(timestamp - self.start_time, ip_protocol,
'', src, dest,
protocol, opcode, desc, extra)
return self.packets[-1].timestamp - self.packets[0].timestamp
def replay_as_summary_lines(self):
- lines = []
- for p in self.packets:
- lines.append(p.as_summary(self.start_time))
- return lines
+ return [p.as_summary(self.start_time) for p in self.packets]
def replay_with_delay(self, start, context=None, account=None):
"""Replay the conversation at the right time.
"""A lightweight conversation that generates a lot of dns:0 packets on
the fly"""
- def __init__(self, dns_rate, duration):
+ def __init__(self, dns_rate, duration, query_file=None):
n = int(dns_rate * duration)
self.times = [random.uniform(0, duration) for i in range(n)]
self.times.sort()
self.rate = dns_rate
self.duration = duration
self.start_time = 0
- self.msg = random_colour_print()
+ self.query_choices = self._get_query_choices(query_file=query_file)
def __str__(self):
return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
(len(self.times), self.duration, self.rate))
+ def _get_query_choices(self, query_file=None):
+ """
+ Read dns query choices from a file, or return default
+
+ rname may contain format string like `{realm}`
+ realm can be fetched from context.realm
+ """
+
+ if query_file:
+ with open(query_file, 'r') as f:
+ text = f.read()
+ choices = []
+ for line in text.splitlines():
+ line = line.strip()
+ if line and not line.startswith('#'):
+ args = line.split(',')
+ assert len(args) == 4
+ choices.append(args)
+ return choices
+ else:
+ return [
+ (0, '{realm}', 'A', 'yes'),
+ (1, '{realm}', 'NS', 'yes'),
+ (2, '*.{realm}', 'A', 'no'),
+ (3, '*.{realm}', 'NS', 'no'),
+ (10, '_msdcs.{realm}', 'A', 'yes'),
+ (11, '_msdcs.{realm}', 'NS', 'yes'),
+ (20, 'nx.realm.com', 'A', 'no'),
+ (21, 'nx.realm.com', 'NS', 'no'),
+ (22, '*.nx.realm.com', 'A', 'no'),
+ (23, '*.nx.realm.com', 'NS', 'no'),
+ ]
+
def replay(self, context=None):
+ assert context
+ assert context.realm
start = time.time()
- fn = traffic_packets.packet_dns_0
for t in self.times:
now = time.time() - start
gap = t - now
if sleep_time > 0:
time.sleep(sleep_time)
+ opcode, rname, rtype, exist = random.choice(self.query_choices)
+ rname = rname.format(realm=context.realm)
+ success = True
packet_start = time.time()
try:
- fn(None, None, context)
- end = time.time()
- duration = end - packet_start
- print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
- except Exception as e:
+ answers = dns_query(rname, rtype)
+ if exist == 'yes' and not len(answers):
+ # expect answers but didn't get, fail
+ success = False
+ except Exception:
+ success = False
+ finally:
end = time.time()
duration = end - packet_start
- print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
+ print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
def ingest_summaries(files, dns_mode='count'):
def construct_conversation_sequence(self, timestamp=0.0,
hard_stop=None,
replay_speed=1,
- ignore_before=0):
+ ignore_before=0,
+ persistence=0):
"""Construct an individual conversation packet sequence from the
model.
"""
while True:
p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
if p == NON_PACKET:
- break
+ if timestamp < ignore_before:
+ break
+ if random.random() > persistence:
+ print("ending after %s (persistence %.1f)" % (key, persistence),
+ file=sys.stderr)
+ break
+
+ p = 'wait:%d' % random.randrange(5, 12)
+ print("trying %s instead of end" % p, file=sys.stderr)
if p in self.query_details:
extra = random.choice(self.query_details[p])
c.append((timestamp, protocol, opcode, extra))
key = key[1:] + (p,)
+ if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
+ # two waits in a row can only be caused by "persistence"
+ # tricks, and will not result in any packets being found.
+ # Instead we pretend this is a fresh start.
+ key = (NON_PACKET,) * (self.n - 1)
return c
- def generate_conversation_sequences(self, scale, duration, replay_speed=1):
+ def scale_to_packet_rate(self, scale):
+ rate_n, rate_t = self.packet_rate
+ return scale * rate_n / rate_t
+
+ def packet_rate_to_scale(self, pps):
+ rate_n, rate_t = self.packet_rate
+ return pps * rate_t / rate_n
+
+ def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
+ persistence=0):
"""Generate a list of conversation descriptions from the model."""
# We run the simulation for ten times as long as our desired
# duration, and take the section at the end.
lead_in = 9 * duration
- rate_n, rate_t = self.packet_rate
- target_packets = int(duration * scale * rate_n / rate_t)
-
+ target_packets = int(packet_rate * duration)
conversations = []
n_packets = 0
c = self.construct_conversation_sequence(start,
hard_stop=duration,
replay_speed=replay_speed,
- ignore_before=0)
+ ignore_before=0,
+ persistence=persistence)
# will these "packets" generate actual traffic?
# some (e.g. ldap unbind) will not generate anything
# if the previous packets are not there, and if the
conversations.append(c)
n_packets += len(c)
- print(("we have %d packets (target %d) in %d conversations at scale %f"
- % (n_packets, target_packets, len(conversations), scale)),
+ scale = self.packet_rate_to_scale(packet_rate)
+ print(("we have %d packets (target %d) in %d conversations at %.1f/s "
+ "(scale %f)" % (n_packets, target_packets, len(conversations),
+ packet_rate, scale)),
file=sys.stderr)
conversations.sort() # sorts by first element == start time
return conversations
os._exit(status)
-def dnshammer_in_fork(dns_rate, duration):
+def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
sys.stdout.flush()
sys.stderr.flush()
pid = os.fork()
if pid != 0:
return pid
+
+ sys.stdin.close()
+ os.close(0)
+
+ try:
+ sys.stdout.close()
+ os.close(1)
+ except IOError as e:
+ LOGGER.warn("stdout closing failed with %s" % e)
+ pass
+ filename = os.path.join(context.statsdir, 'stats-dns')
+ sys.stdout = open(filename, 'w')
+
try:
status = 0
signal.signal(signal.SIGTERM, flushing_signal_handler)
- hammer = DnsHammer(dns_rate, duration)
- hammer.replay()
+ hammer = DnsHammer(dns_rate, duration, query_file=query_file)
+ hammer.replay(context=context)
except Exception:
status = 1
print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
lp=None,
accounts=None,
dns_rate=0,
+ dns_query_file=None,
duration=None,
latency_timeout=1.0,
stop_on_any_error=False,
context = ReplayContext(server=host,
creds=creds,
lp=lp,
+ total_conversations=len(conversation_seq),
**kwargs)
if len(accounts) < len(conversation_seq):
LOGGER.info("Replaying traffic for %u conversations over %d seconds"
% (len(conversation_seq), duration))
+ context.write_stats('intentions',
+ Planned_conversations=len(conversation_seq),
+ Planned_packets=sum(len(x) for x in conversation_seq))
children = {}
try:
if dns_rate:
- pid = dnshammer_in_fork(dns_rate, duration)
+ pid = dnshammer_in_fork(dns_rate, duration, context,
+ query_file=dns_query_file)
children[pid] = 1
for i, cs in enumerate(conversation_seq):
print("EXCEPTION in parent", file=sys.stderr)
traceback.print_exc()
finally:
+ context.write_stats('unfinished',
+ Unfinished_conversations=len(children))
+
for s in (15, 15, 9):
print(("killing %d children with -%d" %
(len(children), s)), file=sys.stderr)
successful = 0
failed = 0
latencies = {}
- failures = {}
- unique_converations = set()
- conversations = 0
-
+ failures = Counter()
+ unique_conversations = set()
if timing_file is not None:
tw = timing_file.write
else:
tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
+ float_values = {
+ 'Maximum lag': 0,
+ 'Start lag': 0,
+ 'Max sleep miss': 0,
+ }
+ int_values = {
+ 'Planned_conversations': 0,
+ 'Planned_packets': 0,
+ 'Unfinished_conversations': 0,
+ }
+
for filename in os.listdir(statsdir):
path = os.path.join(statsdir, filename)
with open(path, 'r') as f:
protocol = fields[2]
packet_type = fields[3]
latency = float(fields[4])
- first = min(float(fields[0]) - latency, first)
- last = max(float(fields[0]), last)
-
- if protocol not in latencies:
- latencies[protocol] = {}
- if packet_type not in latencies[protocol]:
- latencies[protocol][packet_type] = []
-
- latencies[protocol][packet_type].append(latency)
-
- if protocol not in failures:
- failures[protocol] = {}
- if packet_type not in failures[protocol]:
- failures[protocol][packet_type] = 0
+ t = float(fields[0])
+ first = min(t - latency, first)
+ last = max(t, last)
+ op = (protocol, packet_type)
+ latencies.setdefault(op, []).append(latency)
if fields[5] == 'True':
successful += 1
else:
failed += 1
- failures[protocol][packet_type] += 1
+ failures[op] += 1
- if conversation not in unique_converations:
- unique_converations.add(conversation)
- conversations += 1
+ unique_conversations.add(conversation)
tw(line)
except (ValueError, IndexError):
- # not a valid line print and ignore
- print(line, file=sys.stderr)
- pass
+ if ':' in line:
+ k, v = line.split(':', 1)
+ if k in float_values:
+ float_values[k] = max(float(v),
+ float_values[k])
+ elif k in int_values:
+ int_values[k] = max(int(v),
+ int_values[k])
+ else:
+ print(line, file=sys.stderr)
+ else:
+ # not a valid line print and ignore
+ print(line, file=sys.stderr)
+
duration = last - first
if successful == 0:
success_rate = 0
else:
failure_rate = failed / duration
+ conversations = len(unique_conversations)
+
print("Total conversations: %10d" % conversations)
print("Successful operations: %10d (%.3f per second)"
% (successful, success_rate))
print("Failed operations: %10d (%.3f per second)"
% (failed, failure_rate))
+ for k, v in sorted(float_values.items()):
+ print("%-28s %f" % (k.replace('_', ' ') + ':', v))
+ for k, v in sorted(int_values.items()):
+ print("%-28s %d" % (k.replace('_', ' ') + ':', v))
+
print("Protocol Op Code Description "
" Count Failed Mean Median "
"95% Range Max")
- protocols = sorted(latencies.keys())
+ ops = {}
+ for proto, packet in latencies:
+ if proto not in ops:
+ ops[proto] = set()
+ ops[proto].add(packet)
+ protocols = sorted(ops.keys())
+
for protocol in protocols:
- packet_types = sorted(latencies[protocol], key=opcode_key)
+ packet_types = sorted(ops[protocol], key=opcode_key)
for packet_type in packet_types:
- values = latencies[protocol][packet_type]
+ op = (protocol, packet_type)
+ values = latencies[op]
values = sorted(values)
count = len(values)
- failed = failures[protocol][packet_type]
+ failed = failures[op]
mean = sum(values) / count
median = calc_percentile(values, 0.50)
percentile = calc_percentile(values, 0.95)
rng = values[-1] - values[0]
maxv = values[-1]
- desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
- if sys.stdout.isatty:
- print("%-12s %4s %-35s %12d %12d %12.6f "
- "%12.6f %12.6f %12.6f %12.6f"
- % (protocol,
- packet_type,
- desc,
- count,
- failed,
- mean,
- median,
- percentile,
- rng,
- maxv))
- else:
- print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
- % (protocol,
- packet_type,
- desc,
- count,
- failed,
- mean,
- median,
- percentile,
- rng,
- maxv))
+ desc = OP_DESCRIPTIONS.get(op, '')
+ print("%-12s %4s %-35s %12d %12d %12.6f "
+ "%12.6f %12.6f %12.6f %12.6f"
+ % (protocol,
+ packet_type,
+ desc,
+ count,
+ failed,
+ mean,
+ median,
+ percentile,
+ rng,
+ maxv))
def opcode_key(v):
"""Sort key for the operation code to ensure that it sorts numerically"""
try:
return "%03d" % int(v)
- except:
+ except ValueError:
return v