s3:serverid: use dbwrap_traverse_read() in serverid_traverse_read()
[samba.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "serverid.h"
23 #include "util_tdb.h"
24 #include "dbwrap/dbwrap.h"
25 #include "dbwrap/dbwrap_open.h"
26 #include "lib/util/tdb_wrap.h"
27
28 struct serverid_key {
29         pid_t pid;
30         uint32_t task_id;
31         uint32_t vnn;
32 };
33
34 struct serverid_data {
35         uint64_t unique_id;
36         uint32_t msg_flags;
37 };
38
39 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
40 {
41         struct tdb_wrap *db;
42
43         /*
44          * Open the tdb in the parent process (smbd) so that our
45          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
46          * work.
47          */
48
49         db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
50                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
51                            0644);
52         if (db == NULL) {
53                 DEBUG(1, ("could not open serverid.tdb: %s\n",
54                           strerror(errno)));
55                 return false;
56         }
57         return true;
58 }
59
60 static struct db_context *serverid_db(void)
61 {
62         static struct db_context *db;
63
64         if (db != NULL) {
65                 return db;
66         }
67         db = db_open(NULL, lock_path("serverid.tdb"), 0,
68                      TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT, 0644);
69         return db;
70 }
71
72 static void serverid_fill_key(const struct server_id *id,
73                               struct serverid_key *key)
74 {
75         ZERO_STRUCTP(key);
76         key->pid = id->pid;
77         key->task_id = id->task_id;
78         key->vnn = id->vnn;
79 }
80
81 bool serverid_register(const struct server_id id, uint32_t msg_flags)
82 {
83         struct db_context *db;
84         struct serverid_key key;
85         struct serverid_data data;
86         struct db_record *rec;
87         TDB_DATA tdbkey, tdbdata;
88         NTSTATUS status;
89         bool ret = false;
90
91         db = serverid_db();
92         if (db == NULL) {
93                 return false;
94         }
95
96         serverid_fill_key(&id, &key);
97         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
98
99         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
100         if (rec == NULL) {
101                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
102                 return false;
103         }
104
105         ZERO_STRUCT(data);
106         data.unique_id = id.unique_id;
107         data.msg_flags = msg_flags;
108
109         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
110         status = rec->store(rec, tdbdata, 0);
111         if (!NT_STATUS_IS_OK(status)) {
112                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
113                           nt_errstr(status)));
114                 goto done;
115         }
116         ret = true;
117 done:
118         TALLOC_FREE(rec);
119         return ret;
120 }
121
122 bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
123                                  uint32_t msg_flags)
124 {
125         struct db_context *db;
126         struct serverid_key key;
127         struct serverid_data *data;
128         struct db_record *rec;
129         TDB_DATA tdbkey;
130         NTSTATUS status;
131         bool ret = false;
132
133         db = serverid_db();
134         if (db == NULL) {
135                 return false;
136         }
137
138         serverid_fill_key(&id, &key);
139         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
140
141         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
142         if (rec == NULL) {
143                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
144                 return false;
145         }
146
147         if (rec->value.dsize != sizeof(struct serverid_data)) {
148                 DEBUG(1, ("serverid record has unexpected size %d "
149                           "(wanted %d)\n", (int)rec->value.dsize,
150                           (int)sizeof(struct serverid_data)));
151                 goto done;
152         }
153
154         data = (struct serverid_data *)rec->value.dptr;
155
156         if (do_reg) {
157                 data->msg_flags |= msg_flags;
158         } else {
159                 data->msg_flags &= ~msg_flags;
160         }
161
162         status = rec->store(rec, rec->value, 0);
163         if (!NT_STATUS_IS_OK(status)) {
164                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
165                           nt_errstr(status)));
166                 goto done;
167         }
168         ret = true;
169 done:
170         TALLOC_FREE(rec);
171         return ret;
172 }
173
174 bool serverid_deregister(struct server_id id)
175 {
176         struct db_context *db;
177         struct serverid_key key;
178         struct db_record *rec;
179         TDB_DATA tdbkey;
180         NTSTATUS status;
181         bool ret = false;
182
183         db = serverid_db();
184         if (db == NULL) {
185                 return false;
186         }
187
188         serverid_fill_key(&id, &key);
189         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
190
191         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
192         if (rec == NULL) {
193                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
194                 return false;
195         }
196
197         status = rec->delete_rec(rec);
198         if (!NT_STATUS_IS_OK(status)) {
199                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
200                           nt_errstr(status)));
201                 goto done;
202         }
203         ret = true;
204 done:
205         TALLOC_FREE(rec);
206         return ret;
207 }
208
209 struct serverid_exists_state {
210         const struct server_id *id;
211         bool exists;
212 };
213
214 static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
215 {
216         struct serverid_exists_state *state =
217                 (struct serverid_exists_state *)priv;
218
219         if (data.dsize != sizeof(struct serverid_data)) {
220                 return -1;
221         }
222
223         /*
224          * Use memcmp, not direct compare. data.dptr might not be
225          * aligned.
226          */
227         state->exists = (memcmp(&state->id->unique_id, data.dptr,
228                                 sizeof(state->id->unique_id)) == 0);
229         return 0;
230 }
231
232 bool serverid_exists(const struct server_id *id)
233 {
234         struct db_context *db;
235         struct serverid_exists_state state;
236         struct serverid_key key;
237         TDB_DATA tdbkey;
238
239         if (procid_is_me(id)) {
240                 return true;
241         }
242
243         if (!process_exists(*id)) {
244                 return false;
245         }
246
247         db = serverid_db();
248         if (db == NULL) {
249                 return false;
250         }
251
252         serverid_fill_key(id, &key);
253         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
254
255         state.id = id;
256         state.exists = false;
257
258         if (db->parse_record(db, tdbkey, server_exists_parse, &state) != 0) {
259                 return false;
260         }
261         return state.exists;
262 }
263
264 static bool serverid_rec_parse(const struct db_record *rec,
265                                struct server_id *id, uint32_t *msg_flags)
266 {
267         struct serverid_key key;
268         struct serverid_data data;
269
270         if (rec->key.dsize != sizeof(key)) {
271                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
272                           (int)rec->key.dsize));
273                 return false;
274         }
275         if (rec->value.dsize != sizeof(data)) {
276                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
277                           (int)rec->value.dsize));
278                 return false;
279         }
280
281         memcpy(&key, rec->key.dptr, sizeof(key));
282         memcpy(&data, rec->value.dptr, sizeof(data));
283
284         id->pid = key.pid;
285         id->task_id = key.task_id;
286         id->vnn = key.vnn;
287         id->unique_id = data.unique_id;
288         *msg_flags = data.msg_flags;
289         return true;
290 }
291
292 struct serverid_traverse_read_state {
293         int (*fn)(const struct server_id *id, uint32_t msg_flags,
294                   void *private_data);
295         void *private_data;
296 };
297
298 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
299 {
300         struct serverid_traverse_read_state *state =
301                 (struct serverid_traverse_read_state *)private_data;
302         struct server_id id;
303         uint32_t msg_flags;
304
305         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
306                 return 0;
307         }
308         return state->fn(&id, msg_flags,state->private_data);
309 }
310
311 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
312                                       uint32_t msg_flags, void *private_data),
313                             void *private_data)
314 {
315         struct db_context *db;
316         struct serverid_traverse_read_state state;
317         NTSTATUS status;
318
319         db = serverid_db();
320         if (db == NULL) {
321                 return false;
322         }
323         state.fn = fn;
324         state.private_data = private_data;
325
326         status = dbwrap_traverse_read(db, serverid_traverse_read_fn, &state,
327                                       NULL);
328         return NT_STATUS_IS_OK(status);
329 }
330
331 struct serverid_traverse_state {
332         int (*fn)(struct db_record *rec, const struct server_id *id,
333                   uint32_t msg_flags, void *private_data);
334         void *private_data;
335 };
336
337 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
338 {
339         struct serverid_traverse_state *state =
340                 (struct serverid_traverse_state *)private_data;
341         struct server_id id;
342         uint32_t msg_flags;
343
344         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
345                 return 0;
346         }
347         return state->fn(rec, &id, msg_flags, state->private_data);
348 }
349
350 bool serverid_traverse(int (*fn)(struct db_record *rec,
351                                  const struct server_id *id,
352                                  uint32_t msg_flags, void *private_data),
353                             void *private_data)
354 {
355         struct db_context *db;
356         struct serverid_traverse_state state;
357         NTSTATUS status;
358
359         db = serverid_db();
360         if (db == NULL) {
361                 return false;
362         }
363         state.fn = fn;
364         state.private_data = private_data;
365
366         status = dbwrap_traverse(db, serverid_traverse_fn, &state, NULL);
367         return NT_STATUS_IS_OK(status);
368 }