python selftest: enabled samba.tests.s3registry to run with py3
[samba.git] / python / samba / samba3 / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 """Support for reading Samba 3 data files."""
19
20 __docformat__ = "restructuredText"
21
22 REGISTRY_VALUE_PREFIX = b"SAMBA_REGVAL"
23 REGISTRY_DB_VERSION = 1
24
25 import os
26 import struct
27 import tdb
28
29 import samba.samba3.passdb
30 from samba.samba3 import param as s3param
31
32 def fetch_uint32(db, key):
33     try:
34         data = db[key]
35     except KeyError:
36         return None
37     assert len(data) == 4
38     return struct.unpack("<L", data)[0]
39
40
41 def fetch_int32(db, key):
42     try:
43         data = db[key]
44     except KeyError:
45         return None
46     assert len(data) == 4
47     return struct.unpack("<l", data)[0]
48
49
50 class DbDatabase(object):
51     """Simple Samba 3 TDB database reader."""
52     def __init__(self, file):
53         """Open a file.
54
55         :param file: Path of the file to open, appending .tdb or .ntdb.
56         """
57         self.db = tdb.Tdb(file + ".tdb", flags=os.O_RDONLY)
58         self._check_version()
59
60     def _check_version(self):
61         pass
62
63     def close(self):
64         """Close resources associated with this object."""
65         self.db.close()
66
67
68 class Registry(DbDatabase):
69     """Simple read-only support for reading the Samba3 registry.
70
71     :note: This object uses the same syntax for registry key paths as
72         Samba 3. This particular format uses forward slashes for key path
73         separators and abbreviations for the predefined key names.
74         e.g.: HKLM/Software/Bar.
75     """
76     def __len__(self):
77         """Return the number of keys."""
78         return len(self.keys())
79
80     def keys(self):
81         """Return list with all the keys."""
82         return [k.rstrip(b"\x00") for k in self.db if not k.startswith(REGISTRY_VALUE_PREFIX)]
83
84     def subkeys(self, key):
85         """Retrieve the subkeys for the specified key.
86
87         :param key: Key path.
88         :return: list with key names
89         """
90         data = self.db.get(b"%s\x00" % key)
91         if data is None:
92             return []
93         (num, ) = struct.unpack("<L", data[0:4])
94         keys = data[4:].split(b"\0")
95         assert keys[-1] == b""
96         keys.pop()
97         assert len(keys) == num
98         return keys
99
100     def values(self, key):
101         """Return a dictionary with the values set for a specific key.
102
103         :param key: Key to retrieve values for.
104         :return: Dictionary with value names as key, tuple with type and
105             data as value."""
106         data = self.db.get(b"%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key))
107         if data is None:
108             return {}
109         ret = {}
110         (num, ) = struct.unpack("<L", data[0:4])
111         data = data[4:]
112         for i in range(num):
113             # Value name
114             (name, data) = data.split(b"\0", 1)
115
116             (type, ) = struct.unpack("<L", data[0:4])
117             data = data[4:]
118             (value_len, ) = struct.unpack("<L", data[0:4])
119             data = data[4:]
120
121             ret[name] = (type, data[:value_len])
122             data = data[value_len:]
123
124         return ret
125
126
127 # High water mark keys
128 IDMAP_HWM_GROUP = "GROUP HWM\0"
129 IDMAP_HWM_USER = "USER HWM\0"
130
131 IDMAP_GROUP_PREFIX = "GID "
132 IDMAP_USER_PREFIX = "UID "
133
134 # idmap version determines auto-conversion
135 IDMAP_VERSION_V2 = 2
136
137 class IdmapDatabase(DbDatabase):
138     """Samba 3 ID map database reader."""
139
140     def _check_version(self):
141         assert fetch_int32(self.db, "IDMAP_VERSION\0") == IDMAP_VERSION_V2
142
143     def ids(self):
144         """Retrieve a list of all ids in this database."""
145         for k in self.db.iterkeys():
146             if k.startswith(IDMAP_USER_PREFIX):
147                 yield k.rstrip("\0").split(" ")
148             if k.startswith(IDMAP_GROUP_PREFIX):
149                 yield k.rstrip("\0").split(" ")
150
151     def uids(self):
152         """Retrieve a list of all uids in this database."""
153         for k in self.db.iterkeys():
154             if k.startswith(IDMAP_USER_PREFIX):
155                 yield int(k[len(IDMAP_USER_PREFIX):].rstrip("\0"))
156
157     def gids(self):
158         """Retrieve a list of all gids in this database."""
159         for k in self.db.iterkeys():
160             if k.startswith(IDMAP_GROUP_PREFIX):
161                 yield int(k[len(IDMAP_GROUP_PREFIX):].rstrip("\0"))
162
163     def get_sid(self, xid, id_type):
164         """Retrive SID associated with a particular id and type.
165
166         :param xid: UID or GID to retrive SID for.
167         :param id_type: Type of id specified - 'UID' or 'GID'
168         """
169         data = self.db.get("%s %s\0" % (id_type, str(xid)))
170         if data is None:
171             return data
172         return data.rstrip("\0")
173
174     def get_user_sid(self, uid):
175         """Retrieve the SID associated with a particular uid.
176
177         :param uid: UID to retrieve SID for.
178         :return: A SID or None if no mapping was found.
179         """
180         data = self.db.get("%s%d\0" % (IDMAP_USER_PREFIX, uid))
181         if data is None:
182             return data
183         return data.rstrip("\0")
184
185     def get_group_sid(self, gid):
186         data = self.db.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid))
187         if data is None:
188             return data
189         return data.rstrip("\0")
190
191     def get_user_hwm(self):
192         """Obtain the user high-water mark."""
193         return fetch_uint32(self.db, IDMAP_HWM_USER)
194
195     def get_group_hwm(self):
196         """Obtain the group high-water mark."""
197         return fetch_uint32(self.db, IDMAP_HWM_GROUP)
198
199
200 class SecretsDatabase(DbDatabase):
201     """Samba 3 Secrets database reader."""
202
203     def get_auth_password(self):
204         return self.db.get("SECRETS/AUTH_PASSWORD")
205
206     def get_auth_domain(self):
207         return self.db.get("SECRETS/AUTH_DOMAIN")
208
209     def get_auth_user(self):
210         return self.db.get("SECRETS/AUTH_USER")
211
212     def get_domain_guid(self, host):
213         return self.db.get("SECRETS/DOMGUID/%s" % host)
214
215     def ldap_dns(self):
216         for k in self.db.iterkeys():
217             if k.startswith("SECRETS/LDAP_BIND_PW/"):
218                 yield k[len("SECRETS/LDAP_BIND_PW/"):].rstrip("\0")
219
220     def domains(self):
221         """Iterate over domains in this database.
222
223         :return: Iterator over the names of domains in this database.
224         """
225         for k in self.db.iterkeys():
226             if k.startswith("SECRETS/SID/"):
227                 yield k[len("SECRETS/SID/"):].rstrip("\0")
228
229     def get_ldap_bind_pw(self, host):
230         return self.db.get("SECRETS/LDAP_BIND_PW/%s" % host)
231
232     def get_afs_keyfile(self, host):
233         return self.db.get("SECRETS/AFS_KEYFILE/%s" % host)
234
235     def get_machine_sec_channel_type(self, host):
236         return fetch_uint32(self.db, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host)
237
238     def get_machine_last_change_time(self, host):
239         return fetch_uint32(self.db, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host)
240
241     def get_machine_password(self, host):
242         return self.db.get("SECRETS/MACHINE_PASSWORD/%s" % host)
243
244     def get_machine_acc(self, host):
245         return self.db.get("SECRETS/$MACHINE.ACC/%s" % host)
246
247     def get_domtrust_acc(self, host):
248         return self.db.get("SECRETS/$DOMTRUST.ACC/%s" % host)
249
250     def trusted_domains(self):
251         for k in self.db.iterkeys():
252             if k.startswith("SECRETS/$DOMTRUST.ACC/"):
253                 yield k[len("SECRETS/$DOMTRUST.ACC/"):].rstrip("\0")
254
255     def get_random_seed(self):
256         return self.db.get("INFO/random_seed")
257
258     def get_sid(self, host):
259         return self.db.get("SECRETS/SID/%s" % host.upper())
260
261
262 SHARE_DATABASE_VERSION_V1 = 1
263 SHARE_DATABASE_VERSION_V2 = 2
264
265
266 class ShareInfoDatabase(DbDatabase):
267     """Samba 3 Share Info database reader."""
268
269     def _check_version(self):
270         assert fetch_int32(self.db, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2)
271
272     def get_secdesc(self, name):
273         """Obtain the security descriptor on a particular share.
274
275         :param name: Name of the share
276         """
277         secdesc = self.db.get("SECDESC/%s" % name)
278         # FIXME: Run ndr_pull_security_descriptor
279         return secdesc
280
281
282 class Shares(object):
283     """Container for share objects."""
284     def __init__(self, lp, shareinfo):
285         self.lp = lp
286         self.shareinfo = shareinfo
287
288     def __len__(self):
289         """Number of shares."""
290         return len(self.lp) - 1
291
292     def __iter__(self):
293         """Iterate over the share names."""
294         return self.lp.__iter__()
295
296
297 def shellsplit(text):
298     """Very simple shell-like line splitting.
299
300     :param text: Text to split.
301     :return: List with parts of the line as strings.
302     """
303     ret = list()
304     inquotes = False
305     current = ""
306     for c in text:
307         if c == "\"":
308             inquotes = not inquotes
309         elif c in ("\t", "\n", " ") and not inquotes:
310             if current != "":
311                 ret.append(current)
312             current = ""
313         else:
314             current += c
315     if current != "":
316         ret.append(current)
317     return ret
318
319
320 class WinsDatabase(object):
321     """Samba 3 WINS database reader."""
322     def __init__(self, file):
323         self.entries = {}
324         f = open(file, 'r')
325         assert f.readline().rstrip("\n") == "VERSION 1 0"
326         for l in f.readlines():
327             if l[0] == "#": # skip comments
328                 continue
329             entries = shellsplit(l.rstrip("\n"))
330             name = entries[0]
331             ttl = int(entries[1])
332             i = 2
333             ips = []
334             while "." in entries[i]:
335                 ips.append(entries[i])
336                 i+=1
337             nb_flags = int(entries[i][:-1], 16)
338             assert not name in self.entries, "Name %s exists twice" % name
339             self.entries[name] = (ttl, ips, nb_flags)
340         f.close()
341
342     def __getitem__(self, name):
343         return self.entries[name]
344
345     def __len__(self):
346         return len(self.entries)
347
348     def __iter__(self):
349         return iter(self.entries)
350
351     def items(self):
352         """Return the entries in this WINS database."""
353         return self.entries.items()
354
355     def close(self): # for consistency
356         pass
357
358
359 class Samba3(object):
360     """Samba 3 configuration and state data reader."""
361
362     def __init__(self, smbconfpath, s3_lp_ctx=None):
363         """Open the configuration and data for a Samba 3 installation.
364
365         :param smbconfpath: Path to the smb.conf file.
366         :param s3_lp_ctx: Samba3 Loadparm context
367         """
368         self.smbconfpath = smbconfpath
369         if s3_lp_ctx:
370             self.lp = s3_lp_ctx
371         else:
372             self.lp = s3param.get_context()
373             self.lp.load(smbconfpath)
374
375     def statedir_path(self, path):
376         if path[0] == "/" or path[0] == ".":
377             return path
378         return os.path.join(self.lp.get("state directory"), path)
379
380     def privatedir_path(self, path):
381         if path[0] == "/" or path[0] == ".":
382             return path
383         return os.path.join(self.lp.get("private dir"), path)
384
385     def get_conf(self):
386         return self.lp
387
388     def get_sam_db(self):
389         return passdb.PDB(self.lp.get('passdb backend'))
390
391     def get_registry(self):
392         return Registry(self.statedir_path("registry"))
393
394     def get_secrets_db(self):
395         return SecretsDatabase(self.privatedir_path("secrets"))
396
397     def get_shareinfo_db(self):
398         return ShareInfoDatabase(self.statedir_path("share_info"))
399
400     def get_idmap_db(self):
401         return IdmapDatabase(self.statedir_path("winbindd_idmap"))
402
403     def get_wins_db(self):
404         return WinsDatabase(self.statedir_path("wins.dat"))
405
406     def get_shares(self):
407         return Shares(self.get_conf(), self.get_shareinfo_db())