f26e763cf1f9fdcc10c2b919385d7f75e440d3c3
[samba.git] / python / samba / tests / samba_tool / domain_models.py
1 # Unix SMB/CIFS implementation.
2 #
3 # Tests for domain models and fields
4 #
5 # Copyright (C) Catalyst.Net Ltd. 2023
6 #
7 # Written by Rob van der Linde <rob@catalyst.net.nz>
8 #
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 import os
24 from datetime import datetime, timezone
25 from xml.etree import ElementTree
26
27 from ldb import FLAG_MOD_ADD, SCOPE_ONELEVEL, MessageElement
28
29 from samba.dcerpc import security
30 from samba.dcerpc.misc import GUID
31 from samba.domain.models import (AccountType, Computer, Group, Site,
32                                  StrongNTLMPolicy, User, fields)
33 from samba.ndr import ndr_pack, ndr_unpack
34
35 from .base import SambaToolCmdTest
36
37 HOST = "ldap://{DC_SERVER}".format(**os.environ)
38 CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
39
40
41 class ModelTests(SambaToolCmdTest):
42
43     @classmethod
44     def setUpClass(cls):
45         cls.samdb = cls.getSamDB("-H", HOST, CREDS)
46         super().setUpClass()
47
48     def test_query_count(self):
49         """Test count property on Query object without converting to a list."""
50         groups = Group.query(self.samdb)
51         self.assertEqual(groups.count, len(list(groups)))
52
53     def test_query_filter_bool(self):
54         """Tests filtering by a BooleanField."""
55         total = Group.query(self.samdb).count
56         system_groups = Group.query(self.samdb,
57                                     is_critical_system_object=True).count
58         user_groups = Group.query(self.samdb,
59                                   is_critical_system_object=False).count
60         self.assertNotEqual(system_groups, 0)
61         self.assertNotEqual(user_groups, 0)
62         self.assertEqual(system_groups + user_groups, total)
63
64     def test_query_filter_enum(self):
65         """Tests filtering by an EnumField."""
66         robots_vs_humans = User.query(self.samdb).count
67         robots = User.query(self.samdb,
68                             account_type=AccountType.WORKSTATION_TRUST).count
69         humans = User.query(self.samdb,
70                             account_type=AccountType.NORMAL_ACCOUNT).count
71         self.assertNotEqual(robots, 0)
72         self.assertNotEqual(humans, 0)
73         self.assertEqual(robots + humans, robots_vs_humans)
74
75
76 class ComputerModelTests(SambaToolCmdTest):
77
78     @classmethod
79     def setUpClass(cls):
80         cls.samdb = cls.getSamDB("-H", HOST, CREDS)
81         super().setUpClass()
82
83     def test_computer_constructor(self):
84         # Use only name
85         comp1 = Computer.create(self.samdb, name="comp1")
86         self.addCleanup(comp1.delete, self.samdb)
87         self.assertEqual(comp1.name, "comp1")
88         self.assertEqual(comp1.account_name, "comp1$")
89
90         # Use only cn
91         comp2 = Computer.create(self.samdb, cn="comp2")
92         self.addCleanup(comp2.delete, self.samdb)
93         self.assertEqual(comp2.name, "comp2")
94         self.assertEqual(comp2.account_name, "comp2$")
95
96         # Use name and account_name but missing "$" in account_name.
97         comp3 = Computer.create(self.samdb, name="comp3", account_name="comp3")
98         self.addCleanup(comp3.delete, self.samdb)
99         self.assertEqual(comp3.name, "comp3")
100         self.assertEqual(comp3.account_name, "comp3$")
101
102         # Use cn and account_name but missing "$" in account_name.
103         comp4 = Computer.create(self.samdb, cn="comp4", account_name="comp4$")
104         self.addCleanup(comp4.delete, self.samdb)
105         self.assertEqual(comp4.name, "comp4")
106         self.assertEqual(comp4.account_name, "comp4$")
107
108         # Use only account_name, the name should get the "$" removed.
109         comp5 = Computer.create(self.samdb, account_name="comp5$")
110         self.addCleanup(comp5.delete, self.samdb)
111         self.assertEqual(comp5.name, "comp5")
112         self.assertEqual(comp5.account_name, "comp5$")
113
114         # Use only account_name but accidentally forgot the "$" character.
115         comp6 = Computer.create(self.samdb, account_name="comp6")
116         self.addCleanup(comp6.delete, self.samdb)
117         self.assertEqual(comp6.name, "comp6")
118         self.assertEqual(comp6.account_name, "comp6$")
119
120
121 class FieldTestMixin:
122     """Tests a model field to ensure it behaves correctly in both directions.
123
124     Use a mixin since TestCase can't be marked as abstract.
125     """
126
127     @classmethod
128     def setUpClass(cls):
129         cls.samdb = cls.getSamDB("-H", HOST, CREDS)
130         super().setUpClass()
131
132     def get_users_dn(self):
133         """Returns Users DN."""
134         users_dn = self.samdb.get_root_basedn()
135         users_dn.add_child("CN=Users")
136         return users_dn
137
138     def test_to_db_value(self):
139         # Loop through each value and expected value combination.
140         # If the expected value is callable, treat it as a validation callback.
141         # NOTE: perhaps we should be using subtests for this.
142         for (value, expected) in self.to_db_value:
143             db_value = self.field.to_db_value(self.samdb, value, FLAG_MOD_ADD)
144             if callable(expected):
145                 self.assertTrue(expected(db_value))
146             else:
147                 self.assertEqual(db_value, expected)
148
149     def test_from_db_value(self):
150         # Loop through each value and expected value combination.
151         # NOTE: perhaps we should be using subtests for this.
152         for (db_value, expected) in self.from_db_value:
153             value = self.field.from_db_value(self.samdb, db_value)
154             self.assertEqual(value, expected)
155
156
157 class IntegerFieldTest(FieldTestMixin, SambaToolCmdTest):
158     field = fields.IntegerField("FieldName")
159
160     to_db_value = [
161         (10, MessageElement(b"10")),
162         ([1, 5, 10], MessageElement([b"1", b"5", b"10"])),
163         (None, None),
164     ]
165
166     from_db_value = [
167         (MessageElement(b"10"), 10),
168         (MessageElement([b"1", b"5", b"10"]), [1, 5, 10]),
169         (None, None),
170     ]
171
172
173 class BinaryFieldTest(FieldTestMixin, SambaToolCmdTest):
174     field = fields.BinaryField("FieldName")
175
176     to_db_value = [
177         (b"SAMBA", MessageElement(b"SAMBA")),
178         ([b"SAMBA", b"Developer"], MessageElement([b"SAMBA", b"Developer"])),
179         (None, None),
180     ]
181
182     from_db_value = [
183         (MessageElement(b"SAMBA"), b"SAMBA"),
184         (MessageElement([b"SAMBA", b"Developer"]), [b"SAMBA", b"Developer"]),
185         (None, None),
186     ]
187
188
189 class StringFieldTest(FieldTestMixin, SambaToolCmdTest):
190     field = fields.StringField("FieldName")
191
192     to_db_value = [
193         ("SAMBA", MessageElement(b"SAMBA")),
194         (["SAMBA", "Developer"], MessageElement([b"SAMBA", b"Developer"])),
195         (None, None),
196     ]
197
198     from_db_value = [
199         (MessageElement(b"SAMBA"), "SAMBA"),
200         (MessageElement([b"SAMBA", b"Developer"]), ["SAMBA", "Developer"]),
201         (None, None),
202     ]
203
204
205 class BooleanFieldTest(FieldTestMixin, SambaToolCmdTest):
206     field = fields.BooleanField("FieldName")
207
208     to_db_value = [
209         (True, MessageElement(b"TRUE")),
210         ([False, True], MessageElement([b"FALSE", b"TRUE"])),
211         (None, None),
212     ]
213
214     from_db_value = [
215         (MessageElement(b"TRUE"), True),
216         (MessageElement([b"FALSE", b"TRUE"]), [False, True]),
217         (None, None),
218     ]
219
220
221 class EnumFieldTest(FieldTestMixin, SambaToolCmdTest):
222     field = fields.EnumField("FieldName", StrongNTLMPolicy)
223
224     to_db_value = [
225         (StrongNTLMPolicy.OPTIONAL, MessageElement("1")),
226         ([StrongNTLMPolicy.REQUIRED, StrongNTLMPolicy.OPTIONAL],
227          MessageElement(["2", "1"])),
228         (None, None),
229     ]
230
231     from_db_value = [
232         (MessageElement("1"), StrongNTLMPolicy.OPTIONAL),
233         (MessageElement(["2", "1"]),
234          [StrongNTLMPolicy.REQUIRED, StrongNTLMPolicy.OPTIONAL]),
235         (None, None),
236     ]
237
238
239 class DateTimeFieldTest(FieldTestMixin, SambaToolCmdTest):
240     field = fields.DateTimeField("FieldName")
241
242     to_db_value = [
243         (datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
244          MessageElement("20230127223641.0Z")),
245         ([datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
246           datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)],
247          MessageElement(["20230127223641.0Z", "20230127224750.0Z"])),
248         (None, None),
249     ]
250
251     from_db_value = [
252         (MessageElement("20230127223641.0Z"),
253          datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc)),
254         (MessageElement(["20230127223641.0Z", "20230127224750.0Z"]),
255          [datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
256           datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)]),
257         (None, None),
258     ]
259
260
261 class NtTimeFieldTest(FieldTestMixin, SambaToolCmdTest):
262     field = fields.NtTimeField("FieldName")
263
264     to_db_value = [
265         (datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
266          MessageElement("133193326010000000")),
267         ([datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
268           datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)],
269          MessageElement(["133193326010000000", "133193332700000000"])),
270         (None, None),
271     ]
272
273     from_db_value = [
274         (MessageElement("133193326010000000"),
275          datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc)),
276         (MessageElement(["133193326010000000", "133193332700000000"]),
277          [datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
278           datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)]),
279         (None, None),
280     ]
281
282
283 class RelatedFieldTest(FieldTestMixin, SambaToolCmdTest):
284     field = fields.RelatedField("FieldName", User)
285
286     @property
287     def to_db_value(self):
288         alice = User.get(self.samdb, account_name="alice")
289         joe = User.get(self.samdb, account_name="joe")
290         return [
291             (alice, MessageElement(str(alice.dn))),
292             ([joe, alice], MessageElement([str(joe.dn), str(alice.dn)])),
293             (None, None),
294         ]
295
296     @property
297     def from_db_value(self):
298         alice = User.get(self.samdb, account_name="alice")
299         joe = User.get(self.samdb, account_name="joe")
300         return [
301             (MessageElement(str(alice.dn)), alice),
302             (MessageElement([str(joe.dn), str(alice.dn)]), [joe, alice]),
303             (None, None),
304         ]
305
306
307 class DnFieldTest(FieldTestMixin, SambaToolCmdTest):
308     field = fields.DnField("FieldName")
309
310     @property
311     def to_db_value(self):
312         alice = User.get(self.samdb, account_name="alice")
313         joe = User.get(self.samdb, account_name="joe")
314         return [
315             (alice.dn, MessageElement(str(alice.dn))),
316             ([joe.dn, alice.dn], MessageElement([str(joe.dn), str(alice.dn)])),
317             (None, None),
318         ]
319
320     @property
321     def from_db_value(self):
322         alice = User.get(self.samdb, account_name="alice")
323         joe = User.get(self.samdb, account_name="joe")
324         return [
325             (MessageElement(str(alice.dn)), alice.dn),
326             (MessageElement([str(joe.dn), str(alice.dn)]), [joe.dn, alice.dn]),
327             (None, None),
328         ]
329
330
331 class SIDFieldTest(FieldTestMixin, SambaToolCmdTest):
332     field = fields.SIDField("FieldName")
333
334     @property
335     def to_db_value(self):
336         # Create a group for testing
337         group = Group(name="group1")
338         group.save(self.samdb)
339         self.addCleanup(group.delete, self.samdb)
340
341         # Get raw value to compare against
342         group_rec = self.samdb.search(Group.get_base_dn(self.samdb),
343                                       scope=SCOPE_ONELEVEL,
344                                       expression="(name=group1)",
345                                       attrs=["objectSid"])[0]
346         raw_sid = group_rec["objectSid"]
347
348         return [
349             (group.object_sid, raw_sid),
350             (None, None),
351         ]
352
353     @property
354     def from_db_value(self):
355         # Create a group for testing
356         group = Group(name="group1")
357         group.save(self.samdb)
358         self.addCleanup(group.delete, self.samdb)
359
360         # Get raw value to compare against
361         group_rec = self.samdb.search(Group.get_base_dn(self.samdb),
362                                       scope=SCOPE_ONELEVEL,
363                                       expression="(name=group1)",
364                                       attrs=["objectSid"])[0]
365         raw_sid = group_rec["objectSid"]
366
367         return [
368             (raw_sid, group.object_sid),
369             (None, None),
370         ]
371
372
373 class GUIDFieldTest(FieldTestMixin, SambaToolCmdTest):
374     field = fields.GUIDField("FieldName")
375
376     @property
377     def to_db_value(self):
378         users_dn = self.get_users_dn()
379
380         alice = self.samdb.search(users_dn,
381                                   scope=SCOPE_ONELEVEL,
382                                   expression="(sAMAccountName=alice)",
383                                   attrs=["objectGUID"])[0]
384
385         joe = self.samdb.search(users_dn,
386                                 scope=SCOPE_ONELEVEL,
387                                 expression="(sAMAccountName=joe)",
388                                 attrs=["objectGUID"])[0]
389
390         alice_guid = str(ndr_unpack(GUID, alice["objectGUID"][0]))
391         joe_guid = str(ndr_unpack(GUID, joe["objectGUID"][0]))
392
393         return [
394             (alice_guid, alice["objectGUID"]),
395             (
396                 [joe_guid, alice_guid],
397                 MessageElement([joe["objectGUID"][0], alice["objectGUID"][0]]),
398             ),
399             (None, None),
400         ]
401
402     @property
403     def from_db_value(self):
404         users_dn = self.get_users_dn()
405
406         alice = self.samdb.search(users_dn,
407                                   scope=SCOPE_ONELEVEL,
408                                   expression="(sAMAccountName=alice)",
409                                   attrs=["objectGUID"])[0]
410
411         joe = self.samdb.search(users_dn,
412                                 scope=SCOPE_ONELEVEL,
413                                 expression="(sAMAccountName=joe)",
414                                 attrs=["objectGUID"])[0]
415
416         alice_guid = str(ndr_unpack(GUID, alice["objectGUID"][0]))
417         joe_guid = str(ndr_unpack(GUID, joe["objectGUID"][0]))
418
419         return [
420             (alice["objectGUID"], alice_guid),
421             (
422                 MessageElement([joe["objectGUID"][0], alice["objectGUID"][0]]),
423                 [joe_guid, alice_guid],
424             ),
425             (None, None),
426         ]
427
428
429 class SDDLFieldTest(FieldTestMixin, SambaToolCmdTest):
430     field = fields.SDDLField("FieldName")
431
432     def setUp(self):
433         super().setUp()
434         self.domain_sid = security.dom_sid(self.samdb.get_domain_sid())
435
436     def security_descriptor(self, sddl):
437         return security.descriptor.from_sddl(sddl, self.domain_sid)
438
439     @property
440     def to_db_value(self):
441         values = [
442             "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AU)}))",
443             "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AO)}))",
444             "O:SYG:SYD:(XA;OICI;CR;;;WD;((Member_of {SID(AO)}) || (Member_of {SID(BO)})))",
445             "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(%s)}))" % self.domain_sid,
446         ]
447
448         # Values coming in are SDDL strings
449         expected = [
450             (value, MessageElement(ndr_pack(self.security_descriptor(value))))
451             for value in values
452         ]
453
454         # Values coming in are already security descriptors
455         expected.extend([
456             (self.security_descriptor(value),
457              MessageElement(ndr_pack(self.security_descriptor(value))))
458             for value in values
459         ])
460
461         expected.append((None, None))
462         return expected
463
464     @property
465     def from_db_value(self):
466         values = [
467             "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AU)}))",
468             "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AO)}))",
469             "O:SYG:SYD:(XA;OICI;CR;;;WD;((Member_of {SID(AO)}) || (Member_of {SID(BO)})))",
470             "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(%s)}))" % self.domain_sid,
471         ]
472         expected = [
473             (MessageElement(ndr_pack(self.security_descriptor(value))),
474              self.security_descriptor(value))
475             for value in values
476         ]
477         expected.append((None, None))
478         return expected
479
480
481 class PossibleClaimValuesFieldTest(FieldTestMixin, SambaToolCmdTest):
482     field = fields.PossibleClaimValuesField("FieldName")
483
484     json_data = [{
485         "ValueGUID": "1c39ed4f-0b26-4536-b963-5959c8b1b676",
486         "ValueDisplayName": "Alice",
487         "ValueDescription": "Alice Description",
488         "Value": "alice",
489     }]
490
491     xml_data = "<?xml version='1.0' encoding='utf-16'?>" \
492                "<PossibleClaimValues xmlns:xsd='http://www.w3.org/2001/XMLSchema'" \
493                " xmlns:xsi='http://www.w3.org/2001/XMLSchema-instance'" \
494                " xmlns='http://schemas.microsoft.com/2010/08/ActiveDirectory/PossibleValues'>" \
495                "<StringList>" \
496                "<Item>" \
497                "<ValueGUID>1c39ed4f-0b26-4536-b963-5959c8b1b676</ValueGUID>" \
498                "<ValueDisplayName>Alice</ValueDisplayName>" \
499                "<ValueDescription>Alice Description</ValueDescription>" \
500                "<Value>alice</Value>" \
501                "</Item>" \
502                "</StringList>" \
503                "</PossibleClaimValues>"
504
505     def validate_xml(self, db_field):
506         """Callback that compares XML strings.
507
508         Tidying the HTMl output and adding consistent indentation was only
509         added to ETree in Python 3.9+ so generate a single line XML string.
510
511         This is just based on comparing the parsed XML, converted back
512         to a string, then comparing those strings.
513
514         So the expected xml_data string must have no spacing or indentation.
515
516         :param db_field: MessageElement value returned by field.to_db_field()
517         """
518         expected = ElementTree.fromstring(self.xml_data)
519         parsed = ElementTree.fromstring(str(db_field))
520         return ElementTree.tostring(parsed) == ElementTree.tostring(expected)
521
522     @property
523     def to_db_value(self):
524         return [
525             (self.json_data, self.validate_xml),     # callback to validate XML
526             (self.json_data[0], self.validate_xml),  # one item wrapped as list
527             ([], None),                              # empty list clears field
528             (None, None),
529         ]
530
531     @property
532     def from_db_value(self):
533         return [
534             (MessageElement(self.xml_data), self.json_data),
535             (None, None),
536         ]