# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
-from __future__ import print_function
+from __future__ import print_function, division
import time
import os
import traceback
from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
from samba.auth import system_session
-from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
-from samba.dsdb import UF_NORMAL_ACCOUNT
-from samba.dcerpc.misc import SEC_CHAN_WKSTA
+from samba.dsdb import (
+ UF_NORMAL_ACCOUNT,
+ UF_SERVER_TRUST_ACCOUNT,
+ UF_TRUSTED_FOR_DELEGATION
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
from samba import gensec
+from samba import sd_utils
SLEEP_OVERHEAD = 3e-4
class Packet(object):
"""Details of a network packet"""
- def __init__(self, fields):
- if isinstance(fields, str):
- fields = fields.rstrip('\n').split('\t')
+ def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra):
+ self.timestamp = timestamp
+ self.ip_protocol = ip_protocol
+ self.stream_number = stream_number
+ self.src = src
+ self.dest = dest
+ self.protocol = protocol
+ self.opcode = opcode
+ self.desc = desc
+ self.extra = extra
+ if self.src < self.dest:
+ self.endpoints = (self.src, self.dest)
+ else:
+ self.endpoints = (self.dest, self.src)
+
+ @classmethod
+ def from_line(self, line):
+ fields = line.rstrip('\n').split('\t')
(timestamp,
ip_protocol,
stream_number,
desc) = fields[:8]
extra = fields[8:]
- self.timestamp = float(timestamp)
- self.ip_protocol = ip_protocol
- try:
- self.stream_number = int(stream_number)
- except (ValueError, TypeError):
- self.stream_number = None
- self.src = int(src)
- self.dest = int(dest)
- self.protocol = protocol
- self.opcode = opcode
- self.desc = desc
- self.extra = extra
+ timestamp = float(timestamp)
+ src = int(src)
+ dest = int(dest)
- if self.src < self.dest:
- self.endpoints = (self.src, self.dest)
- else:
- self.endpoints = (self.dest, self.src)
+ return Packet(timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra)
def as_summary(self, time_offset=0.0):
"""Format the packet as a traffic_summary line.
return "<Packet @%s>" % self
def copy(self):
- return self.__class__([self.timestamp,
- self.ip_protocol,
- self.stream_number,
- self.src,
- self.dest,
- self.protocol,
- self.opcode,
- self.desc] + self.extra)
+ return self.__class__(self.timestamp,
+ self.ip_protocol,
+ self.stream_number,
+ self.src,
+ self.dest,
+ self.protocol,
+ self.opcode,
+ self.desc,
+ self.extra)
def as_packet_type(self):
t = '%s:%s' % (self.protocol, self.opcode)
return False
fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
- try:
- fn = getattr(traffic_packets, fn_name)
- if fn is traffic_packets.null_packet:
- return False
- except AttributeError:
+ fn = getattr(traffic_packets, fn_name, None)
+ if not fn:
print("missing packet %s" % fn_name, file=sys.stderr)
return False
+ if fn is traffic_packets.null_packet:
+ return False
return True
self.last_netlogon_bad = False
self.last_samlogon_bad = False
self.generate_ldap_search_tables()
- self.next_conversation_id = itertools.count().next
+ self.next_conversation_id = itertools.count()
def generate_ldap_search_tables(self):
session = system_session()
res = db.search(db.domain_dn(),
scope=ldb.SCOPE_SUBTREE,
+ controls=["paged_results:1:1000"],
attrs=['dn'])
# find a list of dns for each pattern
# for k, v in self.dn_map.items():
# print >>sys.stderr, k, len(v)
- for k, v in dn_map.items():
+ for k in list(dn_map.keys()):
if k[-3:] != ',DC':
continue
p = k[:-3]
'conversation-%d' %
conversation.conversation_id)
- self.lp.set("private dir", self.tempdir)
- self.lp.set("lock dir", self.tempdir)
+ self.lp.set("private dir", self.tempdir)
+ self.lp.set("lock dir", self.tempdir)
self.lp.set("state directory", self.tempdir)
self.lp.set("tls verify peer", "no_check")
self.user_creds.set_workstation(self.netbios_name)
self.user_creds.set_password(self.userpass)
self.user_creds.set_username(self.username)
+ self.user_creds.set_domain(self.domain)
if self.prefer_kerberos:
self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
else:
self.machine_creds = Credentials()
self.machine_creds.guess(self.lp)
self.machine_creds.set_workstation(self.netbios_name)
- self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
+ self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds.set_password(self.machinepass)
self.machine_creds.set_username(self.netbios_name + "$")
+ self.machine_creds.set_domain(self.domain)
if self.prefer_kerberos:
self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
else:
self.machine_creds_bad = Credentials()
self.machine_creds_bad.guess(self.lp)
self.machine_creds_bad.set_workstation(self.netbios_name)
- self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
+ self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds_bad.set_password(self.machinepass[:-4])
self.machine_creds_bad.set_username(self.netbios_name + "$")
if self.prefer_kerberos:
return self.ldap_connections[-1]
def simple_bind(creds):
+ """
+ To run simple bind against Windows, we need to run
+ following commands in PowerShell:
+
+ Install-windowsfeature ADCS-Cert-Authority
+ Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
+ Restart-Computer
+
+ """
return SamDB('ldaps://%s' % self.server,
credentials=creds,
lp=self.lp)
def get_samr_context(self, new=False):
if not self.samr_contexts or new:
- self.samr_contexts.append(SamrContext(self.server))
+ self.samr_contexts.append(
+ SamrContext(self.server, lp=self.lp, creds=self.creds))
return self.samr_contexts[-1]
def get_netlogon_connection(self):
class SamrContext(object):
"""State/Context associated with a samr connection.
"""
- def __init__(self, server):
+ def __init__(self, server, lp=None, creds=None):
self.connection = None
self.handle = None
self.domain_handle = None
self.user_handle = None
self.rids = None
self.server = server
+ self.lp = lp
+ self.creds = creds
def get_connection(self):
if not self.connection:
- self.connection = samr.samr("ncacn_ip_tcp:%s" % (self.server))
+ self.connection = samr.samr(
+ "ncacn_ip_tcp:%s[seal]" % (self.server),
+ lp_ctx=self.lp,
+ credentials=self.creds)
+
return self.connection
def get_handle(self):
if p.is_really_a_packet():
self.packets.append(p)
- def add_short_packet(self, timestamp, p, extra, client=True):
+ def add_short_packet(self, timestamp, protocol, opcode, extra,
+ client=True):
"""Create a packet from a timestamp, and 'protocol:opcode' pair, and a
(possibly empty) list of extra data. If client is True, assume
this packet is from the client to the server.
"""
- protocol, opcode = p.split(':', 1)
src, dest = self.guess_client_server()
if not client:
src, dest = dest, src
-
- desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
- ip_protocol = IP_PROTOCOLS.get(protocol, '06')
- fields = [timestamp - self.start_time, ip_protocol,
- '', src, dest,
- protocol, opcode, desc]
- fields.extend(extra)
- packet = Packet(fields)
+ key = (protocol, opcode)
+ desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
+ if protocol in IP_PROTOCOLS:
+ ip_protocol = IP_PROTOCOLS[protocol]
+ else:
+ ip_protocol = '06'
+ packet = Packet(timestamp - self.start_time, ip_protocol,
+ '', src, dest,
+ protocol, opcode, desc, extra)
# XXX we're assuming the timestamp is already adjusted for
# this conversation?
# XXX should we adjust client balance for guessed packets?
gap = t - now
print("gap is now %f" % gap, file=sys.stderr)
- self.conversation_id = context.next_conversation_id()
+ self.conversation_id = next(context.next_conversation_id)
pid = os.fork()
if pid != 0:
return pid
:param s: start of the window
:param e: end of the window
"""
-
- new_packets = []
- for p in self.packets:
- if p.timestamp < s or p.timestamp > e:
- continue
- new_packets.append(p)
-
- self.packets = new_packets
- if new_packets:
- self.start_time = new_packets[0].timestamp
- else:
- self.start_time = None
+ self.packets = [p for p in self.packets if s <= p.timestamp <= e]
+ self.start_time = self.packets[0].timestamp if self.packets else None
def renormalise_times(self, start_time):
"""Adjust the packet start times relative to the new start time."""
f = open(f)
print("Ingesting %s" % (f.name,), file=sys.stderr)
for line in f:
- p = Packet(line)
+ p = Packet.from_line(line)
if p.protocol == 'dns' and dns_mode != 'include':
dns_counts[p.opcode] += 1
else:
timestamp += wait
if hard_stop is not None and timestamp > hard_stop:
break
- c.add_short_packet(timestamp, p, extra)
+ c.add_short_packet(timestamp, protocol, opcode, extra)
key = key[1:] + (p,)
client += 1
print(("we have %d conversations at rate %f" %
- (len(conversations), rate)), file=sys.stderr)
+ (len(conversations), rate)), file=sys.stderr)
conversations.sort()
return conversations
print(("we have %d accounts but %d conversations" %
(accounts, conversations)), file=sys.stderr)
- cstack = zip(sorted(conversations,
- key=lambda x: x.start_time, reverse=True),
- accounts)
+ cstack = list(zip(
+ sorted(conversations, key=lambda x: x.start_time, reverse=True),
+ accounts))
# Set the process group so that the calling scripts are not killed
# when the forked child processes are killed.
finally:
for s in (15, 15, 9):
print(("killing %d children with -%d" %
- (len(children), s)), file=sys.stderr)
+ (len(children), s)), file=sys.stderr)
for pid in children:
try:
os.kill(pid, s)
"""
ou = ou_name(ldb, instance_id)
try:
- ldb.add({"dn": ou.split(',', 1)[1],
+ ldb.add({"dn": ou.split(',', 1)[1],
"objectclass": "organizationalunit"})
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# ignore already exists
if status != 68:
raise
try:
- ldb.add({"dn": ou,
+ ldb.add({"dn": ou,
"objectclass": "organizationalunit"})
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# ignore already exists
if status != 68:
raise
create_machine_account(ldb, instance_id, netbios_name, password)
added += 1
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
if status == 68:
break
else:
create_user_account(ldb, instance_id, username, password)
added += 1
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
if status == 68:
break
else:
"objectclass": "computer",
"sAMAccountName": "%s$" % netbios_name,
"userAccountControl":
- str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
+ str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
"unicodePwd": utf16pw})
end = time.time()
duration = end - start
"userAccountControl": str(UF_NORMAL_ACCOUNT),
"unicodePwd": utf16pw
})
+
+ # grant user write permission to do things like write account SPN
+ sdutils = sd_utils.SDUtils(ldb)
+ sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
+
end = time.time()
duration = end - start
print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
ldb.add({
"dn": dn,
"objectclass": "group",
+ "sAMAccountName": name,
})
end = time.time()
duration = end - start
create_user_account(ldb, instance_id, username, password)
users += 1
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# Stop if entry exists
if status == 68:
break
create_group(ldb, instance_id, name)
groups += 1
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# Stop if entry exists
if status == 68:
break
try:
ldb.delete(ou, ["tree_delete:1"])
except LdbError as e:
- (status, _) = e
+ (status, _) = e.args
# ignore does not exist
if status != 32:
raise
else:
failure_rate = failed / duration
- # print the stats in more human-readable format when stdout is going to the
- # console (as opposed to being redirected to a file)
- if sys.stdout.isatty():
- print("Total conversations: %10d" % conversations)
- print("Successful operations: %10d (%.3f per second)"
- % (successful, success_rate))
- print("Failed operations: %10d (%.3f per second)"
- % (failed, failure_rate))
- else:
- print("(%d, %d, %d, %.3f, %.3f)" %
- (conversations, successful, failed, success_rate, failure_rate))
+ print("Total conversations: %10d" % conversations)
+ print("Successful operations: %10d (%.3f per second)"
+ % (successful, success_rate))
+ print("Failed operations: %10d (%.3f per second)"
+ % (failed, failure_rate))
+
+ print("Protocol Op Code Description "
+ " Count Failed Mean Median "
+ "95% Range Max")
- if sys.stdout.isatty():
- print("Protocol Op Code Description "
- " Count Failed Mean Median "
- "95% Range Max")
- else:
- print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
- "\tmax")
protocols = sorted(latencies.keys())
for protocol in protocols:
packet_types = sorted(latencies[protocol], key=opcode_key)