selftest: Convert samba4.ldap.sort.python to planoldpythontest
[vlendec/samba-autobuild/.git] / source4 / dsdb / tests / python / sort.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Originally based on ./sam.py
4 from __future__ import print_function
5 from unicodedata import normalize
6 import locale
7 locale.setlocale(locale.LC_ALL, ('en_US', 'UTF-8'))
8
9 import optparse
10 import sys
11 import os
12 import re
13
14 sys.path.insert(0, "bin/python")
15 import samba
16 from samba.tests.subunitrun import SubunitOptions, TestProgram
17 from samba.compat import cmp_fn
18 from samba.compat import cmp_to_key_fn
19 from samba.compat import text_type
20 from samba.compat import PY3
21 import samba.getopt as options
22
23 from samba.auth import system_session
24 import ldb
25 from samba.samdb import SamDB
26
27 parser = optparse.OptionParser("sort.py [options] <host>")
28 sambaopts = options.SambaOptions(parser)
29 parser.add_option_group(sambaopts)
30 parser.add_option_group(options.VersionOptions(parser))
31 # use command line creds if available
32 credopts = options.CredentialsOptions(parser)
33 parser.add_option_group(credopts)
34 subunitopts = SubunitOptions(parser)
35 parser.add_option_group(subunitopts)
36
37 parser.add_option('--elements', type='int', default=33,
38                   help="use this many elements in the tests")
39
40 opts, args = parser.parse_args()
41
42 if len(args) < 1:
43     parser.print_usage()
44     sys.exit(1)
45
46 host = os.getenv("SERVER", None)
47 if not host:
48     print("Please specify the host with env variable SERVER")
49     sys.exit(1)
50
51 lp = sambaopts.get_loadparm()
52 creds = credopts.get_credentials(lp)
53
54
55 def norm(x):
56     if not isinstance(x, text_type):
57         x = x.decode('utf8')
58     return normalize('NFKC', x).upper()
59
60
61 # Python, Windows, and Samba all sort the following sequence in
62 # drastically different ways. The order here is what you get from
63 # Windows2012R2.
64 FIENDISH_TESTS = [' ', ' e', '\t-\t', '\n\t\t', '!@#!@#!', '¼', '¹', '1',
65                   '1/4', '1⁄4', '1\xe2\x81\x845', '3', 'abc',
66
67                   # Here we also had '\x00food', but that seems to sort
68                   # non-deterministically on Windows vis-a-vis 'fo\x00od'.
69
70                   'kōkako', 'ŋđ¼³ŧ “«đð', 'ŋđ¼³ŧ“«đð',
71                   'sorttest', 'sorttēst11,', 'śorttest2', 'śoRttest2',
72                   'ś-o-r-t-t-e-s-t-2', 'soRTTēst2,', 'ṡorttest4', 'ṡorttesT4',
73                   'sörttest-5', 'sÖrttest-5', 'so-rttest7,', '桑巴']
74 if not PY3:
75     FIENDISH_TESTS.append('fo\x00od')
76
77 class BaseSortTests(samba.tests.TestCase):
78     avoid_tricky_sort = False
79     maxDiff = 2000
80
81     def create_user(self, i, n, prefix='sorttest', suffix='', attrs=None,
82                     tricky=False):
83         name = "%s%d%s" % (prefix, i, suffix)
84         user = {
85             'cn': name,
86             "objectclass": "user",
87             'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
88             "carLicense": "后来经",
89             "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
90             "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
91             "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
92             "flags": str(i * (n - i)),
93             "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
94                                             ' 3z}'[i & 3],
95                                             '"@'[i & 1],),
96             "comment": "Favourite colour is %d" % (n % (i + 1)),
97         }
98
99         if not PY3:
100             user.update({"roomNumber": "%sb\x00c" % (n - i)})
101
102         if self.avoid_tricky_sort:
103             # We are not even going to try passing tests that assume
104             # some kind of Unicode awareness.
105             for k, v in user.items():
106                 user[k] = re.sub(r'[^\w,.]', 'X', v)
107         else:
108             # Add some even trickier ones!
109             fiendish_index = i % len(FIENDISH_TESTS)
110             user.update({
111                 # Sort doesn't look past a NUL byte.
112                 "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
113                                                                  chr(i & 255),
114                                                                  i),
115                 # Names that vary only in case. Windows returns
116                 # equivalent addresses in the order they were put
117                 # in ('a st', 'A st',...). We don't check that.
118                 "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
119
120                 "streetAddress": FIENDISH_TESTS[fiendish_index],
121                 "postalAddress": FIENDISH_TESTS[-fiendish_index],
122             })
123
124             if not PY3:
125                 user.update({
126                     "photo": "\x00%d" % (n - i),
127                     "displayNamePrintable": "%d\x00%c" % (i, i & 255),
128                     "adminDisplayName": "%d\x00b" % (n - i),
129                     "title": "%d%sb" % (n - i, '\x00' * i)})
130
131         if attrs is not None:
132             user.update(attrs)
133
134         user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)
135
136         self.users.append(user)
137         self.ldb.add(user)
138         return user
139
140     def setUp(self):
141         super(BaseSortTests, self).setUp()
142         self.ldb = SamDB(host, credentials=creds,
143                          session_info=system_session(lp), lp=lp)
144
145         self.base_dn = self.ldb.domain_dn()
146         self.ou = "ou=sort,%s" % self.base_dn
147         if False:
148             try:
149                 self.ldb.delete(self.ou, ['tree_delete:1'])
150             except ldb.LdbError as e:
151                 print("tried deleting %s, got error %s" % (self.ou, e))
152
153         self.ldb.add({
154             "dn": self.ou,
155             "objectclass": "organizationalUnit"})
156         self.users = []
157         n = opts.elements
158         for i in range(n):
159             self.create_user(i, n)
160
161         attrs = set(self.users[0].keys()) - set([
162             'objectclass', 'dn'])
163         self.binary_sorted_keys = attrs.intersection(['audio',
164                                                       'photo',
165                                                       "msTSExpireDate4",
166                                                       'serialNumber',
167                                                       "displayNamePrintable"])
168
169         self.numeric_sorted_keys = attrs.intersection(['flags',
170                                                        'accountExpires'])
171
172         self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])
173
174         self.int64_keys = set(['accountExpires'])
175
176         self.locale_sorted_keys = [x for x in attrs if
177                                    x not in (self.binary_sorted_keys |
178                                              self.numeric_sorted_keys)]
179
180         self.expected_results = {}
181         self.expected_results_binary = {}
182
183         for k in self.locale_sorted_keys:
184             # Using key=locale.strxfrm fails on \x00
185             forward = sorted((norm(x[k]) for x in self.users),
186                              key=cmp_to_key_fn(locale.strcoll))
187             reverse = list(reversed(forward))
188             self.expected_results[k] = (forward, reverse)
189
190         for k in self.binary_sorted_keys:
191             forward = sorted((x[k] for x in self.users))
192             reverse = list(reversed(forward))
193             self.expected_results_binary[k] = (forward, reverse)
194             self.expected_results[k] = (forward, reverse)
195
196         # Fix up some because Python gets it wrong, using Schwartzian tramsform
197         for k in ('adminDisplayName', 'title', 'streetAddress',
198                   'employeeNumber'):
199             if k in self.expected_results:
200                 broken = self.expected_results[k][0]
201                 tmp = [(x.replace('\x00', ''), x) for x in broken]
202                 tmp.sort()
203                 fixed = [x[1] for x in tmp]
204                 self.expected_results[k] = (fixed, list(reversed(fixed)))
205         for k in ('streetAddress', 'postalAddress'):
206             if k in self.expected_results:
207                 c = {}
208                 for u in self.users:
209                     x = u[k]
210                     if x in c:
211                         c[x] += 1
212                         continue
213                     c[x] = 1
214                 fixed = []
215                 for x in FIENDISH_TESTS:
216                     fixed += [norm(x)] * c.get(x, 0)
217
218                 rev = list(reversed(fixed))
219                 self.expected_results[k] = (fixed, rev)
220
221     def tearDown(self):
222         super(BaseSortTests, self).tearDown()
223         self.ldb.delete(self.ou, ['tree_delete:1'])
224
225     def _test_server_sort_default(self):
226         attrs = self.locale_sorted_keys
227
228         for attr in attrs:
229             for rev in (0, 1):
230                 res = self.ldb.search(self.ou,
231                                       scope=ldb.SCOPE_ONELEVEL, attrs=[attr],
232                                       controls=["server_sort:1:%d:%s" %
233                                                 (rev, attr)])
234                 self.assertEqual(len(res), len(self.users))
235
236                 expected_order = self.expected_results[attr][rev]
237                 received_order = [norm(x[attr][0]) for x in res]
238                 if expected_order != received_order:
239                     print(attr, ['forward', 'reverse'][rev])
240                     print("expected", expected_order)
241                     print("recieved", received_order)
242                     print("unnormalised:", [x[attr][0] for x in res])
243                     print("unnormalised: «%s»" % '»  «'.join(str(x[attr][0])
244                                                              for x in res))
245                 self.assertEquals(expected_order, received_order)
246
247     def _test_server_sort_binary(self):
248         for attr in self.binary_sorted_keys:
249             for rev in (0, 1):
250                 res = self.ldb.search(self.ou,
251                                       scope=ldb.SCOPE_ONELEVEL, attrs=[attr],
252                                       controls=["server_sort:1:%d:%s" %
253                                                 (rev, attr)])
254
255                 self.assertEqual(len(res), len(self.users))
256                 expected_order = self.expected_results_binary[attr][rev]
257                 received_order = [str(x[attr][0]) for x in res]
258                 if expected_order != received_order:
259                     print(attr)
260                     print(expected_order)
261                     print(received_order)
262                 self.assertEquals(expected_order, received_order)
263
264     def _test_server_sort_us_english(self):
265         # Windows doesn't support many matching rules, but does allow
266         # the locale specific sorts -- if it has the locale installed.
267         # The most reliable locale is the default US English, which
268         # won't change the sort order.
269
270         for lang, oid in [('en_US', '1.2.840.113556.1.4.1499'),
271                           ]:
272
273             for attr in self.locale_sorted_keys:
274                 for rev in (0, 1):
275                     res = self.ldb.search(self.ou,
276                                           scope=ldb.SCOPE_ONELEVEL,
277                                           attrs=[attr],
278                                           controls=["server_sort:1:%d:%s:%s" %
279                                                     (rev, attr, oid)])
280
281                     self.assertTrue(len(res) == len(self.users))
282                     expected_order = self.expected_results[attr][rev]
283                     received_order = [norm(x[attr][0]) for x in res]
284                     if expected_order != received_order:
285                         print(attr, lang)
286                         print(['forward', 'reverse'][rev])
287                         print("expected: ", expected_order)
288                         print("recieved: ", received_order)
289                         print("unnormalised:", [x[attr][0] for x in res])
290                         print("unnormalised: «%s»" % '»  «'.join(str(x[attr][0])
291                                                                  for x in res))
292
293                     self.assertEquals(expected_order, received_order)
294
295     def _test_server_sort_different_attr(self):
296
297         def cmp_locale(a, b):
298             return locale.strcoll(a[0], b[0])
299
300         def cmp_binary(a, b):
301             return cmp_fn(a[0], b[0])
302
303         def cmp_numeric(a, b):
304             return cmp_fn(int(a[0]), int(b[0]))
305
306         # For testing simplicity, the attributes in here need to be
307         # unique for each user. Otherwise there are multiple possible
308         # valid answers.
309         sort_functions = {'cn': cmp_binary,
310                           "employeeNumber": cmp_locale,
311                           "accountExpires": cmp_numeric,
312                           "msTSExpireDate4": cmp_binary}
313         attrs = list(sort_functions.keys())
314         attr_pairs = zip(attrs, attrs[1:] + attrs[:1])
315
316         for sort_attr, result_attr in attr_pairs:
317             forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
318                              for x in self.users),
319                              key=cmp_to_key_fn(sort_functions[sort_attr]))
320             reverse = list(reversed(forward))
321
322             for rev in (0, 1):
323                 res = self.ldb.search(self.ou,
324                                       scope=ldb.SCOPE_ONELEVEL,
325                                       attrs=[result_attr],
326                                       controls=["server_sort:1:%d:%s" %
327                                                 (rev, sort_attr)])
328                 self.assertEqual(len(res), len(self.users))
329                 pairs = (forward, reverse)[rev]
330
331                 expected_order = [x[1] for x in pairs]
332                 received_order = [norm(x[result_attr][0]) for x in res]
333
334                 if expected_order != received_order:
335                     print(sort_attr, result_attr, ['forward', 'reverse'][rev])
336                     print("expected", expected_order)
337                     print("recieved", received_order)
338                     print("unnormalised:", [x[result_attr][0] for x in res])
339                     print("unnormalised: «%s»" % '»  «'.join(str(x[result_attr][0])
340                                                              for x in res))
341                     print("pairs:", pairs)
342                     # There are bugs in Windows that we don't want (or
343                     # know how) to replicate regarding timestamp sorting.
344                     # Let's remind ourselves.
345                     if result_attr == "msTSExpireDate4":
346                         print('-' * 72)
347                         print("This test fails against Windows with the "
348                               "default number of elements (33).")
349                         print("Try with --elements=27 (or similar).")
350                         print('-' * 72)
351
352                 self.assertEquals(expected_order, received_order)
353                 for x in res:
354                     if sort_attr in x:
355                         self.fail('the search for %s should not return %s' %
356                                   (result_attr, sort_attr))
357
358
359 class SimpleSortTests(BaseSortTests):
360     avoid_tricky_sort = True
361
362     def test_server_sort_different_attr(self):
363         self._test_server_sort_different_attr()
364
365     def test_server_sort_default(self):
366         self._test_server_sort_default()
367
368     def test_server_sort_binary(self):
369         self._test_server_sort_binary()
370
371     def test_server_sort_us_english(self):
372         self._test_server_sort_us_english()
373
374
375 class UnicodeSortTests(BaseSortTests):
376     avoid_tricky_sort = False
377
378     def test_server_sort_default(self):
379         self._test_server_sort_default()
380
381     def test_server_sort_us_english(self):
382         self._test_server_sort_us_english()
383
384     def test_server_sort_different_attr(self):
385         self._test_server_sort_different_attr()
386
387
388 if "://" not in host:
389     if os.path.isfile(host):
390         host = "tdb://%s" % host
391     else:
392         host = "ldap://%s" % host