netcmd: models: move MODELS constant to constants.py to avoid import loop
[samba.git] / python / samba / netcmd / domain / models / model.py
1 # Unix SMB/CIFS implementation.
2 #
3 # Model and basic ORM for the Ldb database.
4 #
5 # Copyright (C) Catalyst.Net Ltd. 2023
6 #
7 # Written by Rob van der Linde <rob@catalyst.net.nz>
8 #
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 import inspect
24 from abc import ABCMeta, abstractmethod
25
26 from ldb import (ERR_NO_SUCH_OBJECT, FLAG_MOD_ADD, FLAG_MOD_REPLACE,
27                  LdbError, Message, MessageElement, SCOPE_BASE,
28                  SCOPE_SUBTREE)
29 from samba.sd_utils import SDUtils
30
31 from .constants import MODELS
32 from .exceptions import (DeleteError, FieldError, NotFound, ProtectError,
33                          UnprotectError)
34 from .fields import (DateTimeField, DnField, Field, GUIDField, IntegerField,
35                      StringField)
36 from .query import Query
37
38
39 class ModelMeta(ABCMeta):
40
41     def __new__(mcls, name, bases, namespace, **kwargs):
42         cls = super().__new__(mcls, name, bases, namespace, **kwargs)
43
44         if cls.__name__ != "Model":
45             cls.fields = dict(inspect.getmembers(cls, lambda f: isinstance(f, Field)))
46             cls.meta = mcls
47             MODELS[name] = cls
48
49         return cls
50
51
52 class Model(metaclass=ModelMeta):
53     cn = StringField("cn")
54     distinguished_name = DnField("distinguishedName")
55     dn = DnField("dn")
56     ds_core_propagation_data = DateTimeField("dsCorePropagationData",
57                                              hidden=True, readonly=True)
58     instance_type = IntegerField("instanceType")
59     name = StringField("name")
60     object_category = DnField("objectCategory")
61     object_class = StringField("objectClass",
62                                default=lambda obj: obj.get_object_class())
63     object_guid = GUIDField("objectGUID")
64     usn_changed = IntegerField("uSNChanged", hidden=True, readonly=True)
65     usn_created = IntegerField("uSNCreated", hidden=True, readonly=True)
66     when_changed = DateTimeField("whenChanged", hidden=True, readonly=True)
67     when_created = DateTimeField("whenCreated", hidden=True, readonly=True)
68
69     def __init__(self, **kwargs):
70         """Create a new model instance and optionally populate fields.
71
72         Does not save the object to the database, call .save() for that.
73
74         :param kwargs: Optional input fields to populate object with
75         """
76         # Used by the _apply method, holds the original ldb Message,
77         # which is used by save() to determine what fields changed.
78         self._message = None
79
80         for field_name, field in self.fields.items():
81             if field_name in kwargs:
82                 default = kwargs[field_name]
83             elif callable(field.default):
84                 default = field.default(self)
85             else:
86                 default = field.default
87
88             setattr(self, field_name, default)
89
90     def __repr__(self):
91         """Return object representation for this model."""
92         return f"<{self.__class__.__name__}: {self}>"
93
94     def __str__(self):
95         """Stringify model instance to implement in each model."""
96         return str(self.cn)
97
98     def __eq__(self, other):
99         """Basic object equality check only really checks if the dn matches.
100
101         :param other: The other object to compare with
102         """
103         if other is None:
104             return False
105         else:
106             return self.dn == other.dn
107
108     def __json__(self):
109         """Automatically called by custom JSONEncoder class.
110
111         When turning an object into json any fields of type RelatedField
112         will also end up calling this method.
113         """
114         if self.dn is not None:
115             return str(self.dn)
116
117     @staticmethod
118     def get_base_dn(ldb):
119         """Return the base DN for the container of this model.
120
121         :param ldb: Ldb connection
122         :return: Dn to use for new objects
123         """
124         return ldb.get_default_basedn()
125
126     @classmethod
127     def get_search_dn(cls, ldb):
128         """Return the DN used for querying.
129
130         By default, this just calls get_base_dn, but it is possible to
131         return a different Dn for querying.
132
133         :param ldb: Ldb connection
134         :return: Dn to use for searching
135         """
136         return cls.get_base_dn(ldb)
137
138     @staticmethod
139     @abstractmethod
140     def get_object_class():
141         """Returns the objectClass for this model."""
142         pass
143
144     @classmethod
145     def from_message(cls, ldb, message):
146         """Create a new model instance from the Ldb Message object.
147
148         :param ldb: Ldb connection
149         :param message: Ldb Message object to create instance from
150         """
151         obj = cls()
152         obj._apply(ldb, message)
153         return obj
154
155     def _apply(self, ldb, message):
156         """Internal method to apply Ldb Message to current object.
157
158         :param ldb: Ldb connection
159         :param message: Ldb Message object to apply
160         """
161         # Store the ldb Message so that in save we can see what changed.
162         self._message = message
163
164         for attr, field in self.fields.items():
165             if field.name in message:
166                 setattr(self, attr, field.from_db_value(ldb, message[field.name]))
167
168     def refresh(self, ldb, fields=None):
169         """Refresh object from database.
170
171         :param ldb: Ldb connection
172         :param fields: Optional list of field names to refresh
173         """
174         attrs = [self.fields[f].name for f in fields] if fields else None
175
176         # This shouldn't normally happen but in case the object refresh fails.
177         try:
178             res = ldb.search(self.dn, scope=SCOPE_BASE, attrs=attrs)
179         except LdbError as e:
180             if e.args[0] == ERR_NO_SUCH_OBJECT:
181                 raise NotFound(f"Refresh failed, object gone: {self.dn}")
182             raise
183
184         self._apply(ldb, res[0])
185
186     def as_dict(self, include_hidden=False):
187         """Returns a dict representation of the model.
188
189         :param include_hidden: Include fields with hidden=True when set
190         :returns: dict representation of model using Ldb field names as keys
191         """
192         obj_dict = {}
193
194         for attr, field in self.fields.items():
195             if not field.hidden or include_hidden:
196                 value = getattr(self, attr)
197                 if value is not None:
198                     obj_dict[field.name] = value
199
200         return obj_dict
201
202     @classmethod
203     def build_expression(cls, **kwargs):
204         """Build LDAP search expression from kwargs.
205
206         :param kwargs: fields to use for expression using model field names
207         """
208         # Take a copy, never modify the original if it can be avoided.
209         # Then always add the object_class to the search criteria.
210         criteria = dict(kwargs)
211         criteria["object_class"] = cls.get_object_class()
212
213         # Build search expression.
214         num_fields = len(criteria)
215         expression = "" if num_fields == 1 else "(&"
216
217         for field_name, value in criteria.items():
218             field = cls.fields.get(field_name)
219             if field is None:
220                 raise ValueError(f"Unknown field '{field_name}'")
221             expression += field.expression(value)
222
223         if num_fields > 1:
224             expression += ")"
225
226         return expression
227
228     @classmethod
229     def query(cls, ldb, **kwargs):
230         """Returns a search query for this model.
231
232         :param ldb: Ldb connection
233         :param kwargs: Search criteria as keyword args
234         """
235         base_dn = cls.get_search_dn(ldb)
236
237         # If the container does not exist produce a friendly error message.
238         try:
239             result = ldb.search(base_dn,
240                                 scope=SCOPE_SUBTREE,
241                                 expression=cls.build_expression(**kwargs))
242         except LdbError as e:
243             if e.args[0] == ERR_NO_SUCH_OBJECT:
244                 raise NotFound(f"Container does not exist: {base_dn}")
245             raise
246
247         return Query(cls, ldb, result)
248
249     @classmethod
250     def get(cls, ldb, **kwargs):
251         """Get one object, must always return one item.
252
253         Either find object by dn=, or any combination of attributes via kwargs.
254         If there are more than one result, MultipleObjectsReturned is raised.
255
256         :param ldb: Ldb connection
257         :param kwargs: Search criteria as keyword args
258         :returns: Model instance or None if not found
259         :raises: MultipleObjects returned if there are more than one results
260         """
261         # If a DN is provided use that to get the object directly.
262         # Otherwise, build a search expression using kwargs provided.
263         dn = kwargs.get("dn")
264
265         if dn:
266             # Handle LDAP error 32 LDAP_NO_SUCH_OBJECT, but raise for the rest.
267             # Return None if the User does not exist.
268             try:
269                 res = ldb.search(dn, scope=SCOPE_BASE)
270             except LdbError as e:
271                 if e.args[0] == ERR_NO_SUCH_OBJECT:
272                     return None
273                 else:
274                     raise
275
276             return cls.from_message(ldb, res[0])
277         else:
278             return cls.query(ldb, **kwargs).get()
279
280     @classmethod
281     def create(cls, ldb, **kwargs):
282         """Create object constructs object and calls save straight after.
283
284         :param ldb: Ldb connection
285         :param kwargs: Fields to populate object from
286         :returns: object
287         """
288         obj = cls(**kwargs)
289         obj.save(ldb)
290         return obj
291
292     @classmethod
293     def get_or_create(cls, ldb, defaults=None, **kwargs):
294         """Retrieve object and if it doesn't exist create a new instance.
295
296         :param ldb: Ldb connection
297         :param defaults: Attributes only used for create but not search
298         :param kwargs: Attributes used for searching existing object
299         :returns: (object, bool created)
300         """
301         obj = cls.get(ldb, **kwargs)
302         if obj is None:
303             attrs = dict(kwargs)
304             if defaults is not None:
305                 attrs.update(defaults)
306             return cls.create(ldb, **attrs), True
307         else:
308             return obj, False
309
310     def save(self, ldb):
311         """Save model to Ldb database.
312
313         The save operation will save all fields excluding fields that
314         return None when calling their `to_db_value` methods.
315
316         The `to_db_value` method can either return a ldb Message object,
317         or None if the field is to be excluded.
318
319         For updates, the existing object is fetched and only fields
320         that are changed are included in the update ldb Message.
321
322         Also for updates, any fields that currently have a value,
323         but are to be set to None will be seen as a delete operation.
324
325         After the save operation the object is refreshed from the server,
326         as often the server will populate some fields.
327
328         :param ldb: Ldb connection
329         """
330         if self.dn is None:
331             dn = self.get_base_dn(ldb)
332             dn.add_child(f"CN={self.cn or self.name}")
333             self.dn = dn
334
335             message = Message(dn=self.dn)
336             for attr, field in self.fields.items():
337                 if attr != "dn" and not field.readonly:
338                     value = getattr(self, attr)
339                     try:
340                         db_value = field.to_db_value(ldb, value, FLAG_MOD_ADD)
341                     except ValueError as e:
342                         raise FieldError(e, field=field)
343
344                     # Don't add empty fields.
345                     if db_value is not None and len(db_value):
346                         message.add(db_value)
347
348             # Create object
349             ldb.add(message)
350
351             # Fetching object refreshes any automatically populated fields.
352             res = ldb.search(dn, scope=SCOPE_BASE)
353             self._apply(ldb, res[0])
354         else:
355             # Existing Message was stored to work out what fields changed.
356             existing_obj = self.from_message(ldb, self._message)
357
358             # Only modify replace or modify fields that have changed.
359             # Any fields that are set to None or an empty list get unset.
360             message = Message(dn=self.dn)
361             for attr, field in self.fields.items():
362                 if attr != "dn" and not field.readonly:
363                     value = getattr(self, attr)
364                     old_value = getattr(existing_obj, attr)
365
366                     if value != old_value:
367                         try:
368                             db_value = field.to_db_value(ldb, value,
369                                                          FLAG_MOD_REPLACE)
370                         except ValueError as e:
371                             raise FieldError(e, field=field)
372
373                         # When a field returns None or empty list, delete attr.
374                         if db_value in (None, []):
375                             db_value = MessageElement([],
376                                                       FLAG_MOD_REPLACE,
377                                                       field.name)
378                         message.add(db_value)
379
380             # Saving nothing only triggers an error.
381             if len(message):
382                 ldb.modify(message)
383
384                 # Fetching object refreshes any automatically populated fields.
385                 self.refresh(ldb)
386
387     def delete(self, ldb):
388         """Delete item from Ldb database.
389
390         If self.dn is None then the object has not yet been saved.
391
392         :param ldb: Ldb connection
393         """
394         if self.dn is None:
395             raise DeleteError("Cannot delete object that doesn't have a dn.")
396
397         try:
398             ldb.delete(self.dn)
399         except LdbError as e:
400             raise DeleteError(f"Delete failed: {e}")
401
402     def protect(self, ldb):
403         """Protect object from accidental deletion.
404
405         :param ldb: Ldb connection
406         """
407         utils = SDUtils(ldb)
408
409         try:
410             utils.dacl_add_ace(self.dn, "(D;;DTSD;;;WD)")
411         except LdbError as e:
412             raise ProtectError(f"Failed to protect object: {e}")
413
414     def unprotect(self, ldb):
415         """Unprotect object from accidental deletion.
416
417         :param ldb: Ldb connection
418         """
419         utils = SDUtils(ldb)
420
421         try:
422             utils.dacl_delete_aces(self.dn, "(D;;DTSD;;;WD)")
423         except LdbError as e:
424             raise UnprotectError(f"Failed to unprotect object: {e}")