s3-g_lock: Properly free "rec" on retry to avoid deadlock
[kai/samba.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "serverid.h"
23 #include "util_tdb.h"
24 #include "dbwrap/dbwrap.h"
25 #include "dbwrap/dbwrap_open.h"
26 #include "lib/tdb_wrap/tdb_wrap.h"
27 #include "lib/param/param.h"
28 #include "ctdbd_conn.h"
29 #include "messages.h"
30
31 struct serverid_key {
32         pid_t pid;
33         uint32_t task_id;
34         uint32_t vnn;
35 };
36
37 struct serverid_data {
38         uint64_t unique_id;
39         uint32_t msg_flags;
40 };
41
42 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
43 {
44         struct tdb_wrap *db;
45         struct loadparm_context *lp_ctx;
46
47         lp_ctx = loadparm_init_s3(mem_ctx, loadparm_s3_helpers());
48         if (lp_ctx == NULL) {
49                 DEBUG(0, ("loadparm_init_s3 failed\n"));
50                 return false;
51         }
52
53         /*
54          * Open the tdb in the parent process (smbd) so that our
55          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
56          * work.
57          */
58
59         db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
60                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
61                            0644, lp_ctx);
62         talloc_unlink(mem_ctx, lp_ctx);
63         if (db == NULL) {
64                 DEBUG(1, ("could not open serverid.tdb: %s\n",
65                           strerror(errno)));
66                 return false;
67         }
68         return true;
69 }
70
71 static struct db_context *serverid_db(void)
72 {
73         static struct db_context *db;
74
75         if (db != NULL) {
76                 return db;
77         }
78         db = db_open(NULL, lock_path("serverid.tdb"), 0,
79                      TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH,
80                      O_RDWR|O_CREAT, 0644, DBWRAP_LOCK_ORDER_2);
81         return db;
82 }
83
84 static void serverid_fill_key(const struct server_id *id,
85                               struct serverid_key *key)
86 {
87         ZERO_STRUCTP(key);
88         key->pid = id->pid;
89         key->task_id = id->task_id;
90         key->vnn = id->vnn;
91 }
92
93 bool serverid_register(const struct server_id id, uint32_t msg_flags)
94 {
95         struct db_context *db;
96         struct serverid_key key;
97         struct serverid_data data;
98         struct db_record *rec;
99         TDB_DATA tdbkey, tdbdata;
100         NTSTATUS status;
101         bool ret = false;
102
103         db = serverid_db();
104         if (db == NULL) {
105                 return false;
106         }
107
108         serverid_fill_key(&id, &key);
109         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
110
111         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
112         if (rec == NULL) {
113                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
114                 return false;
115         }
116
117         ZERO_STRUCT(data);
118         data.unique_id = id.unique_id;
119         data.msg_flags = msg_flags;
120
121         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
122         status = dbwrap_record_store(rec, tdbdata, 0);
123         if (!NT_STATUS_IS_OK(status)) {
124                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
125                           nt_errstr(status)));
126                 goto done;
127         }
128 #ifdef HAVE_CTDB_CONTROL_CHECK_SRVIDS_DECL
129         if (lp_clustering()) {
130                 register_with_ctdbd(messaging_ctdbd_connection(), id.unique_id);
131         }
132 #endif
133         ret = true;
134 done:
135         TALLOC_FREE(rec);
136         return ret;
137 }
138
139 bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
140                                  uint32_t msg_flags)
141 {
142         struct db_context *db;
143         struct serverid_key key;
144         struct serverid_data *data;
145         struct db_record *rec;
146         TDB_DATA tdbkey;
147         TDB_DATA value;
148         NTSTATUS status;
149         bool ret = false;
150
151         db = serverid_db();
152         if (db == NULL) {
153                 return false;
154         }
155
156         serverid_fill_key(&id, &key);
157         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
158
159         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
160         if (rec == NULL) {
161                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
162                 return false;
163         }
164
165         value = dbwrap_record_get_value(rec);
166
167         if (value.dsize != sizeof(struct serverid_data)) {
168                 DEBUG(1, ("serverid record has unexpected size %d "
169                           "(wanted %d)\n", (int)value.dsize,
170                           (int)sizeof(struct serverid_data)));
171                 goto done;
172         }
173
174         data = (struct serverid_data *)value.dptr;
175
176         if (do_reg) {
177                 data->msg_flags |= msg_flags;
178         } else {
179                 data->msg_flags &= ~msg_flags;
180         }
181
182         status = dbwrap_record_store(rec, value, 0);
183         if (!NT_STATUS_IS_OK(status)) {
184                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
185                           nt_errstr(status)));
186                 goto done;
187         }
188         ret = true;
189 done:
190         TALLOC_FREE(rec);
191         return ret;
192 }
193
194 bool serverid_deregister(struct server_id id)
195 {
196         struct db_context *db;
197         struct serverid_key key;
198         struct db_record *rec;
199         TDB_DATA tdbkey;
200         NTSTATUS status;
201         bool ret = false;
202
203         db = serverid_db();
204         if (db == NULL) {
205                 return false;
206         }
207
208         serverid_fill_key(&id, &key);
209         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
210
211         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
212         if (rec == NULL) {
213                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
214                 return false;
215         }
216
217         status = dbwrap_record_delete(rec);
218         if (!NT_STATUS_IS_OK(status)) {
219                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
220                           nt_errstr(status)));
221                 goto done;
222         }
223         ret = true;
224 done:
225         TALLOC_FREE(rec);
226         return ret;
227 }
228
229 struct serverid_exists_state {
230         const struct server_id *id;
231         bool exists;
232 };
233
234 static void server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
235 {
236         struct serverid_exists_state *state =
237                 (struct serverid_exists_state *)priv;
238
239         if (data.dsize != sizeof(struct serverid_data)) {
240                 state->exists = false;
241                 return;
242         }
243
244         /*
245          * Use memcmp, not direct compare. data.dptr might not be
246          * aligned.
247          */
248         state->exists = (memcmp(&state->id->unique_id, data.dptr,
249                                 sizeof(state->id->unique_id)) == 0);
250 }
251
252 bool serverid_exists(const struct server_id *id)
253 {
254         struct db_context *db;
255         struct serverid_exists_state state;
256         struct serverid_key key;
257         TDB_DATA tdbkey;
258         NTSTATUS status;
259
260         if (procid_is_me(id)) {
261                 return true;
262         }
263
264         if (!process_exists(*id)) {
265                 return false;
266         }
267
268         if (id->unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
269                 return true;
270         }
271
272         db = serverid_db();
273         if (db == NULL) {
274                 return false;
275         }
276
277         serverid_fill_key(id, &key);
278         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
279
280         state.id = id;
281         state.exists = false;
282
283         status = dbwrap_parse_record(db, tdbkey, server_exists_parse, &state);
284         if (!NT_STATUS_IS_OK(status)) {
285                 return false;
286         }
287         return state.exists;
288 }
289
290 bool serverids_exist(const struct server_id *ids, int num_ids, bool *results)
291 {
292         struct db_context *db;
293         int i;
294
295 #ifdef HAVE_CTDB_CONTROL_CHECK_SRVIDS_DECL
296         if (lp_clustering()) {
297                 return ctdb_serverids_exist(messaging_ctdbd_connection(),
298                                             ids, num_ids, results);
299         }
300 #endif
301         if (!processes_exist(ids, num_ids, results)) {
302                 return false;
303         }
304
305         db = serverid_db();
306         if (db == NULL) {
307                 return false;
308         }
309
310         for (i=0; i<num_ids; i++) {
311                 struct serverid_exists_state state;
312                 struct serverid_key key;
313                 TDB_DATA tdbkey;
314                 NTSTATUS status;
315
316                 if (ids[i].unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
317                         results[i] = true;
318                         continue;
319                 }
320                 if (!results[i]) {
321                         continue;
322                 }
323
324                 serverid_fill_key(&ids[i], &key);
325                 tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
326
327                 state.id = &ids[i];
328                 state.exists = false;
329                 status = dbwrap_parse_record(db, tdbkey, server_exists_parse, &state);
330                 if (!NT_STATUS_IS_OK(status)) {
331                         results[i] = false;
332                         continue;
333                 }
334                 results[i] = state.exists;
335         }
336         return true;
337 }
338
339 static bool serverid_rec_parse(const struct db_record *rec,
340                                struct server_id *id, uint32_t *msg_flags)
341 {
342         struct serverid_key key;
343         struct serverid_data data;
344         TDB_DATA tdbkey;
345         TDB_DATA tdbdata;
346
347         tdbkey = dbwrap_record_get_key(rec);
348         tdbdata = dbwrap_record_get_value(rec);
349
350         if (tdbkey.dsize != sizeof(key)) {
351                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
352                           (int)tdbkey.dsize));
353                 return false;
354         }
355         if (tdbdata.dsize != sizeof(data)) {
356                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
357                           (int)tdbdata.dsize));
358                 return false;
359         }
360
361         memcpy(&key, tdbkey.dptr, sizeof(key));
362         memcpy(&data, tdbdata.dptr, sizeof(data));
363
364         id->pid = key.pid;
365         id->task_id = key.task_id;
366         id->vnn = key.vnn;
367         id->unique_id = data.unique_id;
368         *msg_flags = data.msg_flags;
369         return true;
370 }
371
372 struct serverid_traverse_read_state {
373         int (*fn)(const struct server_id *id, uint32_t msg_flags,
374                   void *private_data);
375         void *private_data;
376 };
377
378 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
379 {
380         struct serverid_traverse_read_state *state =
381                 (struct serverid_traverse_read_state *)private_data;
382         struct server_id id;
383         uint32_t msg_flags;
384
385         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
386                 return 0;
387         }
388         return state->fn(&id, msg_flags,state->private_data);
389 }
390
391 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
392                                       uint32_t msg_flags, void *private_data),
393                             void *private_data)
394 {
395         struct db_context *db;
396         struct serverid_traverse_read_state state;
397         NTSTATUS status;
398
399         db = serverid_db();
400         if (db == NULL) {
401                 return false;
402         }
403         state.fn = fn;
404         state.private_data = private_data;
405
406         status = dbwrap_traverse_read(db, serverid_traverse_read_fn, &state,
407                                       NULL);
408         return NT_STATUS_IS_OK(status);
409 }
410
411 struct serverid_traverse_state {
412         int (*fn)(struct db_record *rec, const struct server_id *id,
413                   uint32_t msg_flags, void *private_data);
414         void *private_data;
415 };
416
417 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
418 {
419         struct serverid_traverse_state *state =
420                 (struct serverid_traverse_state *)private_data;
421         struct server_id id;
422         uint32_t msg_flags;
423
424         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
425                 return 0;
426         }
427         return state->fn(rec, &id, msg_flags, state->private_data);
428 }
429
430 bool serverid_traverse(int (*fn)(struct db_record *rec,
431                                  const struct server_id *id,
432                                  uint32_t msg_flags, void *private_data),
433                             void *private_data)
434 {
435         struct db_context *db;
436         struct serverid_traverse_state state;
437         NTSTATUS status;
438
439         db = serverid_db();
440         if (db == NULL) {
441                 return false;
442         }
443         state.fn = fn;
444         state.private_data = private_data;
445
446         status = dbwrap_traverse(db, serverid_traverse_fn, &state, NULL);
447         return NT_STATUS_IS_OK(status);
448 }
449
450 uint64_t serverid_get_random_unique_id(void)
451 {
452         uint64_t unique_id = SERVERID_UNIQUE_ID_NOT_TO_VERIFY;
453
454         while (unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
455                 generate_random_buffer((uint8_t *)&unique_id,
456                                        sizeof(unique_id));
457         }
458
459         return unique_id;
460 }
461
462 bool serverid_equal(const struct server_id *p1, const struct server_id *p2)
463 {
464         if (p1->pid != p2->pid) {
465                 return false;
466         }
467
468         if (p1->task_id != p2->task_id) {
469                 return false;
470         }
471
472         if (p1->vnn != p2->vnn) {
473                 return false;
474         }
475
476         if (p1->unique_id != p2->unique_id) {
477                 return false;
478         }
479
480         return true;
481 }