1 # Unix SMB/CIFS implementation.
3 # Tests for domain models and fields
5 # Copyright (C) Catalyst.Net Ltd. 2023
7 # Written by Rob van der Linde <rob@catalyst.net.nz>
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
24 from datetime import datetime, timezone
25 from xml.etree import ElementTree
27 from ldb import FLAG_MOD_ADD, SCOPE_ONELEVEL, MessageElement
29 from samba.dcerpc import security
30 from samba.dcerpc.misc import GUID
31 from samba.domain.models import (AccountType, Computer, Group, Site,
32 StrongNTLMPolicy, User, fields)
33 from samba.ndr import ndr_pack, ndr_unpack
35 from .base import SambaToolCmdTest
37 HOST = "ldap://{DC_SERVER}".format(**os.environ)
38 CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
41 class ModelTests(SambaToolCmdTest):
45 cls.samdb = cls.getSamDB("-H", HOST, CREDS)
48 def test_query_count(self):
49 """Test count property on Query object without converting to a list."""
50 groups = Group.query(self.samdb)
51 self.assertEqual(groups.count, len(list(groups)))
53 def test_query_filter_bool(self):
54 """Tests filtering by a BooleanField."""
55 total = Group.query(self.samdb).count
56 system_groups = Group.query(self.samdb,
57 is_critical_system_object=True).count
58 user_groups = Group.query(self.samdb,
59 is_critical_system_object=False).count
60 self.assertNotEqual(system_groups, 0)
61 self.assertNotEqual(user_groups, 0)
62 self.assertEqual(system_groups + user_groups, total)
64 def test_query_filter_enum(self):
65 """Tests filtering by an EnumField."""
66 robots_vs_humans = User.query(self.samdb).count
67 robots = User.query(self.samdb,
68 account_type=AccountType.WORKSTATION_TRUST).count
69 humans = User.query(self.samdb,
70 account_type=AccountType.NORMAL_ACCOUNT).count
71 self.assertNotEqual(robots, 0)
72 self.assertNotEqual(humans, 0)
73 self.assertEqual(robots + humans, robots_vs_humans)
76 class ComputerModelTests(SambaToolCmdTest):
80 cls.samdb = cls.getSamDB("-H", HOST, CREDS)
83 def test_computer_constructor(self):
85 comp1 = Computer.create(self.samdb, name="comp1")
86 self.addCleanup(comp1.delete, self.samdb)
87 self.assertEqual(comp1.name, "comp1")
88 self.assertEqual(comp1.account_name, "comp1$")
91 comp2 = Computer.create(self.samdb, cn="comp2")
92 self.addCleanup(comp2.delete, self.samdb)
93 self.assertEqual(comp2.name, "comp2")
94 self.assertEqual(comp2.account_name, "comp2$")
96 # Use name and account_name but missing "$" in account_name.
97 comp3 = Computer.create(self.samdb, name="comp3", account_name="comp3")
98 self.addCleanup(comp3.delete, self.samdb)
99 self.assertEqual(comp3.name, "comp3")
100 self.assertEqual(comp3.account_name, "comp3$")
102 # Use cn and account_name but missing "$" in account_name.
103 comp4 = Computer.create(self.samdb, cn="comp4", account_name="comp4$")
104 self.addCleanup(comp4.delete, self.samdb)
105 self.assertEqual(comp4.name, "comp4")
106 self.assertEqual(comp4.account_name, "comp4$")
108 # Use only account_name, the name should get the "$" removed.
109 comp5 = Computer.create(self.samdb, account_name="comp5$")
110 self.addCleanup(comp5.delete, self.samdb)
111 self.assertEqual(comp5.name, "comp5")
112 self.assertEqual(comp5.account_name, "comp5$")
114 # Use only account_name but accidentally forgot the "$" character.
115 comp6 = Computer.create(self.samdb, account_name="comp6")
116 self.addCleanup(comp6.delete, self.samdb)
117 self.assertEqual(comp6.name, "comp6")
118 self.assertEqual(comp6.account_name, "comp6$")
121 class FieldTestMixin:
122 """Tests a model field to ensure it behaves correctly in both directions.
124 Use a mixin since TestCase can't be marked as abstract.
129 cls.samdb = cls.getSamDB("-H", HOST, CREDS)
132 def get_users_dn(self):
133 """Returns Users DN."""
134 users_dn = self.samdb.get_root_basedn()
135 users_dn.add_child("CN=Users")
138 def test_to_db_value(self):
139 # Loop through each value and expected value combination.
140 # If the expected value is callable, treat it as a validation callback.
141 # NOTE: perhaps we should be using subtests for this.
142 for (value, expected) in self.to_db_value:
143 db_value = self.field.to_db_value(self.samdb, value, FLAG_MOD_ADD)
144 if callable(expected):
145 self.assertTrue(expected(db_value))
147 self.assertEqual(db_value, expected)
149 def test_from_db_value(self):
150 # Loop through each value and expected value combination.
151 # NOTE: perhaps we should be using subtests for this.
152 for (db_value, expected) in self.from_db_value:
153 value = self.field.from_db_value(self.samdb, db_value)
154 self.assertEqual(value, expected)
157 class IntegerFieldTest(FieldTestMixin, SambaToolCmdTest):
158 field = fields.IntegerField("FieldName")
161 (10, MessageElement(b"10")),
162 ([1, 5, 10], MessageElement([b"1", b"5", b"10"])),
167 (MessageElement(b"10"), 10),
168 (MessageElement([b"1", b"5", b"10"]), [1, 5, 10]),
173 class BinaryFieldTest(FieldTestMixin, SambaToolCmdTest):
174 field = fields.BinaryField("FieldName")
177 (b"SAMBA", MessageElement(b"SAMBA")),
178 ([b"SAMBA", b"Developer"], MessageElement([b"SAMBA", b"Developer"])),
183 (MessageElement(b"SAMBA"), b"SAMBA"),
184 (MessageElement([b"SAMBA", b"Developer"]), [b"SAMBA", b"Developer"]),
189 class StringFieldTest(FieldTestMixin, SambaToolCmdTest):
190 field = fields.StringField("FieldName")
193 ("SAMBA", MessageElement(b"SAMBA")),
194 (["SAMBA", "Developer"], MessageElement([b"SAMBA", b"Developer"])),
199 (MessageElement(b"SAMBA"), "SAMBA"),
200 (MessageElement([b"SAMBA", b"Developer"]), ["SAMBA", "Developer"]),
205 class BooleanFieldTest(FieldTestMixin, SambaToolCmdTest):
206 field = fields.BooleanField("FieldName")
209 (True, MessageElement(b"TRUE")),
210 ([False, True], MessageElement([b"FALSE", b"TRUE"])),
215 (MessageElement(b"TRUE"), True),
216 (MessageElement([b"FALSE", b"TRUE"]), [False, True]),
221 class EnumFieldTest(FieldTestMixin, SambaToolCmdTest):
222 field = fields.EnumField("FieldName", StrongNTLMPolicy)
225 (StrongNTLMPolicy.OPTIONAL, MessageElement("1")),
226 ([StrongNTLMPolicy.REQUIRED, StrongNTLMPolicy.OPTIONAL],
227 MessageElement(["2", "1"])),
232 (MessageElement("1"), StrongNTLMPolicy.OPTIONAL),
233 (MessageElement(["2", "1"]),
234 [StrongNTLMPolicy.REQUIRED, StrongNTLMPolicy.OPTIONAL]),
239 class DateTimeFieldTest(FieldTestMixin, SambaToolCmdTest):
240 field = fields.DateTimeField("FieldName")
243 (datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
244 MessageElement("20230127223641.0Z")),
245 ([datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
246 datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)],
247 MessageElement(["20230127223641.0Z", "20230127224750.0Z"])),
252 (MessageElement("20230127223641.0Z"),
253 datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc)),
254 (MessageElement(["20230127223641.0Z", "20230127224750.0Z"]),
255 [datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
256 datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)]),
261 class NtTimeFieldTest(FieldTestMixin, SambaToolCmdTest):
262 field = fields.NtTimeField("FieldName")
265 (datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
266 MessageElement("133193326010000000")),
267 ([datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
268 datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)],
269 MessageElement(["133193326010000000", "133193332700000000"])),
274 (MessageElement("133193326010000000"),
275 datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc)),
276 (MessageElement(["133193326010000000", "133193332700000000"]),
277 [datetime(2023, 1, 27, 22, 36, 41, tzinfo=timezone.utc),
278 datetime(2023, 1, 27, 22, 47, 50, tzinfo=timezone.utc)]),
283 class RelatedFieldTest(FieldTestMixin, SambaToolCmdTest):
284 field = fields.RelatedField("FieldName", User)
287 def to_db_value(self):
288 alice = User.get(self.samdb, account_name="alice")
289 joe = User.get(self.samdb, account_name="joe")
291 (alice, MessageElement(str(alice.dn))),
292 ([joe, alice], MessageElement([str(joe.dn), str(alice.dn)])),
297 def from_db_value(self):
298 alice = User.get(self.samdb, account_name="alice")
299 joe = User.get(self.samdb, account_name="joe")
301 (MessageElement(str(alice.dn)), alice),
302 (MessageElement([str(joe.dn), str(alice.dn)]), [joe, alice]),
307 class DnFieldTest(FieldTestMixin, SambaToolCmdTest):
308 field = fields.DnField("FieldName")
311 def to_db_value(self):
312 alice = User.get(self.samdb, account_name="alice")
313 joe = User.get(self.samdb, account_name="joe")
315 (alice.dn, MessageElement(str(alice.dn))),
316 ([joe.dn, alice.dn], MessageElement([str(joe.dn), str(alice.dn)])),
321 def from_db_value(self):
322 alice = User.get(self.samdb, account_name="alice")
323 joe = User.get(self.samdb, account_name="joe")
325 (MessageElement(str(alice.dn)), alice.dn),
326 (MessageElement([str(joe.dn), str(alice.dn)]), [joe.dn, alice.dn]),
331 class SIDFieldTest(FieldTestMixin, SambaToolCmdTest):
332 field = fields.SIDField("FieldName")
335 def to_db_value(self):
336 # Create a group for testing
337 group = Group(name="group1")
338 group.save(self.samdb)
339 self.addCleanup(group.delete, self.samdb)
341 # Get raw value to compare against
342 group_rec = self.samdb.search(Group.get_base_dn(self.samdb),
343 scope=SCOPE_ONELEVEL,
344 expression="(name=group1)",
345 attrs=["objectSid"])[0]
346 raw_sid = group_rec["objectSid"]
349 (group.object_sid, raw_sid),
354 def from_db_value(self):
355 # Create a group for testing
356 group = Group(name="group1")
357 group.save(self.samdb)
358 self.addCleanup(group.delete, self.samdb)
360 # Get raw value to compare against
361 group_rec = self.samdb.search(Group.get_base_dn(self.samdb),
362 scope=SCOPE_ONELEVEL,
363 expression="(name=group1)",
364 attrs=["objectSid"])[0]
365 raw_sid = group_rec["objectSid"]
368 (raw_sid, group.object_sid),
373 class GUIDFieldTest(FieldTestMixin, SambaToolCmdTest):
374 field = fields.GUIDField("FieldName")
377 def to_db_value(self):
378 users_dn = self.get_users_dn()
380 alice = self.samdb.search(users_dn,
381 scope=SCOPE_ONELEVEL,
382 expression="(sAMAccountName=alice)",
383 attrs=["objectGUID"])[0]
385 joe = self.samdb.search(users_dn,
386 scope=SCOPE_ONELEVEL,
387 expression="(sAMAccountName=joe)",
388 attrs=["objectGUID"])[0]
390 alice_guid = str(ndr_unpack(GUID, alice["objectGUID"][0]))
391 joe_guid = str(ndr_unpack(GUID, joe["objectGUID"][0]))
394 (alice_guid, alice["objectGUID"]),
396 [joe_guid, alice_guid],
397 MessageElement([joe["objectGUID"][0], alice["objectGUID"][0]]),
403 def from_db_value(self):
404 users_dn = self.get_users_dn()
406 alice = self.samdb.search(users_dn,
407 scope=SCOPE_ONELEVEL,
408 expression="(sAMAccountName=alice)",
409 attrs=["objectGUID"])[0]
411 joe = self.samdb.search(users_dn,
412 scope=SCOPE_ONELEVEL,
413 expression="(sAMAccountName=joe)",
414 attrs=["objectGUID"])[0]
416 alice_guid = str(ndr_unpack(GUID, alice["objectGUID"][0]))
417 joe_guid = str(ndr_unpack(GUID, joe["objectGUID"][0]))
420 (alice["objectGUID"], alice_guid),
422 MessageElement([joe["objectGUID"][0], alice["objectGUID"][0]]),
423 [joe_guid, alice_guid],
429 class SDDLFieldTest(FieldTestMixin, SambaToolCmdTest):
430 field = fields.SDDLField("FieldName")
434 self.domain_sid = security.dom_sid(self.samdb.get_domain_sid())
436 def security_descriptor(self, sddl):
437 return security.descriptor.from_sddl(sddl, self.domain_sid)
440 def to_db_value(self):
442 "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AU)}))",
443 "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AO)}))",
444 "O:SYG:SYD:(XA;OICI;CR;;;WD;((Member_of {SID(AO)}) || (Member_of {SID(BO)})))",
445 "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(%s)}))" % self.domain_sid,
448 # Values coming in are SDDL strings
450 (value, MessageElement(ndr_pack(self.security_descriptor(value))))
454 # Values coming in are already security descriptors
456 (self.security_descriptor(value),
457 MessageElement(ndr_pack(self.security_descriptor(value))))
461 expected.append((None, None))
465 def from_db_value(self):
467 "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AU)}))",
468 "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AO)}))",
469 "O:SYG:SYD:(XA;OICI;CR;;;WD;((Member_of {SID(AO)}) || (Member_of {SID(BO)})))",
470 "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(%s)}))" % self.domain_sid,
473 (MessageElement(ndr_pack(self.security_descriptor(value))),
474 self.security_descriptor(value))
477 expected.append((None, None))
481 class PossibleClaimValuesFieldTest(FieldTestMixin, SambaToolCmdTest):
482 field = fields.PossibleClaimValuesField("FieldName")
485 "ValueGUID": "1c39ed4f-0b26-4536-b963-5959c8b1b676",
486 "ValueDisplayName": "Alice",
487 "ValueDescription": "Alice Description",
491 xml_data = "<?xml version='1.0' encoding='utf-16'?>" \
492 "<PossibleClaimValues xmlns:xsd='http://www.w3.org/2001/XMLSchema'" \
493 " xmlns:xsi='http://www.w3.org/2001/XMLSchema-instance'" \
494 " xmlns='http://schemas.microsoft.com/2010/08/ActiveDirectory/PossibleValues'>" \
497 "<ValueGUID>1c39ed4f-0b26-4536-b963-5959c8b1b676</ValueGUID>" \
498 "<ValueDisplayName>Alice</ValueDisplayName>" \
499 "<ValueDescription>Alice Description</ValueDescription>" \
500 "<Value>alice</Value>" \
503 "</PossibleClaimValues>"
505 def validate_xml(self, db_field):
506 """Callback that compares XML strings.
508 Tidying the HTMl output and adding consistent indentation was only
509 added to ETree in Python 3.9+ so generate a single line XML string.
511 This is just based on comparing the parsed XML, converted back
512 to a string, then comparing those strings.
514 So the expected xml_data string must have no spacing or indentation.
516 :param db_field: MessageElement value returned by field.to_db_field()
518 expected = ElementTree.fromstring(self.xml_data)
519 parsed = ElementTree.fromstring(str(db_field))
520 return ElementTree.tostring(parsed) == ElementTree.tostring(expected)
523 def to_db_value(self):
525 (self.json_data, self.validate_xml), # callback to validate XML
526 (self.json_data[0], self.validate_xml), # one item wrapped as list
527 ([], None), # empty list clears field
532 def from_db_value(self):
534 (MessageElement(self.xml_data), self.json_data),