netcmd: models: Model.query adds optional polymorphic flag for returning specific...
authorRob van der Linde <rob@catalyst.net.nz>
Tue, 20 Feb 2024 03:45:45 +0000 (16:45 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Fri, 1 Mar 2024 04:45:36 +0000 (04:45 +0000)
This defaults to False, query the User class returns only User instances.

    User.query(samdb)

When set to True, query the User class can return User, Computer, ManagedServiceAccount instances.

    User.query(samdb, polymorphic=True)

If polymorphic is False the same records are still returned but records will always be interpreted as the model that is being queried only, rather than a more specific model that matches that object class.

Signed-off-by: Rob van der Linde <rob@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
python/samba/netcmd/domain/models/model.py
python/samba/netcmd/domain/models/query.py

index 62cc0bbc0f2c5e87e06ee5cfa07703f2a4b44712..fd4df3f068835f811f8d7734ddb18ec5cf6258c8 100644 (file)
@@ -226,10 +226,18 @@ class Model(metaclass=ModelMeta):
         return expression
 
     @classmethod
-    def query(cls, ldb, **kwargs):
+    def query(cls, ldb, polymorphic=False, **kwargs):
         """Returns a search query for this model.
 
+        NOTE: If polymorphic is enabled then querying will return instances
+        of that specific model, for example querying User can return Computer
+        and ManagedServiceAccount instances.
+
+        By default, polymorphic querying is disabled, and querying User
+        will only return User instances.
+
         :param ldb: Ldb connection
+        :param polymorphic: If true enables polymorphic querying (see note)
         :param kwargs: Search criteria as keyword args
         """
         base_dn = cls.get_search_dn(ldb)
@@ -244,7 +252,7 @@ class Model(metaclass=ModelMeta):
                 raise NotFound(f"Container does not exist: {base_dn}")
             raise
 
-        return Query(cls, ldb, result)
+        return Query(cls, ldb, result, polymorphic)
 
     @classmethod
     def get(cls, ldb, **kwargs):
index 5c3a152744571c88d13e21fc28aa3e61acfd74d5..2856a50f825812e0c9af06810cac52cb74bb50c7 100644 (file)
@@ -22,6 +22,7 @@
 
 import re
 
+from .constants import MODELS
 from .exceptions import NotFound, MultipleObjectsReturned
 
 RE_SPLIT_CAMELCASE = re.compile(r"[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))")
@@ -30,27 +31,44 @@ RE_SPLIT_CAMELCASE = re.compile(r"[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))")
 class Query:
     """Simple Query class used by the `Model.query` method."""
 
-    def __init__(self, model, ldb, result):
+    def __init__(self, model, ldb, result, polymorphic):
         self.model = model
         self.ldb = ldb
         self.result = result
         self.count = result.count
         self.name = " ".join(RE_SPLIT_CAMELCASE.findall(model.__name__)).lower()
+        self.polymorphic = polymorphic
 
     def __iter__(self):
         """Loop over Query class yields Model instances."""
         for message in self.result:
-            yield self.model.from_message(self.ldb, message)
+            yield self._model_from_message(message)
+
+    def _model_from_message(self, message):
+        """Returns the model class to use to construct instances.
+
+        If polymorphic query is enabled it will use the last item from
+        the objectClass list.
+
+        Otherwise, it will use the model from the queryset.
+        """
+        if self.polymorphic:
+            object_class = str(message["objectClass"][-1])
+            model = MODELS.get(object_class, self.model)
+        else:
+            model = self.model
+
+        return model.from_message(self.ldb, message)
 
     def first(self):
         """Returns the first item in the Query or None for no results."""
         if self.count:
-            return self.model.from_message(self.ldb, self.result[0])
+            return self._model_from_message(self.result[0])
 
     def last(self):
         """Returns the last item in the Query or None for no results."""
         if self.count:
-            return self.model.from_message(self.ldb, self.result[-1])
+            return self._model_from_message(self.result[-1])
 
     def get(self):
         """Returns one item or None if no results were found.
@@ -62,7 +80,7 @@ class Query:
             raise MultipleObjectsReturned(
                 f"More than one {self.name} objects returned (got {self.count}).")
         elif self.count:
-            return self.model.from_message(self.ldb, self.result[0])
+            return self._model_from_message(self.result[0])
 
     def one(self):
         """Must return EXACTLY one item or raise an exception.
@@ -78,4 +96,4 @@ class Query:
             raise MultipleObjectsReturned(
                 f"More than one {self.name} objects returned (got {self.count}).")
         else:
-            return self.model.from_message(self.ldb, self.result[0])
+            return self._model_from_message(self.result[0])