1 # Unix SMB/CIFS implementation.
3 # Model and basic ORM for the Ldb database.
5 # Copyright (C) Catalyst.Net Ltd. 2023
7 # Written by Rob van der Linde <rob@catalyst.net.nz>
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
24 from abc import ABCMeta, abstractmethod
26 from ldb import (ERR_NO_SUCH_OBJECT, FLAG_MOD_ADD, FLAG_MOD_REPLACE,
27 LdbError, Message, MessageElement, SCOPE_BASE,
29 from samba.sd_utils import SDUtils
31 from .constants import MODELS
32 from .exceptions import (DeleteError, FieldError, NotFound, ProtectError,
34 from .fields import (DateTimeField, DnField, Field, GUIDField, IntegerField,
36 from .query import Query
39 class ModelMeta(ABCMeta):
41 def __new__(mcls, name, bases, namespace, **kwargs):
42 cls = super().__new__(mcls, name, bases, namespace, **kwargs)
44 if cls.__name__ != "Model":
45 cls.fields = dict(inspect.getmembers(cls, lambda f: isinstance(f, Field)))
52 class Model(metaclass=ModelMeta):
53 cn = StringField("cn")
54 distinguished_name = DnField("distinguishedName")
56 ds_core_propagation_data = DateTimeField("dsCorePropagationData",
57 hidden=True, readonly=True)
58 instance_type = IntegerField("instanceType")
59 name = StringField("name")
60 object_category = DnField("objectCategory")
61 object_class = StringField("objectClass",
62 default=lambda obj: obj.get_object_class())
63 object_guid = GUIDField("objectGUID")
64 usn_changed = IntegerField("uSNChanged", hidden=True, readonly=True)
65 usn_created = IntegerField("uSNCreated", hidden=True, readonly=True)
66 when_changed = DateTimeField("whenChanged", hidden=True, readonly=True)
67 when_created = DateTimeField("whenCreated", hidden=True, readonly=True)
69 def __init__(self, **kwargs):
70 """Create a new model instance and optionally populate fields.
72 Does not save the object to the database, call .save() for that.
74 :param kwargs: Optional input fields to populate object with
76 # Used by the _apply method, holds the original ldb Message,
77 # which is used by save() to determine what fields changed.
80 for field_name, field in self.fields.items():
81 if field_name in kwargs:
82 default = kwargs[field_name]
83 elif callable(field.default):
84 default = field.default(self)
86 default = field.default
88 setattr(self, field_name, default)
91 """Return object representation for this model."""
92 return f"<{self.__class__.__name__}: {self}>"
95 """Stringify model instance to implement in each model."""
98 def __eq__(self, other):
99 """Basic object equality check only really checks if the dn matches.
101 :param other: The other object to compare with
106 return self.dn == other.dn
109 """Automatically called by custom JSONEncoder class.
111 When turning an object into json any fields of type RelatedField
112 will also end up calling this method.
114 if self.dn is not None:
118 def get_base_dn(ldb):
119 """Return the base DN for the container of this model.
121 :param ldb: Ldb connection
122 :return: Dn to use for new objects
124 return ldb.get_default_basedn()
127 def get_search_dn(cls, ldb):
128 """Return the DN used for querying.
130 By default, this just calls get_base_dn, but it is possible to
131 return a different Dn for querying.
133 :param ldb: Ldb connection
134 :return: Dn to use for searching
136 return cls.get_base_dn(ldb)
140 def get_object_class():
141 """Returns the objectClass for this model."""
145 def from_message(cls, ldb, message):
146 """Create a new model instance from the Ldb Message object.
148 :param ldb: Ldb connection
149 :param message: Ldb Message object to create instance from
152 obj._apply(ldb, message)
155 def _apply(self, ldb, message):
156 """Internal method to apply Ldb Message to current object.
158 :param ldb: Ldb connection
159 :param message: Ldb Message object to apply
161 # Store the ldb Message so that in save we can see what changed.
162 self._message = message
164 for attr, field in self.fields.items():
165 if field.name in message:
166 setattr(self, attr, field.from_db_value(ldb, message[field.name]))
168 def refresh(self, ldb, fields=None):
169 """Refresh object from database.
171 :param ldb: Ldb connection
172 :param fields: Optional list of field names to refresh
174 attrs = [self.fields[f].name for f in fields] if fields else None
176 # This shouldn't normally happen but in case the object refresh fails.
178 res = ldb.search(self.dn, scope=SCOPE_BASE, attrs=attrs)
179 except LdbError as e:
180 if e.args[0] == ERR_NO_SUCH_OBJECT:
181 raise NotFound(f"Refresh failed, object gone: {self.dn}")
184 self._apply(ldb, res[0])
186 def as_dict(self, include_hidden=False):
187 """Returns a dict representation of the model.
189 :param include_hidden: Include fields with hidden=True when set
190 :returns: dict representation of model using Ldb field names as keys
194 for attr, field in self.fields.items():
195 if not field.hidden or include_hidden:
196 value = getattr(self, attr)
197 if value is not None:
198 obj_dict[field.name] = value
203 def build_expression(cls, **kwargs):
204 """Build LDAP search expression from kwargs.
206 :param kwargs: fields to use for expression using model field names
208 # Take a copy, never modify the original if it can be avoided.
209 # Then always add the object_class to the search criteria.
210 criteria = dict(kwargs)
211 criteria["object_class"] = cls.get_object_class()
213 # Build search expression.
214 num_fields = len(criteria)
215 expression = "" if num_fields == 1 else "(&"
217 for field_name, value in criteria.items():
218 field = cls.fields.get(field_name)
220 raise ValueError(f"Unknown field '{field_name}'")
221 expression += field.expression(value)
229 def query(cls, ldb, **kwargs):
230 """Returns a search query for this model.
232 :param ldb: Ldb connection
233 :param kwargs: Search criteria as keyword args
235 base_dn = cls.get_search_dn(ldb)
237 # If the container does not exist produce a friendly error message.
239 result = ldb.search(base_dn,
241 expression=cls.build_expression(**kwargs))
242 except LdbError as e:
243 if e.args[0] == ERR_NO_SUCH_OBJECT:
244 raise NotFound(f"Container does not exist: {base_dn}")
247 return Query(cls, ldb, result)
250 def get(cls, ldb, **kwargs):
251 """Get one object, must always return one item.
253 Either find object by dn=, or any combination of attributes via kwargs.
254 If there are more than one result, MultipleObjectsReturned is raised.
256 :param ldb: Ldb connection
257 :param kwargs: Search criteria as keyword args
258 :returns: Model instance or None if not found
259 :raises: MultipleObjects returned if there are more than one results
261 # If a DN is provided use that to get the object directly.
262 # Otherwise, build a search expression using kwargs provided.
263 dn = kwargs.get("dn")
266 # Handle LDAP error 32 LDAP_NO_SUCH_OBJECT, but raise for the rest.
267 # Return None if the User does not exist.
269 res = ldb.search(dn, scope=SCOPE_BASE)
270 except LdbError as e:
271 if e.args[0] == ERR_NO_SUCH_OBJECT:
276 return cls.from_message(ldb, res[0])
278 return cls.query(ldb, **kwargs).get()
281 def create(cls, ldb, **kwargs):
282 """Create object constructs object and calls save straight after.
284 :param ldb: Ldb connection
285 :param kwargs: Fields to populate object from
293 def get_or_create(cls, ldb, defaults=None, **kwargs):
294 """Retrieve object and if it doesn't exist create a new instance.
296 :param ldb: Ldb connection
297 :param defaults: Attributes only used for create but not search
298 :param kwargs: Attributes used for searching existing object
299 :returns: (object, bool created)
301 obj = cls.get(ldb, **kwargs)
304 if defaults is not None:
305 attrs.update(defaults)
306 return cls.create(ldb, **attrs), True
311 """Save model to Ldb database.
313 The save operation will save all fields excluding fields that
314 return None when calling their `to_db_value` methods.
316 The `to_db_value` method can either return a ldb Message object,
317 or None if the field is to be excluded.
319 For updates, the existing object is fetched and only fields
320 that are changed are included in the update ldb Message.
322 Also for updates, any fields that currently have a value,
323 but are to be set to None will be seen as a delete operation.
325 After the save operation the object is refreshed from the server,
326 as often the server will populate some fields.
328 :param ldb: Ldb connection
331 dn = self.get_base_dn(ldb)
332 dn.add_child(f"CN={self.cn or self.name}")
335 message = Message(dn=self.dn)
336 for attr, field in self.fields.items():
337 if attr != "dn" and not field.readonly:
338 value = getattr(self, attr)
340 db_value = field.to_db_value(ldb, value, FLAG_MOD_ADD)
341 except ValueError as e:
342 raise FieldError(e, field=field)
344 # Don't add empty fields.
345 if db_value is not None and len(db_value):
346 message.add(db_value)
351 # Fetching object refreshes any automatically populated fields.
352 res = ldb.search(dn, scope=SCOPE_BASE)
353 self._apply(ldb, res[0])
355 # Existing Message was stored to work out what fields changed.
356 existing_obj = self.from_message(ldb, self._message)
358 # Only modify replace or modify fields that have changed.
359 # Any fields that are set to None or an empty list get unset.
360 message = Message(dn=self.dn)
361 for attr, field in self.fields.items():
362 if attr != "dn" and not field.readonly:
363 value = getattr(self, attr)
364 old_value = getattr(existing_obj, attr)
366 if value != old_value:
368 db_value = field.to_db_value(ldb, value,
370 except ValueError as e:
371 raise FieldError(e, field=field)
373 # When a field returns None or empty list, delete attr.
374 if db_value in (None, []):
375 db_value = MessageElement([],
378 message.add(db_value)
380 # Saving nothing only triggers an error.
384 # Fetching object refreshes any automatically populated fields.
387 def delete(self, ldb):
388 """Delete item from Ldb database.
390 If self.dn is None then the object has not yet been saved.
392 :param ldb: Ldb connection
395 raise DeleteError("Cannot delete object that doesn't have a dn.")
399 except LdbError as e:
400 raise DeleteError(f"Delete failed: {e}")
402 def protect(self, ldb):
403 """Protect object from accidental deletion.
405 :param ldb: Ldb connection
410 utils.dacl_add_ace(self.dn, "(D;;DTSD;;;WD)")
411 except LdbError as e:
412 raise ProtectError(f"Failed to protect object: {e}")
414 def unprotect(self, ldb):
415 """Unprotect object from accidental deletion.
417 :param ldb: Ldb connection
422 utils.dacl_delete_aces(self.dn, "(D;;DTSD;;;WD)")
423 except LdbError as e:
424 raise UnprotectError(f"Failed to unprotect object: {e}")