lib: Move "message_send_all" to serverid.c
[nivanova/samba-autobuild/.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "serverid.h"
23 #include "util_tdb.h"
24 #include "dbwrap/dbwrap.h"
25 #include "dbwrap/dbwrap_open.h"
26 #include "lib/tdb_wrap/tdb_wrap.h"
27 #include "lib/param/param.h"
28 #include "ctdbd_conn.h"
29 #include "messages.h"
30 #include "lib/messages_dgm.h"
31
32 struct serverid_key {
33         pid_t pid;
34         uint32_t task_id;
35         uint32_t vnn;
36 };
37
38 struct serverid_data {
39         uint64_t unique_id;
40         uint32_t msg_flags;
41 };
42
43 static struct db_context *serverid_db(void)
44 {
45         static struct db_context *db;
46         char *db_path;
47
48         if (db != NULL) {
49                 return db;
50         }
51
52         db_path = lock_path("serverid.tdb");
53         if (db_path == NULL) {
54                 return NULL;
55         }
56
57         db = db_open(NULL, db_path, 0,
58                      TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH,
59                      O_RDWR|O_CREAT, 0644, DBWRAP_LOCK_ORDER_2,
60                      DBWRAP_FLAG_NONE);
61         TALLOC_FREE(db_path);
62         return db;
63 }
64
65 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
66 {
67         struct db_context *db;
68
69         db = serverid_db();
70         if (db == NULL) {
71                 DEBUG(1, ("could not open serverid.tdb: %s\n",
72                           strerror(errno)));
73                 return false;
74         }
75
76         return true;
77 }
78
79 static void serverid_fill_key(const struct server_id *id,
80                               struct serverid_key *key)
81 {
82         ZERO_STRUCTP(key);
83         key->pid = id->pid;
84         key->task_id = id->task_id;
85         key->vnn = id->vnn;
86 }
87
88 bool serverid_register(const struct server_id id, uint32_t msg_flags)
89 {
90         struct db_context *db;
91         struct serverid_key key;
92         struct serverid_data data;
93         struct db_record *rec;
94         TDB_DATA tdbkey, tdbdata;
95         NTSTATUS status;
96         bool ret = false;
97
98         db = serverid_db();
99         if (db == NULL) {
100                 return false;
101         }
102
103         serverid_fill_key(&id, &key);
104         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
105
106         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
107         if (rec == NULL) {
108                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
109                 return false;
110         }
111
112         ZERO_STRUCT(data);
113         data.unique_id = id.unique_id;
114         data.msg_flags = msg_flags;
115
116         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
117         status = dbwrap_record_store(rec, tdbdata, 0);
118         if (!NT_STATUS_IS_OK(status)) {
119                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
120                           nt_errstr(status)));
121                 goto done;
122         }
123
124         if (lp_clustering()) {
125                 register_with_ctdbd(messaging_ctdbd_connection(), id.unique_id,
126                                     NULL, NULL);
127         }
128
129         ret = true;
130 done:
131         TALLOC_FREE(rec);
132         return ret;
133 }
134
135 bool serverid_deregister(struct server_id id)
136 {
137         struct db_context *db;
138         struct serverid_key key;
139         struct db_record *rec;
140         TDB_DATA tdbkey;
141         NTSTATUS status;
142         bool ret = false;
143
144         db = serverid_db();
145         if (db == NULL) {
146                 return false;
147         }
148
149         serverid_fill_key(&id, &key);
150         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
151
152         rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
153         if (rec == NULL) {
154                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
155                 return false;
156         }
157
158         status = dbwrap_record_delete(rec);
159         if (!NT_STATUS_IS_OK(status)) {
160                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
161                           nt_errstr(status)));
162                 goto done;
163         }
164         ret = true;
165 done:
166         TALLOC_FREE(rec);
167         return ret;
168 }
169
170 static bool serverid_exists_local(const struct server_id *id)
171 {
172         bool exists = process_exists_by_pid(id->pid);
173         uint64_t unique;
174         int ret;
175
176         if (!exists) {
177                 return false;
178         }
179
180         if (id->unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
181                 return true;
182         }
183
184         ret = messaging_dgm_get_unique(id->pid, &unique);
185         if (ret != 0) {
186                 return false;
187         }
188
189         return (unique == id->unique_id);
190 }
191
192 bool serverid_exists(const struct server_id *id)
193 {
194         if (procid_is_local(id)) {
195                 return serverid_exists_local(id);
196         }
197
198         if (lp_clustering()) {
199                 return ctdbd_process_exists(messaging_ctdbd_connection(),
200                                             id->vnn, id->pid);
201         }
202
203         return false;
204 }
205
206 static bool serverid_rec_parse(const struct db_record *rec,
207                                struct server_id *id, uint32_t *msg_flags)
208 {
209         struct serverid_key key;
210         struct serverid_data data;
211         TDB_DATA tdbkey;
212         TDB_DATA tdbdata;
213
214         tdbkey = dbwrap_record_get_key(rec);
215         tdbdata = dbwrap_record_get_value(rec);
216
217         if (tdbkey.dsize != sizeof(key)) {
218                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
219                           (int)tdbkey.dsize));
220                 return false;
221         }
222         if (tdbdata.dsize != sizeof(data)) {
223                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
224                           (int)tdbdata.dsize));
225                 return false;
226         }
227
228         memcpy(&key, tdbkey.dptr, sizeof(key));
229         memcpy(&data, tdbdata.dptr, sizeof(data));
230
231         id->pid = key.pid;
232         id->task_id = key.task_id;
233         id->vnn = key.vnn;
234         id->unique_id = data.unique_id;
235         *msg_flags = data.msg_flags;
236         return true;
237 }
238
239 struct serverid_traverse_read_state {
240         int (*fn)(const struct server_id *id, uint32_t msg_flags,
241                   void *private_data);
242         void *private_data;
243 };
244
245 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
246 {
247         struct serverid_traverse_read_state *state =
248                 (struct serverid_traverse_read_state *)private_data;
249         struct server_id id;
250         uint32_t msg_flags;
251
252         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
253                 return 0;
254         }
255         return state->fn(&id, msg_flags,state->private_data);
256 }
257
258 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
259                                       uint32_t msg_flags, void *private_data),
260                             void *private_data)
261 {
262         struct db_context *db;
263         struct serverid_traverse_read_state state;
264         NTSTATUS status;
265
266         db = serverid_db();
267         if (db == NULL) {
268                 return false;
269         }
270         state.fn = fn;
271         state.private_data = private_data;
272
273         status = dbwrap_traverse_read(db, serverid_traverse_read_fn, &state,
274                                       NULL);
275         return NT_STATUS_IS_OK(status);
276 }
277
278 struct serverid_traverse_state {
279         int (*fn)(struct db_record *rec, const struct server_id *id,
280                   uint32_t msg_flags, void *private_data);
281         void *private_data;
282 };
283
284 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
285 {
286         struct serverid_traverse_state *state =
287                 (struct serverid_traverse_state *)private_data;
288         struct server_id id;
289         uint32_t msg_flags;
290
291         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
292                 return 0;
293         }
294         return state->fn(rec, &id, msg_flags, state->private_data);
295 }
296
297 bool serverid_traverse(int (*fn)(struct db_record *rec,
298                                  const struct server_id *id,
299                                  uint32_t msg_flags, void *private_data),
300                             void *private_data)
301 {
302         struct db_context *db;
303         struct serverid_traverse_state state;
304         NTSTATUS status;
305
306         db = serverid_db();
307         if (db == NULL) {
308                 return false;
309         }
310         state.fn = fn;
311         state.private_data = private_data;
312
313         status = dbwrap_traverse(db, serverid_traverse_fn, &state, NULL);
314         return NT_STATUS_IS_OK(status);
315 }
316
317 struct msg_all {
318         struct messaging_context *msg_ctx;
319         int msg_type;
320         uint32_t msg_flag;
321         const void *buf;
322         size_t len;
323         int n_sent;
324 };
325
326 /****************************************************************************
327  Send one of the messages for the broadcast.
328 ****************************************************************************/
329
330 static int traverse_fn(struct db_record *rec, const struct server_id *id,
331                        uint32_t msg_flags, void *state)
332 {
333         struct msg_all *msg_all = (struct msg_all *)state;
334         NTSTATUS status;
335
336         /* Don't send if the receiver hasn't registered an interest. */
337
338         if((msg_flags & msg_all->msg_flag) == 0) {
339                 return 0;
340         }
341
342         /* If the msg send fails because the pid was not found (i.e. smbd died),
343          * the msg has already been deleted from the messages.tdb.*/
344
345         status = messaging_send_buf(msg_all->msg_ctx, *id, msg_all->msg_type,
346                                     (const uint8_t *)msg_all->buf, msg_all->len);
347
348         if (NT_STATUS_EQUAL(status, NT_STATUS_INVALID_HANDLE)) {
349                 struct server_id_buf idbuf;
350
351                 /*
352                  * If the pid was not found delete the entry from
353                  * serverid.tdb
354                  */
355
356                 DEBUG(2, ("pid %s doesn't exist\n",
357                           server_id_str_buf(*id, &idbuf)));
358
359                 dbwrap_record_delete(rec);
360         }
361         msg_all->n_sent++;
362         return 0;
363 }
364
365 /**
366  * Send a message to all smbd processes.
367  *
368  * It isn't very efficient, but should be OK for the sorts of
369  * applications that use it. When we need efficient broadcast we can add
370  * it.
371  *
372  * @param n_sent Set to the number of messages sent.  This should be
373  * equal to the number of processes, but be careful for races.
374  *
375  * @retval True for success.
376  **/
377 bool message_send_all(struct messaging_context *msg_ctx,
378                       int msg_type,
379                       const void *buf, size_t len,
380                       int *n_sent)
381 {
382         struct msg_all msg_all;
383
384         msg_all.msg_type = msg_type;
385         if (msg_type < 0x100) {
386                 msg_all.msg_flag = FLAG_MSG_GENERAL;
387         } else if (msg_type > 0x100 && msg_type < 0x200) {
388                 msg_all.msg_flag = FLAG_MSG_NMBD;
389         } else if (msg_type > 0x200 && msg_type < 0x300) {
390                 msg_all.msg_flag = FLAG_MSG_PRINT_GENERAL;
391         } else if (msg_type > 0x300 && msg_type < 0x400) {
392                 msg_all.msg_flag = FLAG_MSG_SMBD;
393         } else if (msg_type > 0x400 && msg_type < 0x600) {
394                 msg_all.msg_flag = FLAG_MSG_WINBIND;
395         } else if (msg_type > 4000 && msg_type < 5000) {
396                 msg_all.msg_flag = FLAG_MSG_DBWRAP;
397         } else {
398                 return false;
399         }
400
401         msg_all.buf = buf;
402         msg_all.len = len;
403         msg_all.n_sent = 0;
404         msg_all.msg_ctx = msg_ctx;
405
406         serverid_traverse(traverse_fn, &msg_all);
407         if (n_sent)
408                 *n_sent = msg_all.n_sent;
409         return true;
410 }