ldb: Free memory when repacking database
[garming/samba-autobuild/.git] / script / traffic_replay
index 6f42f2d68cd5b0e535c46e23ef58b7a36c5c2ec5..d29f0a9839c05fd000a1d1c33e6f715c7e7981f4 100755 (executable)
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 # Generates samba network traffic
 #
 # Copyright (C) Catalyst IT Ltd. 2017
@@ -22,12 +22,16 @@ import os
 import optparse
 import tempfile
 import shutil
+import random
 
 sys.path.insert(0, "bin/python")
 
-from samba import gensec
+from samba import gensec, get_debug_level
 from samba.emulate import traffic
 import samba.getopt as options
+from samba.logger import get_samba_logger
+from samba.samdb import SamDB
+from samba.auth import system_session
 
 
 def print_err(*args, **kwargs):
@@ -36,17 +40,18 @@ def print_err(*args, **kwargs):
 
 def main():
 
-    desc = ("Generates network traffic 'conversations' based on <summary-file>"
-            " (which should be the output file produced by either traffic_learner"
-            " or traffic_summary.pl). This traffic is sent to <dns-hostname>,"
+    desc = ("Generates network traffic 'conversations' based on a model generated"
+            " by script/traffic_learner. This traffic is sent to <dns-hostname>,"
             " which is the full DNS hostname of the DC being tested.")
 
     parser = optparse.OptionParser(
-        "%prog [--help|options] <summary-file> <dns-hostname>",
+        "%prog [--help|options] <model-file> <dns-hostname>",
         description=desc)
 
     parser.add_option('--dns-rate', type='float', default=0,
                       help='fire extra DNS packets at this rate')
+    parser.add_option('--dns-query-file', dest="dns_query_file",
+                      help='A file contains DNS query list')
     parser.add_option('-B', '--badpassword-frequency',
                       type='float', default=0.0,
                       help='frequency of connections with bad passwords')
@@ -67,20 +72,34 @@ def main():
     parser.add_option('-c', '--clean-up',
                       action="store_true",
                       help='Clean up the generated groups and user accounts')
-
+    parser.add_option('--random-seed', type='int', default=None,
+                      help='Use to keep randomness consistent across multiple runs')
+    parser.add_option('--stop-on-any-error',
+                      action="store_true",
+                      help='abort the whole thing if a child fails')
     model_group = optparse.OptionGroup(parser, 'Traffic Model Options',
                                        'These options alter the traffic '
-                                       'generated when the summary-file is a '
-                                       'traffic-model (produced by '
-                                       'traffic_learner)')
-    model_group.add_option('-S', '--scale-traffic', type='float', default=1.0,
-                           help='Increase the number of conversations by '
-                           'this factor')
-    model_group.add_option('-D', '--duration', type='float', default=None,
+                                       'generated by the model')
+    model_group.add_option('-S', '--scale-traffic', type='float',
+                           help=('Increase the number of conversations by '
+                                 'this factor (or use -T)'))
+    parser.add_option('-T', '--packets-per-second', type=float,
+                      help=('attempt this many packets per second '
+                            '(alternative to -S)'))
+    parser.add_option('--old-scale',
+                      action="store_true",
+                      help='emulate the old scale for traffic')
+    model_group.add_option('-D', '--duration', type='float', default=60.0,
                            help=('Run model for this long (approx). '
                                  'Default 60s for models'))
+    model_group.add_option('--latency-timeout', type='float', default=None,
+                           help=('Wait this long for last packet to finish'))
     model_group.add_option('-r', '--replay-rate', type='float', default=1.0,
                            help='Replay the traffic faster by this factor')
+    model_group.add_option('--conversation-persistence', type='float',
+                           default=0.0,
+                           help=('chance (0 to 1) that a conversation waits '
+                                 'when it would have died'))
     model_group.add_option('--traffic-summary',
                            help=('Generate a traffic summary file and write '
                                  'it here (- for stdout)'))
@@ -97,7 +116,7 @@ def main():
                               'the traffic')
     user_gen_group.add_option('-n', '--number-of-users', type='int', default=0,
                               help='Total number of test users to create')
-    user_gen_group.add_option('--number-of-groups', type='int', default=0,
+    user_gen_group.add_option('--number-of-groups', type='int', default=None,
                               help='Create this many groups')
     user_gen_group.add_option('--average-groups-per-user',
                               type='int', default=0,
@@ -106,6 +125,8 @@ def main():
     user_gen_group.add_option('--group-memberships', type='int', default=0,
                               help='Total memberships to assign across all '
                               'test users and all groups')
+    user_gen_group.add_option('--max-members', type='int', default=None,
+                              help='Max users to add to any one group')
     parser.add_option_group(user_gen_group)
 
     sambaopts = options.SambaOptions(parser)
@@ -123,16 +144,26 @@ def main():
     # First ensure we have reasonable arguments
 
     if len(args) == 1:
-        summary = None
+        model_file = None
         host    = args[0]
     elif len(args) == 2:
-        summary, host = args
+        model_file, host = args
     else:
         parser.print_usage()
         return
 
+    lp = sambaopts.get_loadparm()
+    debuglevel = get_debug_level()
+    logger = get_samba_logger(name=__name__,
+                              verbose=debuglevel > 3,
+                              quiet=debuglevel < 1)
+
+    traffic.DEBUG_LEVEL = debuglevel
+    # pass log level down to traffic module to make sure level is controlled
+    traffic.LOGGER.setLevel(logger.getEffectiveLevel())
+
     if opts.clean_up:
-        print_err("Removing user and machine accounts")
+        logger.info("Removing user and machine accounts")
         lp    = sambaopts.get_loadparm()
         creds = credopts.get_credentials(lp)
         creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
@@ -140,22 +171,24 @@ def main():
         traffic.clean_up_accounts(ldb, opts.instance_id)
         exit(0)
 
-    if summary:
-        if not os.path.exists(summary):
-            print_err("Summary file %s doesn't exist" % summary)
+    if model_file:
+        if not os.path.exists(model_file):
+            logger.error("Model file %s doesn't exist" % model_file)
             sys.exit(1)
-    # the summary-file can be ommitted for --generate-users-only and
+    # the model-file can be ommitted for --generate-users-only and
     # --cleanup-up, but it should be specified in all other cases
     elif not opts.generate_users_only:
-        print_err("No summary-file specified to replay traffic from")
+        logger.error("No model file specified to replay traffic from")
         sys.exit(1)
 
     if not opts.fixed_password:
-        print_err(("Please use --fixed-password to specify a password"
-                             " for the users created as part of this test"))
+        logger.error(("Please use --fixed-password to specify a password"
+                      " for the users created as part of this test"))
         sys.exit(1)
 
-    lp = sambaopts.get_loadparm()
+    if opts.random_seed is not None:
+        random.seed(opts.random_seed)
+
     creds = credopts.get_credentials(lp)
     creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
 
@@ -165,153 +198,182 @@ def main():
     else:
         domain = lp.get("workgroup")
         if domain == "WORKGROUP":
-            print_err(("NETBIOS domain does not appear to be "
-                       "specified, use the --workgroup option"))
+            logger.error(("NETBIOS domain does not appear to be "
+                          "specified, use the --workgroup option"))
             sys.exit(1)
 
     if not opts.realm and not lp.get('realm'):
-        print_err("Realm not specified, use the --realm option")
+        logger.error("Realm not specified, use the --realm option")
         sys.exit(1)
 
     if opts.generate_users_only and not (opts.number_of_users or
                                          opts.number_of_groups):
-        print_err(("Please specify the number of users and/or groups "
-                   "to generate."))
+        logger.error(("Please specify the number of users and/or groups "
+                      "to generate."))
         sys.exit(1)
 
     if opts.group_memberships and opts.average_groups_per_user:
-        print_err(("--group-memberships and --average-groups-per-user"
-                   " are incompatible options - use one or the other"))
+        logger.error(("--group-memberships and --average-groups-per-user"
+                      " are incompatible options - use one or the other"))
         sys.exit(1)
 
     if not opts.number_of_groups and opts.average_groups_per_user:
-        print_err(("--average-groups-per-user requires "
-                   "--number-of-groups"))
+        logger.error(("--average-groups-per-user requires "
+                      "--number-of-groups"))
         sys.exit(1)
 
+    if opts.number_of_groups and opts.average_groups_per_user:
+        if opts.number_of_groups < opts.average_groups_per_user:
+            logger.error(("--average-groups-per-user can not be more than "
+                          "--number-of-groups"))
+            sys.exit(1)
+
     if not opts.number_of_groups and opts.group_memberships:
-        print_err("--group-memberships requires --number-of-groups")
+        logger.error("--group-memberships requires --number-of-groups")
+        sys.exit(1)
+
+    if opts.scale_traffic is not None and opts.packets_per_second is not None:
+        logger.error("--scale-traffic and --packets-per-second "
+                     "are incompatible. Use one or the other.")
         sys.exit(1)
 
+    if not opts.scale_traffic and not opts.packets_per_second:
+        logger.info("No packet rate specified. Using --scale-traffic=1.0")
+        opts.scale_traffic = 1.0
+
     if opts.timing_data not in ('-', None):
         try:
             open(opts.timing_data, 'w').close()
-        except IOError as e:
-            print_err(("the supplied timing data destination "
-                       "(%s) is not writable" % opts.timing_data))
-            print_err(e)
+        except IOError:
+            # exception info will be added to log automatically
+            logger.exception(("the supplied timing data destination "
+                              "(%s) is not writable" % opts.timing_data))
             sys.exit()
 
     if opts.traffic_summary not in ('-', None):
         try:
             open(opts.traffic_summary, 'w').close()
-        except IOError as e:
-            print_err(("the supplied traffic summary destination "
-                       "(%s) is not writable" % opts.traffic_summary))
-            print_err(e)
+        except IOError:
+            # exception info will be added to log automatically
+            if debuglevel > 0:
+                import traceback
+                traceback.print_exc()
+            logger.exception(("the supplied traffic summary destination "
+                              "(%s) is not writable" % opts.traffic_summary))
             sys.exit()
 
-    traffic.DEBUG_LEVEL = opts.debuglevel
-
-    duration = opts.duration
-    if duration is None:
-        duration = 60.0
+    if opts.old_scale:
+        # we used to use a silly calculation based on the number
+        # of conversations; now we use the number of packets and
+        # scale traffic accurately. To roughly compare with older
+        # numbers you use --old-scale which approximates as follows:
+        opts.scale_traffic *= 0.55
 
-    # ingest the model or traffic summary
-    if summary:
+    # ingest the model
+    if model_file and not opts.generate_users_only:
+        model = traffic.TrafficModel()
         try:
-            conversations, interval, duration, dns_counts = \
-                                            traffic.ingest_summaries([summary])
-
-            print_err(("Using conversations from the traffic summary "
-                       "file specified"))
-
-            # honour the specified duration if it's different to the
-            # capture duration
-            if opts.duration is not None:
-                duration = opts.duration
-
-        except ValueError as e:
-            if not e.message.startswith('need more than'):
-                raise
-
-            model = traffic.TrafficModel()
-
-            try:
-                model.load(summary)
-            except ValueError:
-                print_err(("Could not parse %s. The summary file "
-                           "should be the output from either the "
-                           "traffic_summary.pl or "
-                           "traffic_learner scripts."
-                           % summary))
-                sys.exit()
-
-            print_err(("Using the specified model file to "
-                       "generate conversations"))
+            model.load(model_file)
+        except ValueError:
+            if debuglevel > 0:
+                import traceback
+                traceback.print_exc()
+            logger.error(("Could not parse %s, which does not seem to be "
+                          "a model generated by script/traffic_learner."
+                          % model_file))
+            sys.exit(1)
 
-            conversations = model.generate_conversations(opts.scale_traffic,
-                                                         duration,
-                                                         opts.replay_rate)
+        logger.info(("Using the specified model file to "
+                     "generate conversations"))
 
+        if opts.scale_traffic:
+            packets_per_second = model.scale_to_packet_rate(opts.scale_traffic)
+        else:
+            packets_per_second =  opts.packets_per_second
+
+        conversations = \
+            model.generate_conversation_sequences(
+                packets_per_second,
+                opts.duration,
+                opts.replay_rate,
+                opts.conversation_persistence)
     else:
         conversations = []
 
-    if opts.debuglevel > 5:
-        for c in conversations:
-            for p in c.packets:
-                print("    ", p)
-
-        print('=' * 72)
-
     if opts.number_of_users and opts.number_of_users < len(conversations):
-        print_err(("--number-of-users (%d) is less than the "
-                   "number of conversations to replay (%d)"
-                   % (opts.number_of_users, len(conversations))))
+        logger.error(("--number-of-users (%d) is less than the "
+                      "number of conversations to replay (%d)"
+                     % (opts.number_of_users, len(conversations))))
         sys.exit(1)
 
     number_of_users = max(opts.number_of_users, len(conversations))
+
+    if opts.number_of_groups is None:
+        opts.number_of_groups = max(int(number_of_users / 10), 1)
+
     max_memberships = number_of_users * opts.number_of_groups
 
     if not opts.group_memberships and opts.average_groups_per_user:
         opts.group_memberships = opts.average_groups_per_user * number_of_users
-        print_err(("Using %d group-memberships based on %u average "
-                   "memberships for %d users"
-                   % (opts.group_memberships,
-                      opts.average_groups_per_user, number_of_users)))
+        logger.info(("Using %d group-memberships based on %u average "
+                     "memberships for %d users"
+                     % (opts.group_memberships,
+                        opts.average_groups_per_user, number_of_users)))
 
     if opts.group_memberships > max_memberships:
-        print_err(("The group memberships specified (%d) exceeds "
-                   "the total users (%d) * total groups (%d)"
-                   % (opts.group_memberships, number_of_users,
-                      opts.number_of_groups)))
+        logger.error(("The group memberships specified (%d) exceeds "
+                      "the total users (%d) * total groups (%d)"
+                      % (opts.group_memberships, number_of_users,
+                         opts.number_of_groups)))
         sys.exit(1)
 
+    # if no groups were specified by the user, then make sure we create some
+    # group memberships (otherwise it's not really a fair test)
+    if not opts.group_memberships and not opts.average_groups_per_user:
+        opts.group_memberships = min(number_of_users * 5, max_memberships)
+
+    # Get an LDB connection.
     try:
-        ldb = traffic.openLdb(host, creds, lp)
+        # if we're only adding users, then it's OK to pass a sam.ldb filepath
+        # as the host, which creates the users much faster. In all other cases
+        # we should be connecting to a remote DC
+        if opts.generate_users_only and os.path.isfile(host):
+            ldb = SamDB(url="ldb://{0}".format(host),
+                        session_info=system_session(), lp=lp)
+        else:
+            ldb = traffic.openLdb(host, creds, lp)
     except:
-        print_err(("\nInitial LDAP connection failed! Did you supply "
-                   "a DNS host name and the correct credentials?"))
+        logger.error(("\nInitial LDAP connection failed! Did you supply "
+                      "a DNS host name and the correct credentials?"))
         sys.exit(1)
 
     if opts.generate_users_only:
+        # generate computer accounts for added realism. Assume there will be
+        # some overhang with more computer accounts than users
+        computer_accounts = int(1.25 * number_of_users)
         traffic.generate_users_and_groups(ldb,
                                           opts.instance_id,
                                           opts.fixed_password,
                                           opts.number_of_users,
                                           opts.number_of_groups,
-                                          opts.group_memberships)
+                                          opts.group_memberships,
+                                          opts.max_members,
+                                          machine_accounts=computer_accounts,
+                                          traffic_accounts=False)
         sys.exit()
 
     tempdir = tempfile.mkdtemp(prefix="samba_tg_")
-    print_err("Using temp dir %s" % tempdir)
+    logger.info("Using temp dir %s" % tempdir)
 
     traffic.generate_users_and_groups(ldb,
                                       opts.instance_id,
                                       opts.fixed_password,
                                       number_of_users,
                                       opts.number_of_groups,
-                                      opts.group_memberships)
+                                      opts.group_memberships,
+                                      opts.max_members,
+                                      machine_accounts=len(conversations),
+                                      traffic_accounts=True)
 
     accounts = traffic.generate_replay_accounts(ldb,
                                                 opts.instance_id,
@@ -326,9 +388,9 @@ def main():
         else:
             summary_dest = open(opts.traffic_summary, 'w')
 
-        print_err("Writing traffic summary")
+        logger.info("Writing traffic summary")
         summaries = []
-        for c in conversations:
+        for c in traffic.seq_to_conversations(conversations):
             summaries += c.replay_as_summary_lines()
 
         summaries.sort()
@@ -337,12 +399,15 @@ def main():
 
         exit(0)
 
-    traffic.replay(conversations, host,
+    traffic.replay(conversations,
+                   host,
                    lp=lp,
                    creds=creds,
                    accounts=accounts,
                    dns_rate=opts.dns_rate,
-                   duration=duration,
+                   dns_query_file=opts.dns_query_file,
+                   duration=opts.duration,
+                   latency_timeout=opts.latency_timeout,
                    badpassword_frequency=opts.badpassword_frequency,
                    prefer_kerberos=opts.prefer_kerberos,
                    statsdir=statsdir,
@@ -350,7 +415,9 @@ def main():
                    base_dn=ldb.domain_dn(),
                    ou=traffic.ou_name(ldb, opts.instance_id),
                    tempdir=tempdir,
-                   domain_sid=ldb.get_domain_sid())
+                   stop_on_any_error=opts.stop_on_any_error,
+                   domain_sid=ldb.get_domain_sid(),
+                   instance_id=opts.instance_id)
 
     if opts.timing_data == '-':
         timing_dest = sys.stdout
@@ -359,12 +426,21 @@ def main():
     else:
         timing_dest = open(opts.timing_data, 'w')
 
-    print_err("Generating statistics")
+    logger.info("Generating statistics")
     traffic.generate_stats(statsdir, timing_dest)
 
     if not opts.preserve_tempdir:
-        print_err("Removing temporary directory")
+        logger.info("Removing temporary directory")
         shutil.rmtree(tempdir)
-
+    else:
+        # delete the empty directories anyway. There are thousands of
+        # them and they're EMPTY.
+        for d in os.listdir(tempdir):
+            if d.startswith('conversation-'):
+                path = os.path.join(tempdir, d)
+                try:
+                    os.rmdir(path)
+                except OSError as e:
+                    logger.info("not removing %s (%s)" % (path, e))
 
 main()