import math
import sys
import signal
-import itertools
-from collections import OrderedDict, Counter, defaultdict
+from collections import OrderedDict, Counter, defaultdict, namedtuple
from samba.emulate import traffic_packets
from samba.samdb import SamDB
import ldb
from samba.logger import get_samba_logger
import bisect
+CURRENT_MODEL_VERSION = 2 # save as this
+REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
SLEEP_OVERHEAD = 3e-4
# we don't use None, because it complicates [de]serialisation
sys.stderr.flush()
-def random_colour_print():
- """Return a function that prints a randomly coloured line to stderr"""
- n = 18 + random.randrange(214)
- prefix = "\033[38;5;%dm" % n
-
- def p(*args):
- for a in args:
- print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+def random_colour_print(seeds):
+ """Return a function that prints a coloured line to stderr. The colour
+ of the line depends on a sort of hash of the integer arguments."""
+ if seeds:
+ s = 214
+ for x in seeds:
+ s += 17
+ s *= x
+ s %= 214
+ prefix = "\033[38;5;%dm" % (18 + s)
+
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+ else:
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print(a, file=sys.stderr)
return p
class Packet(object):
"""Details of a network packet"""
+ __slots__ = ('timestamp',
+ 'ip_protocol',
+ 'stream_number',
+ 'src',
+ 'dest',
+ 'protocol',
+ 'opcode',
+ 'desc',
+ 'extra',
+ 'endpoints')
def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
protocol, opcode, desc, extra):
-
self.timestamp = timestamp
self.ip_protocol = ip_protocol
self.stream_number = stream_number
self.endpoints = (self.dest, self.src)
@classmethod
- def from_line(self, line):
+ def from_line(cls, line):
fields = line.rstrip('\n').split('\t')
(timestamp,
ip_protocol,
src = int(src)
dest = int(dest)
- return Packet(timestamp, ip_protocol, stream_number, src, dest,
- protocol, opcode, desc, extra)
+ return cls(timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra)
def as_summary(self, time_offset=0.0):
"""Format the packet as a traffic_summary line.
return self.timestamp - other.timestamp
def is_really_a_packet(self, missing_packet_stats=None):
- """Is the packet one that can be ignored?
+ return is_a_real_packet(self.protocol, self.opcode)
- If so removing it will have no effect on the replay
- """
- if self.protocol in SKIPPED_PROTOCOLS:
- # Ignore any packets for the protocols we're not interested in.
- return False
- if self.protocol == "ldap" and self.opcode == '':
- # skip ldap continuation packets
- return False
- fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
- fn = getattr(traffic_packets, fn_name, None)
- if not fn:
- print("missing packet %s" % fn_name, file=sys.stderr)
- return False
- if fn is traffic_packets.null_packet:
- return False
- return True
+def is_a_real_packet(protocol, opcode):
+ """Is the packet one that can be ignored?
+
+ If so removing it will have no effect on the replay
+ """
+ if protocol in SKIPPED_PROTOCOLS:
+ # Ignore any packets for the protocols we're not interested in.
+ return False
+ if protocol == "ldap" and opcode == '':
+ # skip ldap continuation packets
+ return False
+
+ fn_name = 'packet_%s_%s' % (protocol, opcode)
+ fn = getattr(traffic_packets, fn_name, None)
+ if fn is None:
+ LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+ return False
+ if fn is traffic_packets.null_packet:
+ return False
+ return True
class ReplayContext(object):
- """State/Context for an individual conversation between an simulated client
- and a server.
+ """State/Context for a conversation between an simulated client and a
+ server. Some of the context is shared amongst all conversations
+ and should be generated before the fork, while other context is
+ specific to a particular conversation and should be generated
+ *after* the fork, in generate_process_local_config().
"""
-
def __init__(self,
server=None,
lp=None,
domain_sid=None):
self.server = server
- self.ldap_connections = []
- self.dcerpc_connections = []
- self.lsarpc_connections = []
- self.lsarpc_connections_named = []
- self.drsuapi_connections = []
- self.srvsvc_connections = []
- self.samr_contexts = []
self.netlogon_connection = None
self.creds = creds
self.lp = lp
self.last_netlogon_bad = False
self.last_samlogon_bad = False
self.generate_ldap_search_tables()
- self.next_conversation_id = itertools.count()
def generate_ldap_search_tables(self):
session = system_session()
self.attribute_clue_map = attribute_clue_map
def generate_process_local_config(self, account, conversation):
- if account is None:
- return
+ self.ldap_connections = []
+ self.dcerpc_connections = []
+ self.lsarpc_connections = []
+ self.lsarpc_connections_named = []
+ self.drsuapi_connections = []
+ self.srvsvc_connections = []
+ self.samr_contexts = []
self.netbios_name = account.netbios_name
self.machinepass = account.machinepass
self.username = account.username
class Conversation(object):
"""Details of a converation between a simulated client and a server."""
- conversation_id = None
-
- def __init__(self, start_time=None, endpoints=None):
+ def __init__(self, start_time=None, endpoints=None, seq=(),
+ conversation_id=None):
self.start_time = start_time
self.endpoints = endpoints
self.packets = []
- self.msg = random_colour_print()
+ self.msg = random_colour_print(endpoints)
self.client_balance = 0.0
+ self.conversation_id = conversation_id
+ for p in seq:
+ self.add_short_packet(*p)
def __cmp__(self, other):
if self.start_time is None:
print("gathering packets into conversations", file=sys.stderr)
conversations = OrderedDict()
- for p in packets:
+ for i, p in enumerate(packets):
p.timestamp -= start_time
c = conversations.get(p.endpoints)
if c is None:
- c = Conversation()
+ c = Conversation(conversation_id=(i + 2))
conversations[p.endpoints] = c
c.add_packet(p)
self.n = n
self.dns_opcounts = defaultdict(int)
self.cumulative_duration = 0.0
- self.conversation_rate = [0, 1]
+ self.packet_rate = [0, 1]
def learn(self, conversations, dns_opcounts={}):
prev = 0.0
self.dns_opcounts[k] += v
if len(conversations) > 1:
- elapsed =\
- conversations[-1].start_time - conversations[0].start_time
- self.conversation_rate[0] = len(conversations)
- self.conversation_rate[1] = elapsed
+ first = conversations[0].start_time
+ total = 0
+ last = first + 0.1
+ for c in conversations:
+ total += len(c)
+ last = max(last, c.packets[-1].timestamp)
+
+ self.packet_rate[0] = total
+ self.packet_rate[1] = last - first
for c in conversations:
client, server = c.guess_client_server(server)
'ngrams': ngrams,
'query_details': query_details,
'cumulative_duration': self.cumulative_duration,
- 'conversation_rate': self.conversation_rate,
+ 'packet_rate': self.packet_rate,
+ 'version': CURRENT_MODEL_VERSION
}
d['dns'] = self.dns_opcounts
d = json.load(f)
+ try:
+ version = d["version"]
+ if version < REQUIRED_MODEL_VERSION:
+ raise ValueError("the model file is version %d; "
+ "version %d is required" %
+ (version, REQUIRED_MODEL_VERSION))
+ except KeyError:
+ raise ValueError("the model file lacks a version number; "
+ "version %d is required" %
+ (REQUIRED_MODEL_VERSION))
+
for k, v in d['ngrams'].items():
k = tuple(str(k).split('\t'))
values = self.ngrams.setdefault(k, [])
for p, count in v.items():
values.extend([str(p)] * count)
+ values.sort()
for k, v in d['query_details'].items():
values = self.query_details.setdefault(str(k), [])
values.extend([()] * count)
else:
values.extend([tuple(str(p).split('\t'))] * count)
+ values.sort()
if 'dns' in d:
for k, v in d['dns'].items():
self.dns_opcounts[k] += v
self.cumulative_duration = d['cumulative_duration']
- self.conversation_rate = d['conversation_rate']
-
- def construct_conversation(self, timestamp=0.0, client=2, server=1,
- hard_stop=None, packet_rate=1):
- """Construct a individual converation from the model."""
-
- c = Conversation(timestamp, (server, client))
-
+ self.packet_rate = d['packet_rate']
+
+ def construct_conversation_sequence(self, timestamp=0.0,
+ hard_stop=None,
+ replay_speed=1,
+ ignore_before=0):
+ """Construct an individual conversation packet sequence from the
+ model.
+ """
+ c = []
key = (NON_PACKET,) * (self.n - 1)
+ if ignore_before is None:
+ ignore_before = timestamp - 1
- while key in self.ngrams:
- p = random.choice(self.ngrams.get(key, NON_PACKET))
+ while True:
+ p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
if p == NON_PACKET:
break
+
if p in self.query_details:
extra = random.choice(self.query_details[p])
else:
protocol, opcode = p.split(':', 1)
if protocol == 'wait':
log_wait_time = int(opcode) + random.random()
- wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
+ wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
timestamp += wait
else:
log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
- wait = math.exp(log_wait) / packet_rate
+ wait = math.exp(log_wait) / replay_speed
timestamp += wait
if hard_stop is not None and timestamp > hard_stop:
break
- c.add_short_packet(timestamp, protocol, opcode, extra)
+ if timestamp >= ignore_before:
+ c.append((timestamp, protocol, opcode, extra))
key = key[1:] + (p,)
return c
- def generate_conversations(self, rate, duration, packet_rate=1):
+ def generate_conversations(self, scale, duration, replay_speed=1,
+ server=1, client=2):
"""Generate a list of conversations from the model."""
- # We run the simulation for at least ten times as long as our
- # desired duration, and take a section near the start.
- rate_n, rate_t = self.conversation_rate
+ # We run the simulation for ten times as long as our desired
+ # duration, and take the section at the end.
+ lead_in = 9 * duration
+ rate_n, rate_t = self.packet_rate
+ target_packets = int(duration * scale * rate_n / rate_t)
- duration2 = max(rate_t, duration * 2)
- n = rate * duration2 * rate_n / rate_t
+ conversations = []
+ n_packets = 0
+
+ while n_packets < target_packets:
+ start = random.uniform(-lead_in, duration)
+ c = self.construct_conversation_sequence(start,
+ hard_stop=duration,
+ replay_speed=replay_speed,
+ ignore_before=0)
+ conversations.append(c)
+ n_packets += len(c)
+
+ print(("we have %d packets (target %d) in %d conversations at scale %f"
+ % (n_packets, target_packets, len(conversations), scale)),
+ file=sys.stderr)
+ conversations.sort() # sorts by first element == start time
+ return seq_to_conversations(conversations)
- server = 1
- client = 2
- conversations = []
- end = duration2
- start = end - duration
-
- while client < n + 2:
- start = random.uniform(0, duration2)
- c = self.construct_conversation(start,
- client,
- server,
- hard_stop=(duration2 * 5),
- packet_rate=packet_rate)
-
- c.forget_packets_outside_window(start, end)
- c.renormalise_times(start)
- if len(c) != 0:
- conversations.append(c)
+def seq_to_conversations(seq, server=1, client=2):
+ conversations = []
+ for s in seq:
+ if s:
+ c = Conversation(s[0][0], (server, client), s)
client += 1
-
- print(("we have %d conversations at rate %f" %
- (len(conversations), rate)), file=sys.stderr)
- conversations.sort()
- return conversations
+ conversations.append(c)
+ return conversations
IP_PROTOCOLS = {
return ou
-class ConversationAccounts(object):
- """Details of the machine and user accounts associated with a conversation.
- """
- def __init__(self, netbios_name, machinepass, username, userpass):
- self.netbios_name = netbios_name
- self.machinepass = machinepass
- self.username = username
- self.userpass = userpass
+# ConversationAccounts holds details of the machine and user accounts
+# associated with a conversation.
+#
+# We use a named tuple to reduce shared memory usage.
+ConversationAccounts = namedtuple('ConversationAccounts',
+ ('netbios_name',
+ 'machinepass',
+ 'username',
+ 'userpass'))
def generate_replay_accounts(ldb, instance_id, number, password):
def mk_masked_dir(*path):
- """In a testenv we end up with 0777 diectories that look an alarming
+ """In a testenv we end up with 0777 directories that look an alarming
green colour with ls. Use umask to avoid that."""
+ # py3 os.mkdir can do this
d = os.path.join(*path)
mask = os.umask(0o077)
os.mkdir(d)