PEP8: fix E713: test for membership should be 'not in'
[sfrench/samba-autobuild/.git] / python / samba / samba3 / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 """Support for reading Samba 3 data files."""
19
20 __docformat__ = "restructuredText"
21
22 REGISTRY_VALUE_PREFIX = b"SAMBA_REGVAL"
23 REGISTRY_DB_VERSION = 1
24
25 import os
26 import struct
27 import tdb
28
29 import samba.samba3.passdb
30 from samba.samba3 import param as s3param
31
32
33 def fetch_uint32(db, key):
34     try:
35         data = db[key]
36     except KeyError:
37         return None
38     assert len(data) == 4
39     return struct.unpack("<L", data)[0]
40
41
42 def fetch_int32(db, key):
43     try:
44         data = db[key]
45     except KeyError:
46         return None
47     assert len(data) == 4
48     return struct.unpack("<l", data)[0]
49
50
51 class DbDatabase(object):
52     """Simple Samba 3 TDB database reader."""
53     def __init__(self, file):
54         """Open a file.
55
56         :param file: Path of the file to open, appending .tdb or .ntdb.
57         """
58         self.db = tdb.Tdb(file + ".tdb", flags=os.O_RDONLY)
59         self._check_version()
60
61     def _check_version(self):
62         pass
63
64     def close(self):
65         """Close resources associated with this object."""
66         self.db.close()
67
68
69 class Registry(DbDatabase):
70     """Simple read-only support for reading the Samba3 registry.
71
72     :note: This object uses the same syntax for registry key paths as
73         Samba 3. This particular format uses forward slashes for key path
74         separators and abbreviations for the predefined key names.
75         e.g.: HKLM/Software/Bar.
76     """
77     def __len__(self):
78         """Return the number of keys."""
79         return len(self.keys())
80
81     def keys(self):
82         """Return list with all the keys."""
83         return [k.rstrip(b"\x00") for k in self.db if not k.startswith(REGISTRY_VALUE_PREFIX)]
84
85     def subkeys(self, key):
86         """Retrieve the subkeys for the specified key.
87
88         :param key: Key path.
89         :return: list with key names
90         """
91         data = self.db.get(key + b"\x00")
92         if data is None:
93             return []
94         (num, ) = struct.unpack("<L", data[0:4])
95         keys = data[4:].split(b"\0")
96         assert keys[-1] == b""
97         keys.pop()
98         assert len(keys) == num
99         return keys
100
101     def values(self, key):
102         """Return a dictionary with the values set for a specific key.
103
104         :param key: Key to retrieve values for.
105         :return: Dictionary with value names as key, tuple with type and
106             data as value."""
107         data = self.db.get(REGISTRY_VALUE_PREFIX + b'/' + key + b'\x00')
108         if data is None:
109             return {}
110         ret = {}
111         (num, ) = struct.unpack("<L", data[0:4])
112         data = data[4:]
113         for i in range(num):
114             # Value name
115             (name, data) = data.split(b"\0", 1)
116
117             (type, ) = struct.unpack("<L", data[0:4])
118             data = data[4:]
119             (value_len, ) = struct.unpack("<L", data[0:4])
120             data = data[4:]
121
122             ret[name] = (type, data[:value_len])
123             data = data[value_len:]
124
125         return ret
126
127
128 # High water mark keys
129 IDMAP_HWM_GROUP = b"GROUP HWM\0"
130 IDMAP_HWM_USER = b"USER HWM\0"
131
132 IDMAP_GROUP_PREFIX = b"GID "
133 IDMAP_USER_PREFIX = b"UID "
134
135 # idmap version determines auto-conversion
136 IDMAP_VERSION_V2 = 2
137
138
139 class IdmapDatabase(DbDatabase):
140     """Samba 3 ID map database reader."""
141
142     def _check_version(self):
143         assert fetch_int32(self.db, b"IDMAP_VERSION\0") == IDMAP_VERSION_V2
144
145     def ids(self):
146         """Retrieve a list of all ids in this database."""
147         for k in self.db:
148             if k.startswith(IDMAP_USER_PREFIX):
149                 yield k.rstrip(b"\0").split(b" ")
150             if k.startswith(IDMAP_GROUP_PREFIX):
151                 yield k.rstrip(b"\0").split(b" ")
152
153     def uids(self):
154         """Retrieve a list of all uids in this database."""
155         for k in self.db:
156             if k.startswith(IDMAP_USER_PREFIX):
157                 yield int(k[len(IDMAP_USER_PREFIX):].rstrip(b"\0"))
158
159     def gids(self):
160         """Retrieve a list of all gids in this database."""
161         for k in self.db:
162             if k.startswith(IDMAP_GROUP_PREFIX):
163                 yield int(k[len(IDMAP_GROUP_PREFIX):].rstrip(b"\0"))
164
165     def get_sid(self, xid, id_type):
166         """Retrive SID associated with a particular id and type.
167
168         :param xid: UID or GID to retrieve SID for.
169         :param id_type: Type of id specified - 'UID' or 'GID'
170         """
171         data = self.db.get("%s %s\0" % (id_type, str(xid)))
172         if data is None:
173             return data
174         return data.rstrip("\0")
175
176     def get_user_sid(self, uid):
177         """Retrieve the SID associated with a particular uid.
178
179         :param uid: UID to retrieve SID for.
180         :return: A SID or None if no mapping was found.
181         """
182         data = self.db.get(IDMAP_USER_PREFIX + str(uid).encode() + b'\0')
183         if data is None:
184             return data
185         return data.rstrip(b"\0")
186
187     def get_group_sid(self, gid):
188         data = self.db.get(IDMAP_GROUP_PREFIX + str(gid).encode() + b'\0')
189         if data is None:
190             return data
191         return data.rstrip(b"\0")
192
193     def get_user_hwm(self):
194         """Obtain the user high-water mark."""
195         return fetch_uint32(self.db, IDMAP_HWM_USER)
196
197     def get_group_hwm(self):
198         """Obtain the group high-water mark."""
199         return fetch_uint32(self.db, IDMAP_HWM_GROUP)
200
201
202 class SecretsDatabase(DbDatabase):
203     """Samba 3 Secrets database reader."""
204
205     def get_auth_password(self):
206         return self.db.get("SECRETS/AUTH_PASSWORD")
207
208     def get_auth_domain(self):
209         return self.db.get("SECRETS/AUTH_DOMAIN")
210
211     def get_auth_user(self):
212         return self.db.get("SECRETS/AUTH_USER")
213
214     def get_domain_guid(self, host):
215         return self.db.get("SECRETS/DOMGUID/%s" % host)
216
217     def ldap_dns(self):
218         for k in self.db:
219             if k.startswith("SECRETS/LDAP_BIND_PW/"):
220                 yield k[len("SECRETS/LDAP_BIND_PW/"):].rstrip("\0")
221
222     def domains(self):
223         """Iterate over domains in this database.
224
225         :return: Iterator over the names of domains in this database.
226         """
227         for k in self.db:
228             if k.startswith("SECRETS/SID/"):
229                 yield k[len("SECRETS/SID/"):].rstrip("\0")
230
231     def get_ldap_bind_pw(self, host):
232         return self.db.get("SECRETS/LDAP_BIND_PW/%s" % host)
233
234     def get_afs_keyfile(self, host):
235         return self.db.get("SECRETS/AFS_KEYFILE/%s" % host)
236
237     def get_machine_sec_channel_type(self, host):
238         return fetch_uint32(self.db, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host)
239
240     def get_machine_last_change_time(self, host):
241         return fetch_uint32(self.db, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host)
242
243     def get_machine_password(self, host):
244         return self.db.get("SECRETS/MACHINE_PASSWORD/%s" % host)
245
246     def get_machine_acc(self, host):
247         return self.db.get("SECRETS/$MACHINE.ACC/%s" % host)
248
249     def get_domtrust_acc(self, host):
250         return self.db.get("SECRETS/$DOMTRUST.ACC/%s" % host)
251
252     def trusted_domains(self):
253         for k in self.db:
254             if k.startswith("SECRETS/$DOMTRUST.ACC/"):
255                 yield k[len("SECRETS/$DOMTRUST.ACC/"):].rstrip("\0")
256
257     def get_random_seed(self):
258         return self.db.get("INFO/random_seed")
259
260     def get_sid(self, host):
261         return self.db.get("SECRETS/SID/%s" % host.upper())
262
263
264 SHARE_DATABASE_VERSION_V1 = 1
265 SHARE_DATABASE_VERSION_V2 = 2
266
267
268 class ShareInfoDatabase(DbDatabase):
269     """Samba 3 Share Info database reader."""
270
271     def _check_version(self):
272         assert fetch_int32(self.db, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2)
273
274     def get_secdesc(self, name):
275         """Obtain the security descriptor on a particular share.
276
277         :param name: Name of the share
278         """
279         secdesc = self.db.get("SECDESC/%s" % name)
280         # FIXME: Run ndr_pull_security_descriptor
281         return secdesc
282
283
284 class Shares(object):
285     """Container for share objects."""
286     def __init__(self, lp, shareinfo):
287         self.lp = lp
288         self.shareinfo = shareinfo
289
290     def __len__(self):
291         """Number of shares."""
292         return len(self.lp) - 1
293
294     def __iter__(self):
295         """Iterate over the share names."""
296         return self.lp.__iter__()
297
298
299 def shellsplit(text):
300     """Very simple shell-like line splitting.
301
302     :param text: Text to split.
303     :return: List with parts of the line as strings.
304     """
305     ret = list()
306     inquotes = False
307     current = ""
308     for c in text:
309         if c == "\"":
310             inquotes = not inquotes
311         elif c in ("\t", "\n", " ") and not inquotes:
312             if current != "":
313                 ret.append(current)
314             current = ""
315         else:
316             current += c
317     if current != "":
318         ret.append(current)
319     return ret
320
321
322 class WinsDatabase(object):
323     """Samba 3 WINS database reader."""
324     def __init__(self, file):
325         self.entries = {}
326         f = open(file, 'r')
327         assert f.readline().rstrip("\n") == "VERSION 1 0"
328         for l in f.readlines():
329             if l[0] == "#":  # skip comments
330                 continue
331             entries = shellsplit(l.rstrip("\n"))
332             name = entries[0]
333             ttl = int(entries[1])
334             i = 2
335             ips = []
336             while "." in entries[i]:
337                 ips.append(entries[i])
338                 i += 1
339             nb_flags = int(entries[i][:-1], 16)
340             assert name not in self.entries, "Name %s exists twice" % name
341             self.entries[name] = (ttl, ips, nb_flags)
342         f.close()
343
344     def __getitem__(self, name):
345         return self.entries[name]
346
347     def __len__(self):
348         return len(self.entries)
349
350     def __iter__(self):
351         return iter(self.entries)
352
353     def items(self):
354         """Return the entries in this WINS database."""
355         return self.entries.items()
356
357     def close(self):  # for consistency
358         pass
359
360
361 class Samba3(object):
362     """Samba 3 configuration and state data reader."""
363
364     def __init__(self, smbconfpath, s3_lp_ctx=None):
365         """Open the configuration and data for a Samba 3 installation.
366
367         :param smbconfpath: Path to the smb.conf file.
368         :param s3_lp_ctx: Samba3 Loadparm context
369         """
370         self.smbconfpath = smbconfpath
371         if s3_lp_ctx:
372             self.lp = s3_lp_ctx
373         else:
374             self.lp = s3param.get_context()
375             self.lp.load(smbconfpath)
376
377     def statedir_path(self, path):
378         if path[0] == "/" or path[0] == ".":
379             return path
380         return os.path.join(self.lp.get("state directory"), path)
381
382     def privatedir_path(self, path):
383         if path[0] == "/" or path[0] == ".":
384             return path
385         return os.path.join(self.lp.get("private dir"), path)
386
387     def get_conf(self):
388         return self.lp
389
390     def get_sam_db(self):
391         return passdb.PDB(self.lp.get('passdb backend'))
392
393     def get_registry(self):
394         return Registry(self.statedir_path("registry"))
395
396     def get_secrets_db(self):
397         return SecretsDatabase(self.privatedir_path("secrets"))
398
399     def get_shareinfo_db(self):
400         return ShareInfoDatabase(self.statedir_path("share_info"))
401
402     def get_idmap_db(self):
403         return IdmapDatabase(self.statedir_path("winbindd_idmap"))
404
405     def get_wins_db(self):
406         return WinsDatabase(self.statedir_path("wins.dat"))
407
408     def get_shares(self):
409         return Shares(self.get_conf(), self.get_shareinfo_db())