s3-messages: only include messages.h where needed.
[vlendec/samba-autobuild/.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "serverid.h"
23 #include "dbwrap.h"
24
25 struct serverid_key {
26         pid_t pid;
27         uint32_t vnn;
28 };
29
30 struct serverid_data {
31         uint64_t unique_id;
32         uint32_t msg_flags;
33 };
34
35 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
36 {
37         struct tdb_wrap *db;
38
39         /*
40          * Open the tdb in the parent process (smbd) so that our
41          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
42          * work.
43          */
44
45         db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
46                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
47                            0644);
48         if (db == NULL) {
49                 DEBUG(1, ("could not open serverid.tdb: %s\n",
50                           strerror(errno)));
51                 return false;
52         }
53         return true;
54 }
55
56 static struct db_context *serverid_db(void)
57 {
58         static struct db_context *db;
59
60         if (db != NULL) {
61                 return db;
62         }
63         db = db_open(NULL, lock_path("serverid.tdb"), 0,
64                      TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT, 0644);
65         return db;
66 }
67
68 static void serverid_fill_key(const struct server_id *id,
69                               struct serverid_key *key)
70 {
71         ZERO_STRUCTP(key);
72         key->pid = id->pid;
73         key->vnn = id->vnn;
74 }
75
76 bool serverid_register(const struct server_id id, uint32_t msg_flags)
77 {
78         struct db_context *db;
79         struct serverid_key key;
80         struct serverid_data data;
81         struct db_record *rec;
82         TDB_DATA tdbkey, tdbdata;
83         NTSTATUS status;
84         bool ret = false;
85
86         db = serverid_db();
87         if (db == NULL) {
88                 return false;
89         }
90
91         serverid_fill_key(&id, &key);
92         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
93
94         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
95         if (rec == NULL) {
96                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
97                 return false;
98         }
99
100         ZERO_STRUCT(data);
101         data.unique_id = id.unique_id;
102         data.msg_flags = msg_flags;
103
104         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
105         status = rec->store(rec, tdbdata, 0);
106         if (!NT_STATUS_IS_OK(status)) {
107                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
108                           nt_errstr(status)));
109                 goto done;
110         }
111         ret = true;
112 done:
113         TALLOC_FREE(rec);
114         return ret;
115 }
116
117 bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
118                                  uint32_t msg_flags)
119 {
120         struct db_context *db;
121         struct serverid_key key;
122         struct serverid_data *data;
123         struct db_record *rec;
124         TDB_DATA tdbkey;
125         NTSTATUS status;
126         bool ret = false;
127
128         db = serverid_db();
129         if (db == NULL) {
130                 return false;
131         }
132
133         serverid_fill_key(&id, &key);
134         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
135
136         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
137         if (rec == NULL) {
138                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
139                 return false;
140         }
141
142         if (rec->value.dsize != sizeof(struct serverid_data)) {
143                 DEBUG(1, ("serverid record has unexpected size %d "
144                           "(wanted %d)\n", (int)rec->value.dsize,
145                           (int)sizeof(struct serverid_data)));
146                 goto done;
147         }
148
149         data = (struct serverid_data *)rec->value.dptr;
150
151         if (do_reg) {
152                 data->msg_flags |= msg_flags;
153         } else {
154                 data->msg_flags &= ~msg_flags;
155         }
156
157         status = rec->store(rec, rec->value, 0);
158         if (!NT_STATUS_IS_OK(status)) {
159                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
160                           nt_errstr(status)));
161                 goto done;
162         }
163         ret = true;
164 done:
165         TALLOC_FREE(rec);
166         return ret;
167 }
168
169 bool serverid_deregister(struct server_id id)
170 {
171         struct db_context *db;
172         struct serverid_key key;
173         struct db_record *rec;
174         TDB_DATA tdbkey;
175         NTSTATUS status;
176         bool ret = false;
177
178         db = serverid_db();
179         if (db == NULL) {
180                 return false;
181         }
182
183         serverid_fill_key(&id, &key);
184         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
185
186         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
187         if (rec == NULL) {
188                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
189                 return false;
190         }
191
192         status = rec->delete_rec(rec);
193         if (!NT_STATUS_IS_OK(status)) {
194                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
195                           nt_errstr(status)));
196                 goto done;
197         }
198         ret = true;
199 done:
200         TALLOC_FREE(rec);
201         return ret;
202 }
203
204 struct serverid_exists_state {
205         const struct server_id *id;
206         bool exists;
207 };
208
209 static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
210 {
211         struct serverid_exists_state *state =
212                 (struct serverid_exists_state *)priv;
213
214         if (data.dsize != sizeof(struct serverid_data)) {
215                 return -1;
216         }
217
218         /*
219          * Use memcmp, not direct compare. data.dptr might not be
220          * aligned.
221          */
222         state->exists = (memcmp(&state->id->unique_id, data.dptr,
223                                 sizeof(state->id->unique_id)) == 0);
224         return 0;
225 }
226
227 bool serverid_exists(const struct server_id *id)
228 {
229         struct db_context *db;
230         struct serverid_exists_state state;
231         struct serverid_key key;
232         TDB_DATA tdbkey;
233
234         if (lp_clustering() && !process_exists(*id)) {
235                 return false;
236         }
237
238         db = serverid_db();
239         if (db == NULL) {
240                 return false;
241         }
242
243         serverid_fill_key(id, &key);
244         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
245
246         state.id = id;
247         state.exists = false;
248
249         if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
250                 return false;
251         }
252         return state.exists;
253 }
254
255 static bool serverid_rec_parse(const struct db_record *rec,
256                                struct server_id *id, uint32_t *msg_flags)
257 {
258         struct serverid_key key;
259         struct serverid_data data;
260
261         if (rec->key.dsize != sizeof(key)) {
262                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
263                           (int)rec->key.dsize));
264                 return false;
265         }
266         if (rec->value.dsize != sizeof(data)) {
267                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
268                           (int)rec->value.dsize));
269                 return false;
270         }
271
272         memcpy(&key, rec->key.dptr, sizeof(key));
273         memcpy(&data, rec->value.dptr, sizeof(data));
274
275         id->pid = key.pid;
276         id->vnn = key.vnn;
277         id->unique_id = data.unique_id;
278         *msg_flags = data.msg_flags;
279         return true;
280 }
281
282 struct serverid_traverse_read_state {
283         int (*fn)(const struct server_id *id, uint32_t msg_flags,
284                   void *private_data);
285         void *private_data;
286 };
287
288 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
289 {
290         struct serverid_traverse_read_state *state =
291                 (struct serverid_traverse_read_state *)private_data;
292         struct server_id id;
293         uint32_t msg_flags;
294
295         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
296                 return 0;
297         }
298         return state->fn(&id, msg_flags,state->private_data);
299 }
300
301 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
302                                       uint32_t msg_flags, void *private_data),
303                             void *private_data)
304 {
305         struct db_context *db;
306         struct serverid_traverse_read_state state;
307
308         db = serverid_db();
309         if (db == NULL) {
310                 return false;
311         }
312         state.fn = fn;
313         state.private_data = private_data;
314         return db->traverse_read(db, serverid_traverse_read_fn, &state);
315 }
316
317 struct serverid_traverse_state {
318         int (*fn)(struct db_record *rec, const struct server_id *id,
319                   uint32_t msg_flags, void *private_data);
320         void *private_data;
321 };
322
323 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
324 {
325         struct serverid_traverse_state *state =
326                 (struct serverid_traverse_state *)private_data;
327         struct server_id id;
328         uint32_t msg_flags;
329
330         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
331                 return 0;
332         }
333         return state->fn(rec, &id, msg_flags, state->private_data);
334 }
335
336 bool serverid_traverse(int (*fn)(struct db_record *rec,
337                                  const struct server_id *id,
338                                  uint32_t msg_flags, void *private_data),
339                             void *private_data)
340 {
341         struct db_context *db;
342         struct serverid_traverse_state state;
343
344         db = serverid_db();
345         if (db == NULL) {
346                 return false;
347         }
348         state.fn = fn;
349         state.private_data = private_data;
350         return db->traverse(db, serverid_traverse_fn, &state);
351 }