import math
import sys
import signal
-import itertools
-from collections import OrderedDict, Counter, defaultdict
+from collections import OrderedDict, Counter, defaultdict, namedtuple
from samba.emulate import traffic_packets
from samba.samdb import SamDB
import ldb
import traceback
from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
from samba.auth import system_session
-from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
-from samba.dsdb import UF_NORMAL_ACCOUNT
-from samba.dcerpc.misc import SEC_CHAN_WKSTA
+from samba.dsdb import (
+ UF_NORMAL_ACCOUNT,
+ UF_SERVER_TRUST_ACCOUNT,
+ UF_TRUSTED_FOR_DELEGATION,
+ UF_WORKSTATION_TRUST_ACCOUNT
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
from samba import gensec
from samba import sd_utils
+from samba.compat import get_string
+from samba.logger import get_samba_logger
+import bisect
+CURRENT_MODEL_VERSION = 2 # save as this
+REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
SLEEP_OVERHEAD = 3e-4
# we don't use None, because it complicates [de]serialisation
# DEBUG_LEVEL can be changed by scripts with -d
DEBUG_LEVEL = 0
+LOGGER = get_samba_logger(name=__name__)
+
def debug(level, msg, *args):
"""Print a formatted debug message to standard error.
sys.stderr.flush()
-def random_colour_print():
- """Return a function that prints a randomly coloured line to stderr"""
- n = 18 + random.randrange(214)
- prefix = "\033[38;5;%dm" % n
-
- def p(*args):
- for a in args:
- print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+def random_colour_print(seeds):
+ """Return a function that prints a coloured line to stderr. The colour
+ of the line depends on a sort of hash of the integer arguments."""
+ if seeds:
+ s = 214
+ for x in seeds:
+ s += 17
+ s *= x
+ s %= 214
+ prefix = "\033[38;5;%dm" % (18 + s)
+
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+ else:
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print(a, file=sys.stderr)
return p
class Packet(object):
"""Details of a network packet"""
- def __init__(self, fields):
- if isinstance(fields, str):
- fields = fields.rstrip('\n').split('\t')
+ __slots__ = ('timestamp',
+ 'ip_protocol',
+ 'stream_number',
+ 'src',
+ 'dest',
+ 'protocol',
+ 'opcode',
+ 'desc',
+ 'extra',
+ 'endpoints')
+ def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra):
+ self.timestamp = timestamp
+ self.ip_protocol = ip_protocol
+ self.stream_number = stream_number
+ self.src = src
+ self.dest = dest
+ self.protocol = protocol
+ self.opcode = opcode
+ self.desc = desc
+ self.extra = extra
+ if self.src < self.dest:
+ self.endpoints = (self.src, self.dest)
+ else:
+ self.endpoints = (self.dest, self.src)
+ @classmethod
+ def from_line(cls, line):
+ fields = line.rstrip('\n').split('\t')
(timestamp,
ip_protocol,
stream_number,
desc) = fields[:8]
extra = fields[8:]
- self.timestamp = float(timestamp)
- self.ip_protocol = ip_protocol
- try:
- self.stream_number = int(stream_number)
- except (ValueError, TypeError):
- self.stream_number = None
- self.src = int(src)
- self.dest = int(dest)
- self.protocol = protocol
- self.opcode = opcode
- self.desc = desc
- self.extra = extra
+ timestamp = float(timestamp)
+ src = int(src)
+ dest = int(dest)
- if self.src < self.dest:
- self.endpoints = (self.src, self.dest)
- else:
- self.endpoints = (self.dest, self.src)
+ return cls(timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra)
def as_summary(self, time_offset=0.0):
"""Format the packet as a traffic_summary line.
return "<Packet @%s>" % self
def copy(self):
- return self.__class__([self.timestamp,
- self.ip_protocol,
- self.stream_number,
- self.src,
- self.dest,
- self.protocol,
- self.opcode,
- self.desc] + self.extra)
+ return self.__class__(self.timestamp,
+ self.ip_protocol,
+ self.stream_number,
+ self.src,
+ self.dest,
+ self.protocol,
+ self.opcode,
+ self.desc,
+ self.extra)
def as_packet_type(self):
t = '%s:%s' % (self.protocol, self.opcode)
fn = getattr(traffic_packets, fn_name)
except AttributeError as e:
- print("Conversation(%s) Missing handler %s" % \
+ print("Conversation(%s) Missing handler %s" %
(conversation.conversation_id, fn_name),
file=sys.stderr)
return
return self.timestamp - other.timestamp
def is_really_a_packet(self, missing_packet_stats=None):
- """Is the packet one that can be ignored?
+ return is_a_real_packet(self.protocol, self.opcode)
- If so removing it will have no effect on the replay
- """
- if self.protocol in SKIPPED_PROTOCOLS:
- # Ignore any packets for the protocols we're not interested in.
- return False
- if self.protocol == "ldap" and self.opcode == '':
- # skip ldap continuation packets
- return False
- fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
- fn = getattr(traffic_packets, fn_name, None)
- if not fn:
- print("missing packet %s" % fn_name, file=sys.stderr)
- return False
- if fn is traffic_packets.null_packet:
- return False
- return True
+def is_a_real_packet(protocol, opcode):
+ """Is the packet one that can be ignored?
+
+ If so removing it will have no effect on the replay
+ """
+ if protocol in SKIPPED_PROTOCOLS:
+ # Ignore any packets for the protocols we're not interested in.
+ return False
+ if protocol == "ldap" and opcode == '':
+ # skip ldap continuation packets
+ return False
+
+ fn_name = 'packet_%s_%s' % (protocol, opcode)
+ fn = getattr(traffic_packets, fn_name, None)
+ if fn is None:
+ LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+ return False
+ if fn is traffic_packets.null_packet:
+ return False
+ return True
class ReplayContext(object):
- """State/Context for an individual conversation between an simulated client
- and a server.
+ """State/Context for a conversation between an simulated client and a
+ server. Some of the context is shared amongst all conversations
+ and should be generated before the fork, while other context is
+ specific to a particular conversation and should be generated
+ *after* the fork, in generate_process_local_config().
"""
-
def __init__(self,
server=None,
lp=None,
domain_sid=None):
self.server = server
- self.ldap_connections = []
- self.dcerpc_connections = []
- self.lsarpc_connections = []
- self.lsarpc_connections_named = []
- self.drsuapi_connections = []
- self.srvsvc_connections = []
- self.samr_contexts = []
self.netlogon_connection = None
self.creds = creds
self.lp = lp
self.last_netlogon_bad = False
self.last_samlogon_bad = False
self.generate_ldap_search_tables()
- self.next_conversation_id = itertools.count().next
def generate_ldap_search_tables(self):
session = system_session()
# for k, v in self.dn_map.items():
# print >>sys.stderr, k, len(v)
- for k, v in dn_map.items():
+ for k in list(dn_map.keys()):
if k[-3:] != ',DC':
continue
p = k[:-3]
self.attribute_clue_map = attribute_clue_map
def generate_process_local_config(self, account, conversation):
- if account is None:
- return
+ self.ldap_connections = []
+ self.dcerpc_connections = []
+ self.lsarpc_connections = []
+ self.lsarpc_connections_named = []
+ self.drsuapi_connections = []
+ self.srvsvc_connections = []
+ self.samr_contexts = []
self.netbios_name = account.netbios_name
self.machinepass = account.machinepass
self.username = account.username
'conversation-%d' %
conversation.conversation_id)
- self.lp.set("private dir", self.tempdir)
- self.lp.set("lock dir", self.tempdir)
+ self.lp.set("private dir", self.tempdir)
+ self.lp.set("lock dir", self.tempdir)
self.lp.set("state directory", self.tempdir)
self.lp.set("tls verify peer", "no_check")
than that requested, but not significantly.
"""
if not failed_last_time:
- if (self.badpassword_frequency > 0 and
- random.random() < self.badpassword_frequency):
+ if (self.badpassword_frequency and self.badpassword_frequency > 0
+ and random.random() < self.badpassword_frequency):
try:
f(bad)
except:
self.machine_creds = Credentials()
self.machine_creds.guess(self.lp)
self.machine_creds.set_workstation(self.netbios_name)
- self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
+ self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds.set_password(self.machinepass)
self.machine_creds.set_username(self.netbios_name + "$")
self.machine_creds.set_domain(self.domain)
self.machine_creds_bad = Credentials()
self.machine_creds_bad.guess(self.lp)
self.machine_creds_bad.set_workstation(self.netbios_name)
- self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
+ self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds_bad.set_password(self.machinepass[:-4])
self.machine_creds_bad.set_username(self.netbios_name + "$")
if self.prefer_kerberos:
def get_authenticator(self):
auth = self.machine_creds.new_client_authenticator()
current = netr_Authenticator()
- current.cred.data = [ord(x) for x in auth["credential"]]
+ current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
current.timestamp = auth["timestamp"]
subsequent = netr_Authenticator()
class Conversation(object):
"""Details of a converation between a simulated client and a server."""
- conversation_id = None
-
- def __init__(self, start_time=None, endpoints=None):
+ def __init__(self, start_time=None, endpoints=None, seq=(),
+ conversation_id=None):
self.start_time = start_time
self.endpoints = endpoints
self.packets = []
- self.msg = random_colour_print()
+ self.msg = random_colour_print(endpoints)
self.client_balance = 0.0
+ self.conversation_id = conversation_id
+ for p in seq:
+ self.add_short_packet(*p)
def __cmp__(self, other):
if self.start_time is None:
src, dest = self.guess_client_server()
if not client:
src, dest = dest, src
-
- desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
- ip_protocol = IP_PROTOCOLS.get(protocol, '06')
- fields = [timestamp - self.start_time, ip_protocol,
- '', src, dest,
- protocol, opcode, desc]
- fields.extend(extra)
- packet = Packet(fields)
+ key = (protocol, opcode)
+ desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
+ if protocol in IP_PROTOCOLS:
+ ip_protocol = IP_PROTOCOLS[protocol]
+ else:
+ ip_protocol = '06'
+ packet = Packet(timestamp - self.start_time, ip_protocol,
+ '', src, dest,
+ protocol, opcode, desc, extra)
# XXX we're assuming the timestamp is already adjusted for
# this conversation?
# XXX should we adjust client balance for guessed packets?
gap = t - now
print("gap is now %f" % gap, file=sys.stderr)
- self.conversation_id = context.next_conversation_id()
+ self.conversation_id = next(context.next_conversation_id)
pid = os.fork()
if pid != 0:
return pid
f = open(f)
print("Ingesting %s" % (f.name,), file=sys.stderr)
for line in f:
- p = Packet(line)
+ p = Packet.from_line(line)
if p.protocol == 'dns' and dns_mode != 'include':
dns_counts[p.opcode] += 1
else:
print("gathering packets into conversations", file=sys.stderr)
conversations = OrderedDict()
- for p in packets:
+ for i, p in enumerate(packets):
p.timestamp -= start_time
c = conversations.get(p.endpoints)
if c is None:
- c = Conversation()
+ c = Conversation(conversation_id=(i + 2))
conversations[p.endpoints] = c
c.add_packet(p)
self.n = n
self.dns_opcounts = defaultdict(int)
self.cumulative_duration = 0.0
- self.conversation_rate = [0, 1]
+ self.packet_rate = [0, 1]
def learn(self, conversations, dns_opcounts={}):
prev = 0.0
self.dns_opcounts[k] += v
if len(conversations) > 1:
- elapsed =\
- conversations[-1].start_time - conversations[0].start_time
- self.conversation_rate[0] = len(conversations)
- self.conversation_rate[1] = elapsed
+ first = conversations[0].start_time
+ total = 0
+ last = first + 0.1
+ for c in conversations:
+ total += len(c)
+ last = max(last, c.packets[-1].timestamp)
+
+ self.packet_rate[0] = total
+ self.packet_rate[1] = last - first
for c in conversations:
client, server = c.guess_client_server(server)
'ngrams': ngrams,
'query_details': query_details,
'cumulative_duration': self.cumulative_duration,
- 'conversation_rate': self.conversation_rate,
+ 'packet_rate': self.packet_rate,
+ 'version': CURRENT_MODEL_VERSION
}
d['dns'] = self.dns_opcounts
d = json.load(f)
+ try:
+ version = d["version"]
+ if version < REQUIRED_MODEL_VERSION:
+ raise ValueError("the model file is version %d; "
+ "version %d is required" %
+ (version, REQUIRED_MODEL_VERSION))
+ except KeyError:
+ raise ValueError("the model file lacks a version number; "
+ "version %d is required" %
+ (REQUIRED_MODEL_VERSION))
+
for k, v in d['ngrams'].items():
k = tuple(str(k).split('\t'))
values = self.ngrams.setdefault(k, [])
for p, count in v.items():
values.extend([str(p)] * count)
+ values.sort()
for k, v in d['query_details'].items():
values = self.query_details.setdefault(str(k), [])
values.extend([()] * count)
else:
values.extend([tuple(str(p).split('\t'))] * count)
+ values.sort()
if 'dns' in d:
for k, v in d['dns'].items():
self.dns_opcounts[k] += v
self.cumulative_duration = d['cumulative_duration']
- self.conversation_rate = d['conversation_rate']
-
- def construct_conversation(self, timestamp=0.0, client=2, server=1,
- hard_stop=None, packet_rate=1):
- """Construct a individual converation from the model."""
-
- c = Conversation(timestamp, (server, client))
-
+ self.packet_rate = d['packet_rate']
+
+ def construct_conversation_sequence(self, timestamp=0.0,
+ hard_stop=None,
+ replay_speed=1,
+ ignore_before=0):
+ """Construct an individual conversation packet sequence from the
+ model.
+ """
+ c = []
key = (NON_PACKET,) * (self.n - 1)
+ if ignore_before is None:
+ ignore_before = timestamp - 1
- while key in self.ngrams:
- p = random.choice(self.ngrams.get(key, NON_PACKET))
+ while True:
+ p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
if p == NON_PACKET:
break
+
if p in self.query_details:
extra = random.choice(self.query_details[p])
else:
protocol, opcode = p.split(':', 1)
if protocol == 'wait':
log_wait_time = int(opcode) + random.random()
- wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
+ wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
timestamp += wait
else:
log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
- wait = math.exp(log_wait) / packet_rate
+ wait = math.exp(log_wait) / replay_speed
timestamp += wait
if hard_stop is not None and timestamp > hard_stop:
break
- c.add_short_packet(timestamp, protocol, opcode, extra)
+ if timestamp >= ignore_before:
+ c.append((timestamp, protocol, opcode, extra))
key = key[1:] + (p,)
return c
- def generate_conversations(self, rate, duration, packet_rate=1):
+ def generate_conversations(self, scale, duration, replay_speed=1,
+ server=1, client=2):
"""Generate a list of conversations from the model."""
- # We run the simulation for at least ten times as long as our
- # desired duration, and take a section near the start.
- rate_n, rate_t = self.conversation_rate
+ # We run the simulation for ten times as long as our desired
+ # duration, and take the section at the end.
+ lead_in = 9 * duration
+ rate_n, rate_t = self.packet_rate
+ target_packets = int(duration * scale * rate_n / rate_t)
- duration2 = max(rate_t, duration * 2)
- n = rate * duration2 * rate_n / rate_t
+ conversations = []
+ n_packets = 0
+
+ while n_packets < target_packets:
+ start = random.uniform(-lead_in, duration)
+ c = self.construct_conversation_sequence(start,
+ hard_stop=duration,
+ replay_speed=replay_speed,
+ ignore_before=0)
+ conversations.append(c)
+ n_packets += len(c)
+
+ print(("we have %d packets (target %d) in %d conversations at scale %f"
+ % (n_packets, target_packets, len(conversations), scale)),
+ file=sys.stderr)
+ conversations.sort() # sorts by first element == start time
+ return seq_to_conversations(conversations)
- server = 1
- client = 2
- conversations = []
- end = duration2
- start = end - duration
-
- while client < n + 2:
- start = random.uniform(0, duration2)
- c = self.construct_conversation(start,
- client,
- server,
- hard_stop=(duration2 * 5),
- packet_rate=packet_rate)
-
- c.forget_packets_outside_window(start, end)
- c.renormalise_times(start)
- if len(c) != 0:
- conversations.append(c)
+def seq_to_conversations(seq, server=1, client=2):
+ conversations = []
+ for s in seq:
+ if s:
+ c = Conversation(s[0][0], (server, client), s)
client += 1
-
- print(("we have %d conversations at rate %f" %
- (len(conversations), rate)), file=sys.stderr)
- conversations.sort()
- return conversations
+ conversations.append(c)
+ return conversations
IP_PROTOCOLS = {
end = start + duration
- print("Replaying traffic for %u conversations over %d seconds"
+ LOGGER.info("Replaying traffic for %u conversations over %d seconds"
% (len(conversations), duration))
children = {}
finally:
for s in (15, 15, 9):
print(("killing %d children with -%d" %
- (len(children), s)), file=sys.stderr)
+ (len(children), s)), file=sys.stderr)
for pid in children:
try:
os.kill(pid, s)
session = system_session()
ldb = SamDB(url="ldap://%s" % host,
session_info=session,
+ options=['modules:paged_searches'],
credentials=creds,
lp=lp)
return ldb
"""
ou = ou_name(ldb, instance_id)
try:
- ldb.add({"dn": ou.split(',', 1)[1],
+ ldb.add({"dn": ou.split(',', 1)[1],
"objectclass": "organizationalunit"})
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# ignore already exists
if status != 68:
raise
try:
- ldb.add({"dn": ou,
+ ldb.add({"dn": ou,
"objectclass": "organizationalunit"})
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# ignore already exists
if status != 68:
raise
return ou
-class ConversationAccounts(object):
- """Details of the machine and user accounts associated with a conversation.
- """
- def __init__(self, netbios_name, machinepass, username, userpass):
- self.netbios_name = netbios_name
- self.machinepass = machinepass
- self.username = username
- self.userpass = userpass
+# ConversationAccounts holds details of the machine and user accounts
+# associated with a conversation.
+#
+# We use a named tuple to reduce shared memory usage.
+ConversationAccounts = namedtuple('ConversationAccounts',
+ ('netbios_name',
+ 'machinepass',
+ 'username',
+ 'userpass'))
def generate_replay_accounts(ldb, instance_id, number, password):
"""Generate a series of unique machine and user account names."""
- generate_traffic_accounts(ldb, instance_id, number, password)
accounts = []
for i in range(1, number + 1):
- netbios_name = "STGM-%d-%d" % (instance_id, i)
- username = "STGU-%d-%d" % (instance_id, i)
+ netbios_name = machine_name(instance_id, i)
+ username = user_name(instance_id, i)
account = ConversationAccounts(netbios_name, password, username,
password)
return accounts
-def generate_traffic_accounts(ldb, instance_id, number, password):
- """Create the specified number of user and machine accounts.
-
- As accounts are not explicitly deleted between runs. This function starts
- with the last account and iterates backwards stopping either when it
- finds an already existing account or it has generated all the required
- accounts.
- """
- print(("Generating machine and conversation accounts, "
- "as required for %d conversations" % number),
- file=sys.stderr)
- added = 0
- for i in range(number, 0, -1):
- try:
- netbios_name = "STGM-%d-%d" % (instance_id, i)
- create_machine_account(ldb, instance_id, netbios_name, password)
- added += 1
- except LdbError as e:
- (status, _) = e
- if status == 68:
- break
- else:
- raise
- if added > 0:
- print("Added %d new machine accounts" % added,
- file=sys.stderr)
-
- added = 0
- for i in range(number, 0, -1):
- try:
- username = "STGU-%d-%d" % (instance_id, i)
- create_user_account(ldb, instance_id, username, password)
- added += 1
- except LdbError as e:
- (status, _) = e
- if status == 68:
- break
- else:
- raise
-
- if added > 0:
- print("Added %d new user accounts" % added,
- file=sys.stderr)
-
-
-def create_machine_account(ldb, instance_id, netbios_name, machinepass):
+def create_machine_account(ldb, instance_id, netbios_name, machinepass,
+ traffic_account=True):
"""Create a machine account via ldap."""
ou = ou_name(ldb, instance_id)
dn = "cn=%s,%s" % (netbios_name, ou)
- utf16pw = unicode(
- '"' + machinepass.encode('utf-8') + '"', 'utf-8'
- ).encode('utf-16-le')
- start = time.time()
+ utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
+
+ if traffic_account:
+ # we set these bits for the machine account otherwise the replayed
+ # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
+ account_controls = str(UF_TRUSTED_FOR_DELEGATION |
+ UF_SERVER_TRUST_ACCOUNT)
+
+ else:
+ account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
+
ldb.add({
"dn": dn,
"objectclass": "computer",
"sAMAccountName": "%s$" % netbios_name,
- "userAccountControl":
- str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
+ "userAccountControl": account_controls,
"unicodePwd": utf16pw})
- end = time.time()
- duration = end - start
- print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
def create_user_account(ldb, instance_id, username, userpass):
"""Create a user account via ldap."""
ou = ou_name(ldb, instance_id)
user_dn = "cn=%s,%s" % (username, ou)
- utf16pw = unicode(
- '"' + userpass.encode('utf-8') + '"', 'utf-8'
- ).encode('utf-16-le')
- start = time.time()
+ utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
ldb.add({
"dn": user_dn,
"objectclass": "user",
# grant user write permission to do things like write account SPN
sdutils = sd_utils.SDUtils(ldb)
- sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
-
- end = time.time()
- duration = end - start
- print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
+ sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
def create_group(ldb, instance_id, name):
ou = ou_name(ldb, instance_id)
dn = "cn=%s,%s" % (name, ou)
- start = time.time()
ldb.add({
"dn": dn,
"objectclass": "group",
+ "sAMAccountName": name,
})
- end = time.time()
- duration = end - start
- print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
def user_name(instance_id, i):
return "STGU-%d-%d" % (instance_id, i)
+def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
+ """Seach objectclass, return attr in a set"""
+ objs = ldb.search(
+ expression="(objectClass={})".format(objectclass),
+ attrs=[attr]
+ )
+ return {str(obj[attr]) for obj in objs}
+
+
def generate_users(ldb, instance_id, number, password):
"""Add users to the server"""
+ existing_objects = search_objectclass(ldb, objectclass='user')
users = 0
for i in range(number, 0, -1):
- try:
- username = user_name(instance_id, i)
- create_user_account(ldb, instance_id, username, password)
+ name = user_name(instance_id, i)
+ if name not in existing_objects:
+ create_user_account(ldb, instance_id, name, password)
users += 1
- except LdbError as e:
- (status, _) = e
- # Stop if entry exists
- if status == 68:
- break
- else:
- raise
+ if users % 50 == 0:
+ LOGGER.info("Created %u/%u users" % (users, number))
return users
+def machine_name(instance_id, i, traffic_account=True):
+ """Generate a machine account name from instance id."""
+ if traffic_account:
+ # traffic accounts correspond to a given user, and use different
+ # userAccountControl flags to ensure packets get processed correctly
+ # by the DC
+ return "STGM-%d-%d" % (instance_id, i)
+ else:
+ # Otherwise we're just generating computer accounts to simulate a
+ # semi-realistic network. These use the default computer
+ # userAccountControl flags, so we use a different account name so that
+ # we don't try to use them when generating packets
+ return "PC-%d-%d" % (instance_id, i)
+
+
+def generate_machine_accounts(ldb, instance_id, number, password,
+ traffic_account=True):
+ """Add machine accounts to the server"""
+ existing_objects = search_objectclass(ldb, objectclass='computer')
+ added = 0
+ for i in range(number, 0, -1):
+ name = machine_name(instance_id, i, traffic_account)
+ if name + "$" not in existing_objects:
+ create_machine_account(ldb, instance_id, name, password,
+ traffic_account)
+ added += 1
+ if added % 50 == 0:
+ LOGGER.info("Created %u/%u machine accounts" % (added, number))
+
+ return added
+
+
def group_name(instance_id, i):
"""Generate a group name from instance id."""
return "STGG-%d-%d" % (instance_id, i)
def generate_groups(ldb, instance_id, number):
"""Create the required number of groups on the server."""
+ existing_objects = search_objectclass(ldb, objectclass='group')
groups = 0
for i in range(number, 0, -1):
- try:
- name = group_name(instance_id, i)
+ name = group_name(instance_id, i)
+ if name not in existing_objects:
create_group(ldb, instance_id, name)
groups += 1
- except LdbError as e:
- (status, _) = e
- # Stop if entry exists
- if status == 68:
- break
- else:
- raise
+ if groups % 1000 == 0:
+ LOGGER.info("Created %u/%u groups" % (groups, number))
+
return groups
try:
ldb.delete(ou, ["tree_delete:1"])
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# ignore does not exist
if status != 32:
raise
def generate_users_and_groups(ldb, instance_id, password,
number_of_users, number_of_groups,
- group_memberships):
+ group_memberships, max_members,
+ machine_accounts, traffic_accounts=True):
"""Generate the required users and groups, allocating the users to
those groups."""
- assignments = []
- groups_added = 0
+ memberships_added = 0
+ groups_added = 0
+ computers_added = 0
create_ou(ldb, instance_id)
- print("Generating dummy user accounts", file=sys.stderr)
+ LOGGER.info("Generating dummy user accounts")
users_added = generate_users(ldb, instance_id, number_of_users, password)
+ LOGGER.info("Generating dummy machine accounts")
+ computers_added = generate_machine_accounts(ldb, instance_id,
+ machine_accounts, password,
+ traffic_accounts)
+
if number_of_groups > 0:
- print("Generating dummy groups", file=sys.stderr)
+ LOGGER.info("Generating dummy groups")
groups_added = generate_groups(ldb, instance_id, number_of_groups)
if group_memberships > 0:
- print("Assigning users to groups", file=sys.stderr)
- assignments = assign_groups(number_of_groups,
- groups_added,
- number_of_users,
- users_added,
- group_memberships)
- print("Adding users to groups", file=sys.stderr)
+ LOGGER.info("Assigning users to groups")
+ assignments = GroupAssignments(number_of_groups,
+ groups_added,
+ number_of_users,
+ users_added,
+ group_memberships,
+ max_members)
+ LOGGER.info("Adding users to groups")
add_users_to_groups(ldb, instance_id, assignments)
+ memberships_added = assignments.total()
if (groups_added > 0 and users_added == 0 and
number_of_groups != groups_added):
- print("Warning: the added groups will contain no members",
- file=sys.stderr)
-
- print(("Added %d users, %d groups and %d group memberships" %
- (users_added, groups_added, len(assignments))),
- file=sys.stderr)
-
-
-def assign_groups(number_of_groups,
- groups_added,
- number_of_users,
- users_added,
- group_memberships):
- """Allocate users to groups.
-
- The intention is to have a few users that belong to most groups, while
- the majority of users belong to a few groups.
-
- A few groups will contain most users, with the remaining only having a
- few users.
- """
+ LOGGER.warning("The added groups will contain no members")
+
+ LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
+ (users_added, computers_added, groups_added,
+ memberships_added))
+
+
+class GroupAssignments(object):
+ def __init__(self, number_of_groups, groups_added, number_of_users,
+ users_added, group_memberships, max_members):
+
+ self.count = 0
+ self.generate_group_distribution(number_of_groups)
+ self.generate_user_distribution(number_of_users, group_memberships)
+ self.max_members = max_members
+ self.assignments = defaultdict(list)
+ self.assign_groups(number_of_groups, groups_added, number_of_users,
+ users_added, group_memberships)
+
+ def cumulative_distribution(self, weights):
+ # make sure the probabilities conform to a cumulative distribution
+ # spread between 0.0 and 1.0. Dividing by the weighted total gives each
+ # probability a proportional share of 1.0. Higher probabilities get a
+ # bigger share, so are more likely to be picked. We use the cumulative
+ # value, so we can use random.random() as a simple index into the list
+ dist = []
+ total = sum(weights)
+ if total == 0:
+ return None
+
+ cumulative = 0.0
+ for probability in weights:
+ cumulative += probability
+ dist.append(cumulative / total)
+ return dist
- def generate_user_distribution(n):
+ def generate_user_distribution(self, num_users, num_memberships):
"""Probability distribution of a user belonging to a group.
"""
- dist = []
- for x in range(1, n + 1):
- p = 1 / (x + 0.001)
- dist.append(p)
- return dist
+ # Assign a weighted probability to each user. Use the Pareto
+ # Distribution so that some users are in a lot of groups, and the
+ # bulk of users are in only a few groups. If we're assigning a large
+ # number of group memberships, use a higher shape. This means slightly
+ # fewer outlying users that are in large numbers of groups. The aim is
+ # to have no users belonging to more than ~500 groups.
+ if num_memberships > 5000000:
+ shape = 3.0
+ elif num_memberships > 2000000:
+ shape = 2.5
+ elif num_memberships > 300000:
+ shape = 2.25
+ else:
+ shape = 1.75
- def generate_group_distribution(n):
+ weights = []
+ for x in range(1, num_users + 1):
+ p = random.paretovariate(shape)
+ weights.append(p)
+
+ # convert the weights to a cumulative distribution between 0.0 and 1.0
+ self.user_dist = self.cumulative_distribution(weights)
+
+ def generate_group_distribution(self, n):
"""Probability distribution of a group containing a user."""
- dist = []
+
+ # Assign a weighted probability to each user. Probability decreases
+ # as the group-ID increases
+ weights = []
for x in range(1, n + 1):
p = 1 / (x**1.3)
- dist.append(p)
- return dist
+ weights.append(p)
+
+ # convert the weights to a cumulative distribution between 0.0 and 1.0
+ self.group_weights = weights
+ self.group_dist = self.cumulative_distribution(weights)
+
+ def generate_random_membership(self):
+ """Returns a randomly generated user-group membership"""
+
+ # the list items are cumulative distribution values between 0.0 and
+ # 1.0, which makes random() a handy way to index the list to get a
+ # weighted random user/group. (Here the user/group returned are
+ # zero-based array indexes)
+ user = bisect.bisect(self.user_dist, random.random())
+ group = bisect.bisect(self.group_dist, random.random())
+
+ return user, group
+
+ def users_in_group(self, group):
+ return self.assignments[group]
+
+ def get_groups(self):
+ return self.assignments.keys()
+
+ def cap_group_membership(self, group, max_members):
+ """Prevent the group's membership from exceeding the max specified"""
+ num_members = len(self.assignments[group])
+ if num_members >= max_members:
+ LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+ # remove this group and then recalculate the cumulative
+ # distribution, so this group is no longer selected
+ self.group_weights[group - 1] = 0
+ new_dist = self.cumulative_distribution(self.group_weights)
+ self.group_dist = new_dist
+
+ def add_assignment(self, user, group):
+ # the assignments are stored in a dictionary where key=group,
+ # value=list-of-users-in-group (indexing by group-ID allows us to
+ # optimize for DB membership writes)
+ if user not in self.assignments[group]:
+ self.assignments[group].append(user)
+ self.count += 1
+
+ # check if there'a cap on how big the groups can grow
+ if self.max_members:
+ self.cap_group_membership(group, self.max_members)
+
+ def assign_groups(self, number_of_groups, groups_added,
+ number_of_users, users_added, group_memberships):
+ """Allocate users to groups.
+
+ The intention is to have a few users that belong to most groups, while
+ the majority of users belong to a few groups.
+
+ A few groups will contain most users, with the remaining only having a
+ few users.
+ """
- assignments = set()
- if group_memberships <= 0:
- return assignments
+ if group_memberships <= 0:
+ return
- group_dist = generate_group_distribution(number_of_groups)
- user_dist = generate_user_distribution(number_of_users)
+ # Calculate the number of group menberships required
+ group_memberships = math.ceil(
+ float(group_memberships) *
+ (float(users_added) / float(number_of_users)))
- # Calculate the number of group menberships required
- group_memberships = math.ceil(
- float(group_memberships) *
- (float(users_added) / float(number_of_users)))
+ if self.max_members:
+ group_memberships = min(group_memberships,
+ self.max_members * number_of_groups)
- existing_users = number_of_users - users_added - 1
- existing_groups = number_of_groups - groups_added - 1
- while len(assignments) < group_memberships:
- user = random.randint(0, number_of_users - 1)
- group = random.randint(0, number_of_groups - 1)
- probability = group_dist[group] * user_dist[user]
+ existing_users = number_of_users - users_added - 1
+ existing_groups = number_of_groups - groups_added - 1
+ while self.total() < group_memberships:
+ user, group = self.generate_random_membership()
- if ((random.random() < probability * 10000) and
- (group > existing_groups or user > existing_users)):
- # the + 1 converts the array index to the corresponding
- # group or user number
- assignments.add(((user + 1), (group + 1)))
+ if group > existing_groups or user > existing_users:
+ # the + 1 converts the array index to the corresponding
+ # group or user number
+ self.add_assignment(user + 1, group + 1)
- return assignments
+ def total(self):
+ return self.count
def add_users_to_groups(db, instance_id, assignments):
- """Add users to their assigned groups.
+ """Takes the assignments of users to groups and applies them to the DB."""
+
+ total = assignments.total()
+ count = 0
+ added = 0
- Takes the list of (group,user) tuples generated by assign_groups and
- assign the users to their specified groups."""
+ for group in assignments.get_groups():
+ users_in_group = assignments.users_in_group(group)
+ if len(users_in_group) == 0:
+ continue
+
+ # Split up the users into chunks, so we write no more than 1K at a
+ # time. (Minimizing the DB modifies is more efficient, but writing
+ # 10K+ users to a single group becomes inefficient memory-wise)
+ for chunk in range(0, len(users_in_group), 1000):
+ chunk_of_users = users_in_group[chunk:chunk + 1000]
+ add_group_members(db, instance_id, group, chunk_of_users)
+
+ added += len(chunk_of_users)
+ count += 1
+ if count % 50 == 0:
+ LOGGER.info("Added %u/%u memberships" % (added, total))
+
+def add_group_members(db, instance_id, group, users_in_group):
+ """Adds the given users to group specified."""
ou = ou_name(db, instance_id)
def build_dn(name):
return("cn=%s,%s" % (name, ou))
- for (user, group) in assignments:
- user_dn = build_dn(user_name(instance_id, user))
- group_dn = build_dn(group_name(instance_id, group))
+ group_dn = build_dn(group_name(instance_id, group))
+ m = ldb.Message()
+ m.dn = ldb.Dn(db, group_dn)
- m = ldb.Message()
- m.dn = ldb.Dn(db, group_dn)
- m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
- start = time.time()
- db.modify(m)
- end = time.time()
- duration = end - start
- print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
+ for user in users_in_group:
+ user_dn = build_dn(user_name(instance_id, user))
+ idx = "member-" + str(user)
+ m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
+
+ db.modify(m)
def generate_stats(statsdir, timing_file):
else:
failure_rate = failed / duration
- # print the stats in more human-readable format when stdout is going to the
- # console (as opposed to being redirected to a file)
- if sys.stdout.isatty():
- print("Total conversations: %10d" % conversations)
- print("Successful operations: %10d (%.3f per second)"
- % (successful, success_rate))
- print("Failed operations: %10d (%.3f per second)"
- % (failed, failure_rate))
- else:
- print("(%d, %d, %d, %.3f, %.3f)" %
- (conversations, successful, failed, success_rate, failure_rate))
+ print("Total conversations: %10d" % conversations)
+ print("Successful operations: %10d (%.3f per second)"
+ % (successful, success_rate))
+ print("Failed operations: %10d (%.3f per second)"
+ % (failed, failure_rate))
+
+ print("Protocol Op Code Description "
+ " Count Failed Mean Median "
+ "95% Range Max")
- if sys.stdout.isatty():
- print("Protocol Op Code Description "
- " Count Failed Mean Median "
- "95% Range Max")
- else:
- print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
- "\tmax")
protocols = sorted(latencies.keys())
for protocol in protocols:
packet_types = sorted(latencies[protocol], key=opcode_key)
def mk_masked_dir(*path):
- """In a testenv we end up with 0777 diectories that look an alarming
+ """In a testenv we end up with 0777 directories that look an alarming
green colour with ls. Use umask to avoid that."""
+ # py3 os.mkdir can do this
d = os.path.join(*path)
mask = os.umask(0o077)
os.mkdir(d)