netcmd: models: fix build_expression did not work with EnumField
authorRob van der Linde <rob@catalyst.net.nz>
Thu, 18 Jan 2024 02:47:52 +0000 (15:47 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Thu, 8 Feb 2024 02:48:44 +0000 (02:48 +0000)
Signed-off-by: Rob van der Linde <rob@catalyst.net.nz>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/netcmd/domain/models/fields.py
python/samba/tests/samba_tool/domain_models.py

index e9f0529680079b4e913629e63bee266752a84367..c02562e7c3704a9f1b51d693ebac2195c32cf0dd 100644 (file)
@@ -209,6 +209,10 @@ class EnumField(Field):
         else:
             return MessageElement(str(value.value), flags, self.name)
 
+    def expression(self, value):
+        """Returns the ldb search expression for this field."""
+        return f"({self.name}={binary_encode(str(value.value))})"
+
 
 class DateTimeField(Field):
     """A field for parsing ldb timestamps into Python datetime."""
index d58f47bfd9afc2eb3626c550f7977e3d7d578191..45d6095c775a5eb0f3f691149b16063676be87c3 100644 (file)
@@ -27,8 +27,8 @@ from xml.etree import ElementTree
 from ldb import FLAG_MOD_ADD, MessageElement, SCOPE_ONELEVEL
 from samba.dcerpc import security
 from samba.dcerpc.misc import GUID
-from samba.netcmd.domain.models import (Group, Site, User, StrongNTLMPolicy,
-                                        fields)
+from samba.netcmd.domain.models import (AccountType, Group, Site, User,
+                                        StrongNTLMPolicy, fields)
 from samba.ndr import ndr_pack, ndr_unpack
 
 from .base import SambaToolCmdTest
@@ -37,6 +37,41 @@ HOST = "ldap://{DC_SERVER}".format(**os.environ)
 CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
 
 
+class ModelTests(SambaToolCmdTest):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.samdb = cls.getSamDB("-H", HOST, CREDS)
+        super().setUpClass()
+
+    def test_query_count(self):
+        """Test count property on Query object without converting to a list."""
+        groups = Group.query(self.samdb)
+        self.assertEqual(groups.count, len(list(groups)))
+
+    def test_query_filter_bool(self):
+        """Tests filtering by a BooleanField."""
+        total = Group.query(self.samdb).count
+        system_groups = Group.query(self.samdb,
+                                    is_critical_system_object=True).count
+        user_groups = Group.query(self.samdb,
+                                  is_critical_system_object=False).count
+        self.assertNotEqual(system_groups, 0)
+        self.assertNotEqual(user_groups, 0)
+        self.assertEqual(system_groups + user_groups, total)
+
+    def test_query_filter_enum(self):
+        """Tests filtering by an EnumField."""
+        robots_vs_humans = User.query(self.samdb).count
+        robots = User.query(self.samdb,
+                            account_type=AccountType.WORKSTATION_TRUST).count
+        humans = User.query(self.samdb,
+                            account_type=AccountType.NORMAL_ACCOUNT).count
+        self.assertNotEqual(robots, 0)
+        self.assertNotEqual(humans, 0)
+        self.assertEqual(robots + humans, robots_vs_humans)
+
+
 class FieldTestMixin:
     """Tests a model field to ensure it behaves correctly in both directions.