Add an output parameter to message_send_all that says how many
[samba.git] / source3 / lib / messages.c
1 /* 
2    Unix SMB/Netbios implementation.
3    Version 3.0
4    Samba internal messaging functions
5    Copyright (C) Andrew Tridgell 2000
6    Copyright (C) 2001 by Martin Pool
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 /**
24    @defgroups messages Internal messaging framework
25    @{
26    @file messages.c
27
28    This module is used for internal messaging between Samba daemons. 
29
30    The idea is that if a part of Samba wants to do communication with
31    another Samba process then it will do a message_register() of a
32    dispatch function, and use message_send_pid() to send messages to
33    that process.
34
35    The dispatch function is given the pid of the sender, and it can
36    use that to reply by message_send_pid().  See ping_message() for a
37    simple example.
38
39    This system doesn't have any inherent size limitations but is not
40    very efficient for large messages or when messages are sent in very
41    quick succession.
42
43 */
44
45 #include "includes.h"
46
47 /* the locking database handle */
48 static TDB_CONTEXT *tdb;
49 static int received_signal;
50
51 /* change the message version with any incompatible changes in the protocol */
52 #define MESSAGE_VERSION 1
53
54 struct message_rec {
55         int msg_version;
56         int msg_type;
57         pid_t dest;
58         pid_t src;
59         size_t len;
60 };
61
62 /* we have a linked list of dispatch handlers */
63 static struct dispatch_fns {
64         struct dispatch_fns *next, *prev;
65         int msg_type;
66         void (*fn)(int msg_type, pid_t pid, void *buf, size_t len);
67 } *dispatch_fns;
68
69 /****************************************************************************
70  Notifications come in as signals.
71 ****************************************************************************/
72
73 static void sig_usr1(void)
74 {
75         received_signal = 1;
76         sys_select_signal();
77 }
78
79 /****************************************************************************
80  A useful function for testing the message system.
81 ****************************************************************************/
82
83 void ping_message(int msg_type, pid_t src, void *buf, size_t len)
84 {
85         char *msg = buf ? buf : "none";
86         DEBUG(1,("INFO: Received PING message from PID %u [%s]\n",(unsigned int)src, msg));
87         message_send_pid(src, MSG_PONG, buf, len, True);
88 }
89
90 /****************************************************************************
91  Return current debug level.
92 ****************************************************************************/
93
94 void debuglevel_message(int msg_type, pid_t src, void *buf, size_t len)
95 {
96         DEBUG(1,("INFO: Received REQ_DEBUGLEVEL message from PID %u\n",(unsigned int)src));
97         message_send_pid(src, MSG_DEBUGLEVEL, DEBUGLEVEL_CLASS, sizeof(DEBUGLEVEL_CLASS), True);
98 }
99
100 /****************************************************************************
101  Initialise the messaging functions. 
102 ****************************************************************************/
103
104 BOOL message_init(void)
105 {
106         if (tdb) return True;
107
108         tdb = tdb_open_log(lock_path("messages.tdb"), 
109                        0, TDB_CLEAR_IF_FIRST|TDB_DEFAULT, 
110                        O_RDWR|O_CREAT,0600);
111
112         if (!tdb) {
113                 DEBUG(0,("ERROR: Failed to initialise messages database\n"));
114                 return False;
115         }
116
117         CatchSignal(SIGUSR1, SIGNAL_CAST sig_usr1);
118
119         message_register(MSG_PING, ping_message);
120         message_register(MSG_REQ_DEBUGLEVEL, debuglevel_message);
121
122         return True;
123 }
124
125 /*******************************************************************
126  Form a static tdb key from a pid.
127 ******************************************************************/
128
129 static TDB_DATA message_key_pid(pid_t pid)
130 {
131         static char key[20];
132         TDB_DATA kbuf;
133
134         slprintf(key, sizeof(key)-1, "PID/%d", (int)pid);
135         
136         kbuf.dptr = (char *)key;
137         kbuf.dsize = strlen(key)+1;
138         return kbuf;
139 }
140
141 /****************************************************************************
142  Notify a process that it has a message. If the process doesn't exist 
143  then delete its record in the database.
144 ****************************************************************************/
145
146 static BOOL message_notify(pid_t pid)
147 {
148         if (kill(pid, SIGUSR1) == -1) {
149                 if (errno == ESRCH) {
150                         DEBUG(2,("pid %d doesn't exist - deleting messages record\n", (int)pid));
151                         tdb_delete(tdb, message_key_pid(pid));
152                 } else {
153                         DEBUG(2,("message to process %d failed - %s\n", (int)pid, strerror(errno)));
154                 }
155                 return False;
156         }
157         return True;
158 }
159
160 /****************************************************************************
161  Send a message to a particular pid.
162 ****************************************************************************/
163
164 BOOL message_send_pid(pid_t pid, int msg_type, const void *buf, size_t len,
165                       BOOL duplicates_allowed)
166 {
167         TDB_DATA kbuf;
168         TDB_DATA dbuf;
169         struct message_rec rec;
170         void *p;
171
172         rec.msg_version = MESSAGE_VERSION;
173         rec.msg_type = msg_type;
174         rec.dest = pid;
175         rec.src = sys_getpid();
176         rec.len = len;
177
178         kbuf = message_key_pid(pid);
179
180         /* lock the record for the destination */
181         tdb_chainlock(tdb, kbuf);
182
183         dbuf = tdb_fetch(tdb, kbuf);
184
185         if (!dbuf.dptr) {
186                 /* its a new record */
187                 p = (void *)malloc(len + sizeof(rec));
188                 if (!p) goto failed;
189
190                 memcpy(p, &rec, sizeof(rec));
191                 if (len > 0) memcpy((void *)((char*)p+sizeof(rec)), buf, len);
192
193                 dbuf.dptr = p;
194                 dbuf.dsize = len + sizeof(rec);
195                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
196                 SAFE_FREE(p);
197                 goto ok;
198         }
199
200         if (!duplicates_allowed) {
201                 char *ptr;
202                 struct message_rec prec;
203                 
204                 for(ptr = (char *)dbuf.dptr; ptr < dbuf.dptr + dbuf.dsize; ) {
205                         /*
206                          * First check if the message header matches, then, if it's a non-zero
207                          * sized message, check if the data matches. If so it's a duplicate and
208                          * we can discard it. JRA.
209                          */
210
211                         if (!memcmp(ptr, &rec, sizeof(rec))) {
212                                 if (!len || (len && !memcmp( ptr + sizeof(rec), buf, len))) {
213                                         DEBUG(10,("message_send_pid: discarding duplicate message.\n"));
214                                         SAFE_FREE(dbuf.dptr);
215                                         tdb_chainunlock(tdb, kbuf);
216                                         return True;
217                                 }
218                         }
219                         memcpy(&prec, ptr, sizeof(prec));
220                         ptr += sizeof(rec) + prec.len;
221                 }
222         }
223
224         /* we're adding to an existing entry */
225         p = (void *)malloc(dbuf.dsize + len + sizeof(rec));
226         if (!p) goto failed;
227
228         memcpy(p, dbuf.dptr, dbuf.dsize);
229         memcpy((void *)((char*)p+dbuf.dsize), &rec, sizeof(rec));
230         if (len > 0) memcpy((void *)((char*)p+dbuf.dsize+sizeof(rec)), buf, len);
231
232         SAFE_FREE(dbuf.dptr);
233         dbuf.dptr = p;
234         dbuf.dsize += len + sizeof(rec);
235         tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
236         SAFE_FREE(dbuf.dptr);
237
238  ok:
239         tdb_chainunlock(tdb, kbuf);
240         errno = 0;                    /* paranoia */
241         return message_notify(pid);
242
243  failed:
244         tdb_chainunlock(tdb, kbuf);
245         errno = 0;                    /* paranoia */
246         return False;
247 }
248
249 /****************************************************************************
250  Retrieve the next message for the current process.
251 ****************************************************************************/
252
253 static BOOL message_recv(int *msg_type, pid_t *src, void **buf, size_t *len)
254 {
255         TDB_DATA kbuf;
256         TDB_DATA dbuf;
257         struct message_rec rec;
258
259         kbuf = message_key_pid(sys_getpid());
260
261         tdb_chainlock(tdb, kbuf);
262         
263         dbuf = tdb_fetch(tdb, kbuf);
264         if (dbuf.dptr == NULL || dbuf.dsize == 0) goto failed;
265
266         memcpy(&rec, dbuf.dptr, sizeof(rec));
267
268         if (rec.msg_version != MESSAGE_VERSION) {
269                 DEBUG(0,("message version %d received (expected %d)\n", rec.msg_version, MESSAGE_VERSION));
270                 goto failed;
271         }
272
273         if (rec.len > 0) {
274                 (*buf) = (void *)malloc(rec.len);
275                 if (!(*buf)) goto failed;
276
277                 memcpy(*buf, dbuf.dptr+sizeof(rec), rec.len);
278         } else {
279                 *buf = NULL;
280         }
281
282         *len = rec.len;
283         *msg_type = rec.msg_type;
284         *src = rec.src;
285
286         if (dbuf.dsize - (sizeof(rec)+rec.len) > 0)
287                 memmove(dbuf.dptr, dbuf.dptr+sizeof(rec)+rec.len, dbuf.dsize - (sizeof(rec)+rec.len));
288         dbuf.dsize -= sizeof(rec)+rec.len;
289
290         if (dbuf.dsize == 0)
291                 tdb_delete(tdb, kbuf);
292         else
293                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
294
295         SAFE_FREE(dbuf.dptr);
296         tdb_chainunlock(tdb, kbuf);
297         return True;
298
299  failed:
300         tdb_chainunlock(tdb, kbuf);
301         return False;
302 }
303
304 /****************************************************************************
305  Receive and dispatch any messages pending for this process.
306  Notice that all dispatch handlers for a particular msg_type get called,
307  so you can register multiple handlers for a message.
308 ****************************************************************************/
309
310 void message_dispatch(void)
311 {
312         int msg_type;
313         pid_t src;
314         void *buf;
315         size_t len;
316         struct dispatch_fns *dfn;
317
318         if (!received_signal) return;
319
320         DEBUG(10,("message_dispatch: received_signal = %d\n", received_signal));
321
322         received_signal = 0;
323
324         while (message_recv(&msg_type, &src, &buf, &len)) {
325                 for (dfn = dispatch_fns; dfn; dfn = dfn->next) {
326                         if (dfn->msg_type == msg_type) {
327                                 DEBUG(10,("message_dispatch: processing message of type %d.\n", msg_type));
328                                 dfn->fn(msg_type, src, buf, len);
329                         }
330                 }
331                 SAFE_FREE(buf);
332         }
333 }
334
335 /****************************************************************************
336  Register a dispatch function for a particular message type.
337 ****************************************************************************/
338
339 void message_register(int msg_type, 
340                       void (*fn)(int msg_type, pid_t pid, void *buf, size_t len))
341 {
342         struct dispatch_fns *dfn;
343
344         dfn = (struct dispatch_fns *)malloc(sizeof(*dfn));
345
346         if (dfn != NULL) {
347
348                 ZERO_STRUCTPN(dfn);
349
350                 dfn->msg_type = msg_type;
351                 dfn->fn = fn;
352
353                 DLIST_ADD(dispatch_fns, dfn);
354         }
355         else {
356         
357                 DEBUG(0,("message_register: Not enough memory. malloc failed!\n"));
358         }
359 }
360
361 /****************************************************************************
362  De-register the function for a particular message type.
363 ****************************************************************************/
364
365 void message_deregister(int msg_type)
366 {
367         struct dispatch_fns *dfn, *next;
368
369         for (dfn = dispatch_fns; dfn; dfn = next) {
370                 next = dfn->next;
371                 if (dfn->msg_type == msg_type) {
372                         DLIST_REMOVE(dispatch_fns, dfn);
373                         SAFE_FREE(dfn);
374                 }
375         }       
376 }
377
378 struct msg_all {
379         int msg_type;
380         const void *buf;
381         size_t len;
382         BOOL duplicates;
383         int             n_sent;
384 };
385
386 /****************************************************************************
387  Send one of the messages for the broadcast.
388 ****************************************************************************/
389
390 static int traverse_fn(TDB_CONTEXT *the_tdb, TDB_DATA kbuf, TDB_DATA dbuf, void *state)
391 {
392         struct connections_data crec;
393         struct msg_all *msg_all = (struct msg_all *)state;
394
395         if (dbuf.dsize != sizeof(crec))
396                 return 0;
397
398         memcpy(&crec, dbuf.dptr, sizeof(crec));
399
400         if (crec.cnum != -1)
401                 return 0;
402
403         /* if the msg send fails because the pid was not found (i.e. smbd died), 
404          * the msg has already been deleted from the messages.tdb.*/
405         if (!message_send_pid(crec.pid, msg_all->msg_type,
406                               msg_all->buf, msg_all->len,
407                               msg_all->duplicates)) {
408                 
409                 /* if the pid was not found delete the entry from connections.tdb */
410                 if (errno == ESRCH) {
411                         DEBUG(2,("pid %u doesn't exist - deleting connections %d [%s]\n",
412                                         (unsigned int)crec.pid, crec.cnum, crec.name));
413                         tdb_delete(the_tdb, kbuf);
414                 }
415         }
416         msg_all->n_sent++;
417         return 0;
418 }
419
420 /**
421  * Send a message to all smbd processes.
422  *
423  * It isn't very efficient, but should be OK for the sorts of
424  * applications that use it. When we need efficient broadcast we can add
425  * it.
426  *
427  * @param n_sent Set to the number of messages sent.  This should be
428  * equal to the number of processes, but be careful for races.
429  *
430  * @return True for success.
431  **/
432 BOOL message_send_all(TDB_CONTEXT *conn_tdb, int msg_type,
433                       const void *buf, size_t len,
434                       BOOL duplicates_allowed,
435                       int *n_sent)
436 {
437         struct msg_all msg_all;
438
439         msg_all.msg_type = msg_type;
440         msg_all.buf = buf;
441         msg_all.len = len;
442         msg_all.duplicates = duplicates_allowed;
443         msg_all.n_sent = 0;
444
445         tdb_traverse(conn_tdb, traverse_fn, &msg_all);
446         if (n_sent)
447                 *n_sent = msg_all.n_sent;
448         return True;
449 }
450
451 /** @} **/