PEP8: fix E303: too many blank lines (2)
[nivanova/samba-autobuild/.git] / source4 / dsdb / tests / python / linked_attributes.py
index 225ee47473ed5b832fd1de3e04dc24cce31ea4af..f297ac9c16b982c17ea7af5361744e805fc07664 100644 (file)
@@ -1,12 +1,11 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 # Originally based on ./sam.py
+from __future__ import print_function
 import optparse
 import sys
 import os
-import base64
-import random
-import re
+import itertools
 
 sys.path.insert(0, "bin/python")
 import samba
@@ -19,8 +18,6 @@ import ldb
 from samba.samdb import SamDB
 from samba.dcerpc import misc
 
-import time
-
 parser = optparse.OptionParser("linked_attributes.py [options] <host>")
 sambaopts = options.SambaOptions(parser)
 parser.add_option_group(sambaopts)
@@ -61,15 +58,15 @@ class LATests(samba.tests.TestCase):
     def setUp(self):
         super(LATests, self).setUp()
         self.samdb = SamDB(host, credentials=creds,
-                         session_info=system_session(lp), lp=lp)
+                           session_info=system_session(lp), lp=lp)
 
         self.base_dn = self.samdb.domain_dn()
         self.ou = "OU=la,%s" % self.base_dn
         if opts.delete_in_setup:
             try:
                 self.samdb.delete(self.ou, ['tree_delete:1'])
-            except ldb.LdbError, e:
-                print "tried deleting %s, got error %s" % (self.ou, e)
+            except ldb.LdbError as e:
+                print("tried deleting %s, got error %s" % (self.ou, e))
         self.samdb.add({'objectclass': 'organizationalUnit',
                         'dn': self.ou})
 
@@ -78,41 +75,48 @@ class LATests(samba.tests.TestCase):
         if not opts.no_cleanup:
             self.samdb.delete(self.ou, ['tree_delete:1'])
 
-    def delete_user(self, user):
-        self.samdb.delete(user['dn'])
-        del self.users[self.users.index(user)]
-
-    def add_object(self, cn, objectclass):
+    def add_object(self, cn, objectclass, more_attrs={}):
         dn = "CN=%s,%s" % (cn, self.ou)
-        self.samdb.add({'cn': cn,
-                      'objectclass': objectclass,
-                      'dn': dn})
+        attrs = {'cn': cn,
+                 'objectclass': objectclass,
+                 'dn': dn}
+        attrs.update(more_attrs)
+        self.samdb.add(attrs)
 
         return dn
 
-    def add_objects(self, n, objectclass, prefix=None):
+    def add_objects(self, n, objectclass, prefix=None, more_attrs={}):
         if prefix is None:
             prefix = objectclass
         dns = []
         for i in range(n):
             dns.append(self.add_object("%s%d" % (prefix, i + 1),
-                                       objectclass))
+                                       objectclass,
+                                       more_attrs=more_attrs))
         return dns
 
-    def add_linked_attribute(self, src, dest, attr='member'):
+    def add_linked_attribute(self, src, dest, attr='member',
+                             controls=None):
         m = ldb.Message()
         m.dn = ldb.Dn(self.samdb, src)
         m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
-        self.samdb.modify(m)
+        self.samdb.modify(m, controls=controls)
 
-    def remove_linked_attribute(self, src, dest, attr='member'):
+    def remove_linked_attribute(self, src, dest, attr='member',
+                                controls=None):
         m = ldb.Message()
         m.dn = ldb.Dn(self.samdb, src)
         m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
-        self.samdb.modify(m)
+        self.samdb.modify(m, controls=controls)
 
-    def attr_search(self, obj, expected, attr, scope=ldb.SCOPE_BASE,
-                    **controls):
+    def replace_linked_attribute(self, src, dest, attr='member',
+                                 controls=None):
+        m = ldb.Message()
+        m.dn = ldb.Dn(self.samdb, src)
+        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
+        self.samdb.modify(m, controls=controls)
+
+    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):
         if opts.no_reveal_internals:
             if 'reveal_internals' in controls:
                 del controls['reveal_internals']
@@ -125,9 +129,8 @@ class LATests(samba.tests.TestCase):
                                 controls=controls)
         return res
 
-    def assert_links(self, obj, expected, attr, sorted=False, msg='',
-                     **kwargs):
-        res = self.attr_search(obj, expected, attr, **kwargs)
+    def assert_links(self, obj, expected, attr, msg='', **kwargs):
+        res = self.attr_search(obj, attr, **kwargs)
 
         if len(expected) == 0:
             if attr in res[0]:
@@ -139,14 +142,13 @@ class LATests(samba.tests.TestCase):
         except KeyError:
             self.fail("missing attr '%s' on %s" % (attr, obj))
 
-        if sorted == False:
-            results = set(results)
-            expected = set(expected)
+        expected = sorted(expected)
+        results = sorted(results)
 
         if expected != results:
-            print msg
-            print "expected %s" % expected
-            print "received %s" % results
+            print(msg)
+            print("expected %s" % expected)
+            print("received %s" % results)
 
         self.assertEqual(results, expected)
 
@@ -164,6 +166,27 @@ class LATests(samba.tests.TestCase):
                                 attrs=['objectGUID'])
         return str(misc.GUID(res[0]['objectGUID'][0]))
 
+    def assertRaisesLdbError(self, errcode, msg, f, *args, **kwargs):
+        """Assert a function raises a particular LdbError."""
+        try:
+            f(*args, **kwargs)
+        except ldb.LdbError as e:
+            (num, msg) = e.args
+            if num != errcode:
+                lut = {v: k for k, v in vars(ldb).items()
+                       if k.startswith('ERR_') and isinstance(v, int)}
+                self.fail("%s, expected "
+                          "LdbError %s, (%d) "
+                          "got %s (%d)" % (msg,
+                                           lut.get(errcode), errcode,
+                                           lut.get(num), num))
+        else:
+            lut = {v: k for k, v in vars(ldb).items()
+                   if k.startswith('ERR_') and isinstance(v, int)}
+            self.fail("%s, expected "
+                      "LdbError %s, (%d) "
+                      "but we got success" % (msg, lut.get(errcode), errcode))
+
     def _test_la_backlinks(self, reveal=False):
         tag = 'backlinks'
         kwargs = {}
@@ -186,7 +209,7 @@ class LATests(samba.tests.TestCase):
 
     def test_la_backlinks_reveal(self):
         if opts.no_reveal_internals:
-            print 'skipping because --no-reveal-internals'
+            print('skipping because --no-reveal-internals')
             return
         self._test_la_backlinks(True)
 
@@ -214,7 +237,7 @@ class LATests(samba.tests.TestCase):
 
     def test_la_backlinks_delete_group_reveal(self):
         if opts.no_reveal_internals:
-            print 'skipping because --no-reveal-internals'
+            print('skipping because --no-reveal-internals')
             return
         self._test_la_backlinks_delete_group(True)
 
@@ -250,10 +273,10 @@ class LATests(samba.tests.TestCase):
         self.samdb.delete(g2)
         self.assert_back_links(u1, [g1], show_deleted=1, show_recycled=1,
                                show_deactivated_link=0,
-                                  reveal_internals=0)
+                               reveal_internals=0)
         self.assert_back_links(u2, set(), show_deleted=1, show_recycled=1,
                                show_deactivated_link=0,
-                                  reveal_internals=0)
+                               reveal_internals=0)
         self.assert_forward_links(g1, [u1], show_deleted=1, show_recycled=1,
                                   show_deactivated_link=0,
                                   reveal_internals=0)
@@ -302,8 +325,15 @@ class LATests(samba.tests.TestCase):
         self.assert_forward_links(g2, [u1])
         self.remove_linked_attribute(g2, u1)
         self.assert_forward_links(g2, [])
+        self.remove_linked_attribute(g1, [])
+        self.assert_forward_links(g1, [])
 
-    def test_la_links_delete_link_reveal(self):
+        # removing a duplicate link in the same message should fail
+        self.add_linked_attribute(g2, [u1, u2])
+        self.assertRaises(ldb.LdbError,
+                          self.remove_linked_attribute, g2, [u1, u1])
+
+    def _test_la_links_delete_link_reveal(self):
         u1, u2 = self.add_objects(2, 'user', 'u_del_link_reveal')
         g1, g2 = self.add_objects(2, 'group', 'g_del_link_reveal')
 
@@ -317,7 +347,13 @@ class LATests(samba.tests.TestCase):
                                   show_recycled=1,
                                   show_deactivated_link=0,
                                   reveal_internals=0
-        )
+                                  )
+
+    def test_la_links_delete_link_reveal(self):
+        if opts.no_reveal_internals:
+            print('skipping because --no-reveal-internals')
+            return
+        self._test_la_links_delete_link_reveal()
 
     def test_la_links_delete_user(self):
         u1, u2 = self.add_objects(2, 'user', 'u_del_user')
@@ -371,27 +407,249 @@ class LATests(samba.tests.TestCase):
                                   show_deactivated_link=0,
                                   reveal_internals=0)
 
-    def _test_la_links_sort_order(self):
-        u1, u2, u3 = self.add_objects(3, 'user', 'u_sort_order')
-        g1, g2, g3 = self.add_objects(3, 'group', 'g_sort_order')
+    def test_multiple_links(self):
+        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_multiple_links')
+        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_multiple_links')
 
-        # Add these in a haphazard order
-        self.add_linked_attribute(g2, u3)
+        self.add_linked_attribute(g1, [u1, u2, u3, u4])
+        self.add_linked_attribute(g2, [u3, u1])
         self.add_linked_attribute(g3, u2)
-        self.add_linked_attribute(g1, u3)
-        self.add_linked_attribute(g1, u1)
-        self.add_linked_attribute(g2, u1)
-        self.add_linked_attribute(g2, u2)
-        self.add_linked_attribute(g3, u3)
+
+        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
+                                  "adding duplicate values",
+                                  self.add_linked_attribute, g2,
+                                  [u1, u2, u3, u2])
+
+        self.assert_forward_links(g1, [u1, u2, u3, u4])
+        self.assert_forward_links(g2, [u3, u1])
+        self.assert_forward_links(g3, [u2])
+        self.assert_back_links(u1, [g2, g1])
+        self.assert_back_links(u2, [g3, g1])
+        self.assert_back_links(u3, [g2, g1])
+        self.assert_back_links(u4, [g1])
+
+        self.remove_linked_attribute(g2, [u1, u3])
+        self.remove_linked_attribute(g1, [u1, u3])
+
+        self.assert_forward_links(g1, [u2, u4])
+        self.assert_forward_links(g2, [])
+        self.assert_forward_links(g3, [u2])
+        self.assert_back_links(u1, [])
+        self.assert_back_links(u2, [g3, g1])
+        self.assert_back_links(u3, [])
+        self.assert_back_links(u4, [g1])
+
+        self.add_linked_attribute(g1, [u1, u3])
+        self.add_linked_attribute(g2, [u3, u1])
+        self.add_linked_attribute(g3, [u1, u3])
+
+        self.assert_forward_links(g1, [u1, u2, u3, u4])
+        self.assert_forward_links(g2, [u1, u3])
+        self.assert_forward_links(g3, [u1, u2, u3])
+        self.assert_back_links(u1, [g1, g2, g3])
+        self.assert_back_links(u2, [g3, g1])
+        self.assert_back_links(u3, [g3, g2, g1])
+        self.assert_back_links(u4, [g1])
+
+    def test_la_links_replace(self):
+        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_replace')
+        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_replace')
+
+        self.add_linked_attribute(g1, [u1, u2])
+        self.add_linked_attribute(g2, [u1, u3])
         self.add_linked_attribute(g3, u1)
 
-        self.assert_forward_links(g1, [u3, u1], sorted=True)
-        self.assert_forward_links(g2, [u3, u2, u1], sorted=True)
-        self.assert_forward_links(g3, [u3, u2, u1], sorted=True)
+        self.replace_linked_attribute(g1, [u2])
+        self.replace_linked_attribute(g2, [u2, u3])
+        self.replace_linked_attribute(g3, [u1, u3])
+        self.replace_linked_attribute(g4, [u4])
+
+        self.assert_forward_links(g1, [u2])
+        self.assert_forward_links(g2, [u3, u2])
+        self.assert_forward_links(g3, [u3, u1])
+        self.assert_forward_links(g4, [u4])
+        self.assert_back_links(u1, [g3])
+        self.assert_back_links(u2, [g1, g2])
+        self.assert_back_links(u3, [g2, g3])
+        self.assert_back_links(u4, [g4])
+
+        self.replace_linked_attribute(g1, [u1, u2, u3])
+        self.replace_linked_attribute(g2, [u1])
+        self.replace_linked_attribute(g3, [u2])
+        self.replace_linked_attribute(g4, [])
+
+        self.assert_forward_links(g1, [u1, u2, u3])
+        self.assert_forward_links(g2, [u1])
+        self.assert_forward_links(g3, [u2])
+        self.assert_forward_links(g4, [])
+        self.assert_back_links(u1, [g1, g2])
+        self.assert_back_links(u2, [g1, g3])
+        self.assert_back_links(u3, [g1])
+        self.assert_back_links(u4, [])
+
+        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
+                                  "replacing duplicate values",
+                                  self.replace_linked_attribute, g2,
+                                  [u1, u2, u3, u2])
+
+    def test_la_links_replace2(self):
+        users = self.add_objects(12, 'user', 'u_replace2')
+        g1, = self.add_objects(1, 'group', 'g_replace2')
+
+        self.add_linked_attribute(g1, users[:6])
+        self.assert_forward_links(g1, users[:6])
+        self.replace_linked_attribute(g1, users)
+        self.assert_forward_links(g1, users)
+        self.replace_linked_attribute(g1, users[6:])
+        self.assert_forward_links(g1, users[6:])
+        self.remove_linked_attribute(g1, users[6:9])
+        self.assert_forward_links(g1, users[9:])
+        self.remove_linked_attribute(g1, users[9:])
+        self.assert_forward_links(g1, [])
+
+    def test_la_links_permutations(self):
+        """Make sure the order in which we add links doesn't matter."""
+        users = self.add_objects(3, 'user', 'u_permutations')
+        groups = self.add_objects(6, 'group', 'g_permutations')
+
+        for g, p in zip(groups, itertools.permutations(users)):
+            self.add_linked_attribute(g, p)
 
-        self.assert_back_links(u1, [g3, g2, g1], sorted=True)
-        self.assert_back_links(u2, [g3, g2], sorted=True)
-        self.assert_back_links(u3, [g3, g2, g1], sorted=True)
+        # everyone should be in every group
+        for g in groups:
+            self.assert_forward_links(g, users)
+
+        for u in users:
+            self.assert_back_links(u, groups)
+
+        for g, p in zip(groups[::-1], itertools.permutations(users)):
+            self.replace_linked_attribute(g, p)
+
+        for g in groups:
+            self.assert_forward_links(g, users)
+
+        for u in users:
+            self.assert_back_links(u, groups)
+
+        for g, p in zip(groups, itertools.permutations(users)):
+            self.remove_linked_attribute(g, p)
+
+        for g in groups:
+            self.assert_forward_links(g, [])
+
+        for u in users:
+            self.assert_back_links(u, [])
+
+    def test_la_links_relaxed(self):
+        """Check that the relax control doesn't mess with linked attributes."""
+        relax_control = ['relax:0']
+
+        users = self.add_objects(10, 'user', 'u_relax')
+        groups = self.add_objects(3, 'group', 'g_relax',
+                                  more_attrs={'member': users[:2]})
+        g_relax1, g_relax2, g_uptight = groups
+
+        # g_relax1 has all users added at once
+        # g_relax2 gets them one at a time in reverse order
+        # g_uptight never relaxes
+
+        self.add_linked_attribute(g_relax1, users[2:5], controls=relax_control)
+
+        for u in reversed(users[2:5]):
+            self.add_linked_attribute(g_relax2, u, controls=relax_control)
+            self.add_linked_attribute(g_uptight, u)
+
+        for g in groups:
+            self.assert_forward_links(g, users[:5])
+
+            self.add_linked_attribute(g, users[5:7])
+            self.assert_forward_links(g, users[:7])
+
+            for u in users[7:]:
+                self.add_linked_attribute(g, u)
+
+            self.assert_forward_links(g, users)
+
+        for u in users:
+            self.assert_back_links(u, groups)
+
+        # try some replacement permutations
+        import random
+        random.seed(1)
+        users2 = users[:]
+        for i in range(5):
+            random.shuffle(users2)
+            self.replace_linked_attribute(g_relax1, users2,
+                                          controls=relax_control)
+
+            self.assert_forward_links(g_relax1, users)
+
+        for i in range(5):
+            random.shuffle(users2)
+            self.remove_linked_attribute(g_relax2, users2,
+                                         controls=relax_control)
+            self.remove_linked_attribute(g_uptight, users2)
+
+            self.replace_linked_attribute(g_relax1, [], controls=relax_control)
+
+            random.shuffle(users2)
+            self.add_linked_attribute(g_relax2, users2,
+                                      controls=relax_control)
+            self.add_linked_attribute(g_uptight, users2)
+            self.replace_linked_attribute(g_relax1, users2,
+                                          controls=relax_control)
+
+            self.assert_forward_links(g_relax1, users)
+            self.assert_forward_links(g_relax2, users)
+            self.assert_forward_links(g_uptight, users)
+
+        for u in users:
+            self.assert_back_links(u, groups)
+
+    def test_add_all_at_once(self):
+        """All these other tests are creating linked attributes after the
+        objects are there. We want to test creating them all at once
+        using LDIF.
+        """
+        users = self.add_objects(7, 'user', 'u_all_at_once')
+        g1, g3 = self.add_objects(2, 'group', 'g_all_at_once',
+                                  more_attrs={'member': users})
+        (g2,) = self.add_objects(1, 'group', 'g_all_at_once2',
+                                 more_attrs={'member': users[:5]})
+
+        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
+                                  "adding multiple duplicate values",
+                                  self.add_objects, 1, 'group',
+                                  'g_with_duplicate_links',
+                                  more_attrs={'member': users[:5] + users[1:2]})
+
+        self.assert_forward_links(g1, users)
+        self.assert_forward_links(g2, users[:5])
+        self.assert_forward_links(g3, users)
+        for u in users[:5]:
+            self.assert_back_links(u, [g1, g2, g3])
+        for u in users[5:]:
+            self.assert_back_links(u, [g1, g3])
+
+        self.remove_linked_attribute(g2, users[0])
+        self.remove_linked_attribute(g2, users[1])
+        self.add_linked_attribute(g2, users[1])
+        self.add_linked_attribute(g2, users[5])
+        self.add_linked_attribute(g2, users[6])
+
+        self.assert_forward_links(g1, users)
+        self.assert_forward_links(g2, users[1:])
+
+        for u in users[1:]:
+            self.remove_linked_attribute(g2, u)
+        self.remove_linked_attribute(g1, users)
+
+        for u in users:
+            self.samdb.delete(u)
+
+        self.assert_forward_links(g1, [])
+        self.assert_forward_links(g2, [])
+        self.assert_forward_links(g3, [])
 
     def test_one_way_attributes(self):
         e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',