s3:lib: implement serverid_exists() as wrapper of serverids_exist()
[kai/samba.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "serverid.h"
23 #include "util_tdb.h"
24 #include "dbwrap/dbwrap.h"
25 #include "dbwrap/dbwrap_open.h"
26 #include "lib/tdb_wrap/tdb_wrap.h"
27 #include "lib/param/param.h"
28 #include "ctdbd_conn.h"
29 #include "messages.h"
30
31 struct serverid_key {
32         pid_t pid;
33         uint32_t task_id;
34         uint32_t vnn;
35 };
36
37 struct serverid_data {
38         uint64_t unique_id;
39         uint32_t msg_flags;
40 };
41
42 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
43 {
44         struct tdb_wrap *db;
45         struct loadparm_context *lp_ctx;
46
47         lp_ctx = loadparm_init_s3(mem_ctx, loadparm_s3_helpers());
48         if (lp_ctx == NULL) {
49                 DEBUG(0, ("loadparm_init_s3 failed\n"));
50                 return false;
51         }
52
53         /*
54          * Open the tdb in the parent process (smbd) so that our
55          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
56          * work.
57          */
58
59         db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
60                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
61                            0644, lp_ctx);
62         talloc_unlink(mem_ctx, lp_ctx);
63         if (db == NULL) {
64                 DEBUG(1, ("could not open serverid.tdb: %s\n",
65                           strerror(errno)));
66                 return false;
67         }
68         return true;
69 }
70
71 static struct db_context *serverid_db(void)
72 {
73         static struct db_context *db;
74
75         if (db != NULL) {
76                 return db;
77         }
78         db = db_open(NULL, lock_path("serverid.tdb"), 0,
79                      TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH,
80                      O_RDWR|O_CREAT, 0644, DBWRAP_LOCK_ORDER_2);
81         return db;
82 }
83
84 static void serverid_fill_key(const struct server_id *id,
85                               struct serverid_key *key)
86 {
87         ZERO_STRUCTP(key);
88         key->pid = id->pid;
89         key->task_id = id->task_id;
90         key->vnn = id->vnn;
91 }
92
93 bool serverid_register(const struct server_id id, uint32_t msg_flags)
94 {
95         struct db_context *db;
96         struct serverid_key key;
97         struct serverid_data data;
98         struct db_record *rec;
99         TDB_DATA tdbkey, tdbdata;
100         NTSTATUS status;
101         bool ret = false;
102
103         db = serverid_db();
104         if (db == NULL) {
105                 return false;
106         }
107
108         serverid_fill_key(&id, &key);
109         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
110
111         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
112         if (rec == NULL) {
113                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
114                 return false;
115         }
116
117         ZERO_STRUCT(data);
118         data.unique_id = id.unique_id;
119         data.msg_flags = msg_flags;
120
121         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
122         status = dbwrap_record_store(rec, tdbdata, 0);
123         if (!NT_STATUS_IS_OK(status)) {
124                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
125                           nt_errstr(status)));
126                 goto done;
127         }
128 #ifdef HAVE_CTDB_CONTROL_CHECK_SRVIDS_DECL
129         if (lp_clustering()) {
130                 register_with_ctdbd(messaging_ctdbd_connection(), id.unique_id);
131         }
132 #endif
133         ret = true;
134 done:
135         TALLOC_FREE(rec);
136         return ret;
137 }
138
139 bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
140                                  uint32_t msg_flags)
141 {
142         struct db_context *db;
143         struct serverid_key key;
144         struct serverid_data *data;
145         struct db_record *rec;
146         TDB_DATA tdbkey;
147         TDB_DATA value;
148         NTSTATUS status;
149         bool ret = false;
150
151         db = serverid_db();
152         if (db == NULL) {
153                 return false;
154         }
155
156         serverid_fill_key(&id, &key);
157         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
158
159         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
160         if (rec == NULL) {
161                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
162                 return false;
163         }
164
165         value = dbwrap_record_get_value(rec);
166
167         if (value.dsize != sizeof(struct serverid_data)) {
168                 DEBUG(1, ("serverid record has unexpected size %d "
169                           "(wanted %d)\n", (int)value.dsize,
170                           (int)sizeof(struct serverid_data)));
171                 goto done;
172         }
173
174         data = (struct serverid_data *)value.dptr;
175
176         if (do_reg) {
177                 data->msg_flags |= msg_flags;
178         } else {
179                 data->msg_flags &= ~msg_flags;
180         }
181
182         status = dbwrap_record_store(rec, value, 0);
183         if (!NT_STATUS_IS_OK(status)) {
184                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
185                           nt_errstr(status)));
186                 goto done;
187         }
188         ret = true;
189 done:
190         TALLOC_FREE(rec);
191         return ret;
192 }
193
194 bool serverid_deregister(struct server_id id)
195 {
196         struct db_context *db;
197         struct serverid_key key;
198         struct db_record *rec;
199         TDB_DATA tdbkey;
200         NTSTATUS status;
201         bool ret = false;
202
203         db = serverid_db();
204         if (db == NULL) {
205                 return false;
206         }
207
208         serverid_fill_key(&id, &key);
209         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
210
211         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
212         if (rec == NULL) {
213                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
214                 return false;
215         }
216
217         status = dbwrap_record_delete(rec);
218         if (!NT_STATUS_IS_OK(status)) {
219                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
220                           nt_errstr(status)));
221                 goto done;
222         }
223         ret = true;
224 done:
225         TALLOC_FREE(rec);
226         return ret;
227 }
228
229 struct serverid_exists_state {
230         const struct server_id *id;
231         bool exists;
232 };
233
234 static void server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
235 {
236         struct serverid_exists_state *state =
237                 (struct serverid_exists_state *)priv;
238
239         if (data.dsize != sizeof(struct serverid_data)) {
240                 state->exists = false;
241                 return;
242         }
243
244         /*
245          * Use memcmp, not direct compare. data.dptr might not be
246          * aligned.
247          */
248         state->exists = (memcmp(&state->id->unique_id, data.dptr,
249                                 sizeof(state->id->unique_id)) == 0);
250 }
251
252 bool serverid_exists(const struct server_id *id)
253 {
254         bool result = false;
255         bool ok = false;
256
257         ok = serverids_exist(id, 1, &result);
258         if (!ok) {
259                 return false;
260         }
261
262         return result;
263 }
264
265 bool serverids_exist(const struct server_id *ids, int num_ids, bool *results)
266 {
267         struct db_context *db;
268         int i;
269
270         if (!processes_exist(ids, num_ids, results)) {
271                 return false;
272         }
273
274         db = serverid_db();
275         if (db == NULL) {
276                 return false;
277         }
278
279         for (i=0; i<num_ids; i++) {
280                 struct serverid_exists_state state;
281                 struct serverid_key key;
282                 TDB_DATA tdbkey;
283                 NTSTATUS status;
284
285                 if (ids[i].unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
286                         results[i] = true;
287                         continue;
288                 }
289                 if (!results[i]) {
290                         continue;
291                 }
292
293                 serverid_fill_key(&ids[i], &key);
294                 tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
295
296                 state.id = &ids[i];
297                 state.exists = false;
298                 status = dbwrap_parse_record(db, tdbkey, server_exists_parse, &state);
299                 if (!NT_STATUS_IS_OK(status)) {
300                         results[i] = false;
301                         continue;
302                 }
303                 results[i] = state.exists;
304         }
305         return true;
306 }
307
308 static bool serverid_rec_parse(const struct db_record *rec,
309                                struct server_id *id, uint32_t *msg_flags)
310 {
311         struct serverid_key key;
312         struct serverid_data data;
313         TDB_DATA tdbkey;
314         TDB_DATA tdbdata;
315
316         tdbkey = dbwrap_record_get_key(rec);
317         tdbdata = dbwrap_record_get_value(rec);
318
319         if (tdbkey.dsize != sizeof(key)) {
320                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
321                           (int)tdbkey.dsize));
322                 return false;
323         }
324         if (tdbdata.dsize != sizeof(data)) {
325                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
326                           (int)tdbdata.dsize));
327                 return false;
328         }
329
330         memcpy(&key, tdbkey.dptr, sizeof(key));
331         memcpy(&data, tdbdata.dptr, sizeof(data));
332
333         id->pid = key.pid;
334         id->task_id = key.task_id;
335         id->vnn = key.vnn;
336         id->unique_id = data.unique_id;
337         *msg_flags = data.msg_flags;
338         return true;
339 }
340
341 struct serverid_traverse_read_state {
342         int (*fn)(const struct server_id *id, uint32_t msg_flags,
343                   void *private_data);
344         void *private_data;
345 };
346
347 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
348 {
349         struct serverid_traverse_read_state *state =
350                 (struct serverid_traverse_read_state *)private_data;
351         struct server_id id;
352         uint32_t msg_flags;
353
354         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
355                 return 0;
356         }
357         return state->fn(&id, msg_flags,state->private_data);
358 }
359
360 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
361                                       uint32_t msg_flags, void *private_data),
362                             void *private_data)
363 {
364         struct db_context *db;
365         struct serverid_traverse_read_state state;
366         NTSTATUS status;
367
368         db = serverid_db();
369         if (db == NULL) {
370                 return false;
371         }
372         state.fn = fn;
373         state.private_data = private_data;
374
375         status = dbwrap_traverse_read(db, serverid_traverse_read_fn, &state,
376                                       NULL);
377         return NT_STATUS_IS_OK(status);
378 }
379
380 struct serverid_traverse_state {
381         int (*fn)(struct db_record *rec, const struct server_id *id,
382                   uint32_t msg_flags, void *private_data);
383         void *private_data;
384 };
385
386 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
387 {
388         struct serverid_traverse_state *state =
389                 (struct serverid_traverse_state *)private_data;
390         struct server_id id;
391         uint32_t msg_flags;
392
393         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
394                 return 0;
395         }
396         return state->fn(rec, &id, msg_flags, state->private_data);
397 }
398
399 bool serverid_traverse(int (*fn)(struct db_record *rec,
400                                  const struct server_id *id,
401                                  uint32_t msg_flags, void *private_data),
402                             void *private_data)
403 {
404         struct db_context *db;
405         struct serverid_traverse_state state;
406         NTSTATUS status;
407
408         db = serverid_db();
409         if (db == NULL) {
410                 return false;
411         }
412         state.fn = fn;
413         state.private_data = private_data;
414
415         status = dbwrap_traverse(db, serverid_traverse_fn, &state, NULL);
416         return NT_STATUS_IS_OK(status);
417 }
418
419 uint64_t serverid_get_random_unique_id(void)
420 {
421         uint64_t unique_id = SERVERID_UNIQUE_ID_NOT_TO_VERIFY;
422
423         while (unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
424                 generate_random_buffer((uint8_t *)&unique_id,
425                                        sizeof(unique_id));
426         }
427
428         return unique_id;
429 }