PEP8: fix E226: missing whitespace around arithmetic operator
[sfrench/samba-autobuild/.git] / source4 / dsdb / tests / python / sort.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Originally based on ./sam.py
4 from __future__ import print_function
5 from unicodedata import normalize
6 import locale
7 locale.setlocale(locale.LC_ALL, ('en_US', 'UTF-8'))
8
9 import optparse
10 import sys
11 import os
12 import re
13
14 sys.path.insert(0, "bin/python")
15 import samba
16 from samba.tests.subunitrun import SubunitOptions, TestProgram
17
18 import samba.getopt as options
19
20 from samba.auth import system_session
21 import ldb
22 from samba.samdb import SamDB
23
24 parser = optparse.OptionParser("sort.py [options] <host>")
25 sambaopts = options.SambaOptions(parser)
26 parser.add_option_group(sambaopts)
27 parser.add_option_group(options.VersionOptions(parser))
28 # use command line creds if available
29 credopts = options.CredentialsOptions(parser)
30 parser.add_option_group(credopts)
31 subunitopts = SubunitOptions(parser)
32 parser.add_option_group(subunitopts)
33
34 parser.add_option('--elements', type='int', default=33,
35                   help="use this many elements in the tests")
36
37 opts, args = parser.parse_args()
38
39 if len(args) < 1:
40     parser.print_usage()
41     sys.exit(1)
42
43 host = args[0]
44
45 lp = sambaopts.get_loadparm()
46 creds = credopts.get_credentials(lp)
47
48
49 def norm(x):
50     x = x.decode('utf-8')
51     return normalize('NFKC', x).upper().encode('utf-8')
52
53 # Python, Windows, and Samba all sort the following sequence in
54 # drastically different ways. The order here is what you get from
55 # Windows2012R2.
56 FIENDISH_TESTS = [' ', ' e', '\t-\t', '\n\t\t', '!@#!@#!', '¼', '¹', '1',
57                   '1/4', '1⁄4', '1\xe2\x81\x845', '3', 'abc', 'fo\x00od',
58
59                   # Here we also had '\x00food', but that seems to sort
60                   # non-deterministically on Windows vis-a-vis 'fo\x00od'.
61
62                   'kōkako', 'ŋđ¼³ŧ “«đð', 'ŋđ¼³ŧ“«đð',
63                   'sorttest', 'sorttēst11,', 'śorttest2', 'śoRttest2',
64                   'ś-o-r-t-t-e-s-t-2', 'soRTTēst2,', 'ṡorttest4', 'ṡorttesT4',
65                   'sörttest-5', 'sÖrttest-5', 'so-rttest7,', '桑巴']
66
67
68 class BaseSortTests(samba.tests.TestCase):
69     avoid_tricky_sort = False
70     maxDiff = 2000
71
72     def create_user(self, i, n, prefix='sorttest', suffix='', attrs=None,
73                     tricky=False):
74         name = "%s%d%s" % (prefix, i, suffix)
75         user = {
76             'cn': name,
77             "objectclass": "user",
78             'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
79             "roomNumber": "%sb\x00c" % (n - i),
80             "carLicense": "后来经",
81             "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
82             "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
83             "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
84             "flags": str(i * (n - i)),
85             "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
86                                             ' 3z}'[i & 3],
87                                             '"@'[i & 1],),
88             "comment": "Favourite colour is %d" % (n % (i + 1)),
89         }
90
91         if self.avoid_tricky_sort:
92             # We are not even going to try passing tests that assume
93             # some kind of Unicode awareness.
94             for k, v in user.items():
95                 user[k] = re.sub(r'[^\w,.]', 'X', v)
96         else:
97             # Add some even trickier ones!
98             fiendish_index = i % len(FIENDISH_TESTS)
99             user.update({
100                 # Sort doesn't look past a NUL byte.
101                 "photo": "\x00%d" % (n - i),
102                 "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
103                                                                  chr(i & 255),
104                                                                  i),
105                 "displayNamePrintable": "%d\x00%c" % (i, i & 255),
106                 "adminDisplayName": "%d\x00b" % (n - i),
107                 "title": "%d%sb" % (n - i, '\x00' * i),
108
109                 # Names that vary only in case. Windows returns
110                 # equivalent addresses in the order they were put
111                 # in ('a st', 'A st',...). We don't check that.
112                 "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
113
114                 "streetAddress": FIENDISH_TESTS[fiendish_index],
115                 "postalAddress": FIENDISH_TESTS[-fiendish_index],
116             })
117
118         if attrs is not None:
119             user.update(attrs)
120
121         user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)
122
123         self.users.append(user)
124         self.ldb.add(user)
125         return user
126
127     def setUp(self):
128         super(BaseSortTests, self).setUp()
129         self.ldb = SamDB(host, credentials=creds,
130                          session_info=system_session(lp), lp=lp)
131
132         self.base_dn = self.ldb.domain_dn()
133         self.ou = "ou=sort,%s" % self.base_dn
134         if False:
135             try:
136                 self.ldb.delete(self.ou, ['tree_delete:1'])
137             except ldb.LdbError as e:
138                 print("tried deleting %s, got error %s" % (self.ou, e))
139
140         self.ldb.add({
141             "dn": self.ou,
142             "objectclass": "organizationalUnit"})
143         self.users = []
144         n = opts.elements
145         for i in range(n):
146             self.create_user(i, n)
147
148         attrs = set(self.users[0].keys()) - set([
149             'objectclass', 'dn'])
150         self.binary_sorted_keys = attrs.intersection(['audio',
151                                                       'photo',
152                                                       "msTSExpireDate4",
153                                                       'serialNumber',
154                                                       "displayNamePrintable"])
155
156         self.numeric_sorted_keys = attrs.intersection(['flags',
157                                                        'accountExpires'])
158
159         self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])
160
161         self.int64_keys = set(['accountExpires'])
162
163         self.locale_sorted_keys = [x for x in attrs if
164                                    x not in (self.binary_sorted_keys |
165                                              self.numeric_sorted_keys)]
166
167         self.expected_results = {}
168         self.expected_results_binary = {}
169
170         for k in self.locale_sorted_keys:
171             # Using key=locale.strxfrm fails on \x00
172             forward = sorted((norm(x[k]) for x in self.users),
173                              cmp=locale.strcoll)
174             reverse = list(reversed(forward))
175             self.expected_results[k] = (forward, reverse)
176
177         for k in self.binary_sorted_keys:
178             forward = sorted((x[k] for x in self.users))
179             reverse = list(reversed(forward))
180             self.expected_results_binary[k] = (forward, reverse)
181             self.expected_results[k] = (forward, reverse)
182
183         # Fix up some because Python gets it wrong, using Schwartzian tramsform
184         for k in ('adminDisplayName', 'title', 'streetAddress',
185                   'employeeNumber'):
186             if k in self.expected_results:
187                 broken = self.expected_results[k][0]
188                 tmp = [(x.replace('\x00', ''), x) for x in broken]
189                 tmp.sort()
190                 fixed = [x[1] for x in tmp]
191                 self.expected_results[k] = (fixed, list(reversed(fixed)))
192         for k in ('streetAddress', 'postalAddress'):
193             if k in self.expected_results:
194                 c = {}
195                 for u in self.users:
196                     x = u[k]
197                     if x in c:
198                         c[x] += 1
199                         continue
200                     c[x] = 1
201                 fixed = []
202                 for x in FIENDISH_TESTS:
203                     fixed += [norm(x)] * c.get(x, 0)
204
205                 rev = list(reversed(fixed))
206                 self.expected_results[k] = (fixed, rev)
207
208     def tearDown(self):
209         super(BaseSortTests, self).tearDown()
210         self.ldb.delete(self.ou, ['tree_delete:1'])
211
212     def _test_server_sort_default(self):
213         attrs = self.locale_sorted_keys
214
215         for attr in attrs:
216             for rev in (0, 1):
217                 res = self.ldb.search(self.ou,
218                                       scope=ldb.SCOPE_ONELEVEL, attrs=[attr],
219                                       controls=["server_sort:1:%d:%s" %
220                                                 (rev, attr)])
221                 self.assertEqual(len(res), len(self.users))
222
223                 expected_order = self.expected_results[attr][rev]
224                 received_order = [norm(x[attr][0]) for x in res]
225                 if expected_order != received_order:
226                     print(attr, ['forward', 'reverse'][rev])
227                     print("expected", expected_order)
228                     print("recieved", received_order)
229                     print("unnormalised:", [x[attr][0] for x in res])
230                     print("unnormalised: «%s»" % '»  «'.join(x[attr][0]
231                                                              for x in res))
232                 self.assertEquals(expected_order, received_order)
233
234     def _test_server_sort_binary(self):
235         for attr in self.binary_sorted_keys:
236             for rev in (0, 1):
237                 res = self.ldb.search(self.ou,
238                                       scope=ldb.SCOPE_ONELEVEL, attrs=[attr],
239                                       controls=["server_sort:1:%d:%s" %
240                                                 (rev, attr)])
241
242                 self.assertEqual(len(res), len(self.users))
243                 expected_order = self.expected_results_binary[attr][rev]
244                 received_order = [x[attr][0] for x in res]
245                 if expected_order != received_order:
246                     print(attr)
247                     print(expected_order)
248                     print(received_order)
249                 self.assertEquals(expected_order, received_order)
250
251     def _test_server_sort_us_english(self):
252         # Windows doesn't support many matching rules, but does allow
253         # the locale specific sorts -- if it has the locale installed.
254         # The most reliable locale is the default US English, which
255         # won't change the sort order.
256
257         for lang, oid in [('en_US', '1.2.840.113556.1.4.1499'),
258                           ]:
259
260             for attr in self.locale_sorted_keys:
261                 for rev in (0, 1):
262                     res = self.ldb.search(self.ou,
263                                           scope=ldb.SCOPE_ONELEVEL,
264                                           attrs=[attr],
265                                           controls=["server_sort:1:%d:%s:%s" %
266                                                     (rev, attr, oid)])
267
268                     self.assertTrue(len(res) == len(self.users))
269                     expected_order = self.expected_results[attr][rev]
270                     received_order = [norm(x[attr][0]) for x in res]
271                     if expected_order != received_order:
272                         print(attr, lang)
273                         print(['forward', 'reverse'][rev])
274                         print("expected: ", expected_order)
275                         print("recieved: ", received_order)
276                         print("unnormalised:", [x[attr][0] for x in res])
277                         print("unnormalised: «%s»" % '»  «'.join(x[attr][0]
278                                                                  for x in res))
279
280                     self.assertEquals(expected_order, received_order)
281
282     def _test_server_sort_different_attr(self):
283
284         def cmp_locale(a, b):
285             return locale.strcoll(a[0], b[0])
286
287         def cmp_binary(a, b):
288             return cmp(a[0], b[0])
289
290         def cmp_numeric(a, b):
291             return cmp(int(a[0]), int(b[0]))
292
293         # For testing simplicity, the attributes in here need to be
294         # unique for each user. Otherwise there are multiple possible
295         # valid answers.
296         sort_functions = {'cn': cmp_binary,
297                           "employeeNumber": cmp_locale,
298                           "accountExpires": cmp_numeric,
299                           "msTSExpireDate4": cmp_binary}
300         attrs = sort_functions.keys()
301         attr_pairs = zip(attrs, attrs[1:] + attrs[:1])
302
303         for sort_attr, result_attr in attr_pairs:
304             forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
305                              for x in self.users),
306                              cmp=sort_functions[sort_attr])
307             reverse = list(reversed(forward))
308
309             for rev in (0, 1):
310                 res = self.ldb.search(self.ou,
311                                       scope=ldb.SCOPE_ONELEVEL,
312                                       attrs=[result_attr],
313                                       controls=["server_sort:1:%d:%s" %
314                                                 (rev, sort_attr)])
315                 self.assertEqual(len(res), len(self.users))
316                 pairs = (forward, reverse)[rev]
317
318                 expected_order = [x[1] for x in pairs]
319                 received_order = [norm(x[result_attr][0]) for x in res]
320
321                 if expected_order != received_order:
322                     print(sort_attr, result_attr, ['forward', 'reverse'][rev])
323                     print("expected", expected_order)
324                     print("recieved", received_order)
325                     print("unnormalised:", [x[result_attr][0] for x in res])
326                     print("unnormalised: «%s»" % '»  «'.join(x[result_attr][0]
327                                                              for x in res))
328                     print("pairs:", pairs)
329                     # There are bugs in Windows that we don't want (or
330                     # know how) to replicate regarding timestamp sorting.
331                     # Let's remind ourselves.
332                     if result_attr == "msTSExpireDate4":
333                         print('-' * 72)
334                         print("This test fails against Windows with the "
335                                "default number of elements (33).")
336                         print("Try with --elements=27 (or similar).")
337                         print('-' * 72)
338
339                 self.assertEquals(expected_order, received_order)
340                 for x in res:
341                     if sort_attr in x:
342                         self.fail('the search for %s should not return %s' %
343                                   (result_attr, sort_attr))
344
345
346 class SimpleSortTests(BaseSortTests):
347     avoid_tricky_sort = True
348     def test_server_sort_different_attr(self):
349         self._test_server_sort_different_attr()
350
351     def test_server_sort_default(self):
352         self._test_server_sort_default()
353
354     def test_server_sort_binary(self):
355         self._test_server_sort_binary()
356
357     def test_server_sort_us_english(self):
358         self._test_server_sort_us_english()
359
360
361 class UnicodeSortTests(BaseSortTests):
362     avoid_tricky_sort = False
363
364     def test_server_sort_default(self):
365         self._test_server_sort_default()
366
367     def test_server_sort_us_english(self):
368         self._test_server_sort_us_english()
369
370     def test_server_sort_different_attr(self):
371         self._test_server_sort_different_attr()
372
373
374 if "://" not in host:
375     if os.path.isfile(host):
376         host = "tdb://%s" % host
377     else:
378         host = "ldap://%s" % host
379
380
381 TestProgram(module=__name__, opts=subunitopts)