Given how often a panic has to do with malloc() problems, don't tempt
[tprouty/samba.git] / source / lib / messages.c
1 /* 
2    Unix SMB/CIFS implementation.
3    Samba internal messaging functions
4    Copyright (C) Andrew Tridgell 2000
5    Copyright (C) 2001 by Martin Pool
6    Copyright (C) 2002 by Jeremy Allison
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 /**
24   @defgroup messages Internal messaging framework
25   @{
26   @file messages.c
27   
28   @brief  Module for internal messaging between Samba daemons. 
29
30    The idea is that if a part of Samba wants to do communication with
31    another Samba process then it will do a message_register() of a
32    dispatch function, and use message_send_pid() to send messages to
33    that process.
34
35    The dispatch function is given the pid of the sender, and it can
36    use that to reply by message_send_pid().  See ping_message() for a
37    simple example.
38
39    @caution Dispatch functions must be able to cope with incoming
40    messages on an *odd* byte boundary.
41
42    This system doesn't have any inherent size limitations but is not
43    very efficient for large messages or when messages are sent in very
44    quick succession.
45
46 */
47
48 #include "includes.h"
49
50 /* the locking database handle */
51 static TDB_CONTEXT *tdb;
52 static int received_signal;
53
54 /* change the message version with any incompatible changes in the protocol */
55 #define MESSAGE_VERSION 1
56
57 struct message_rec {
58         int msg_version;
59         int msg_type;
60         pid_t dest;
61         pid_t src;
62         size_t len;
63 };
64
65 /* we have a linked list of dispatch handlers */
66 static struct dispatch_fns {
67         struct dispatch_fns *next, *prev;
68         int msg_type;
69         void (*fn)(int msg_type, pid_t pid, void *buf, size_t len);
70 } *dispatch_fns;
71
72 /****************************************************************************
73  Notifications come in as signals.
74 ****************************************************************************/
75
76 static void sig_usr1(void)
77 {
78         received_signal = 1;
79         sys_select_signal();
80 }
81
82 /****************************************************************************
83  A useful function for testing the message system.
84 ****************************************************************************/
85
86 static void ping_message(int msg_type, pid_t src, void *buf, size_t len)
87 {
88         const char *msg = buf ? buf : "none";
89         DEBUG(1,("INFO: Received PING message from PID %u [%s]\n",(unsigned int)src, msg));
90         message_send_pid(src, MSG_PONG, buf, len, True);
91 }
92
93 /****************************************************************************
94  Initialise the messaging functions. 
95 ****************************************************************************/
96
97 BOOL message_init(void)
98 {
99         if (tdb) return True;
100
101         tdb = tdb_open_log(lock_path("messages.tdb"), 
102                        0, TDB_CLEAR_IF_FIRST|TDB_DEFAULT, 
103                        O_RDWR|O_CREAT,0600);
104
105         if (!tdb) {
106                 DEBUG(0,("ERROR: Failed to initialise messages database\n"));
107                 return False;
108         }
109
110         CatchSignal(SIGUSR1, SIGNAL_CAST sig_usr1);
111
112         message_register(MSG_PING, ping_message);
113
114         /* Register some debugging related messages */
115
116         register_msg_pool_usage();
117         register_dmalloc_msgs();
118
119         return True;
120 }
121
122 /*******************************************************************
123  Form a static tdb key from a pid.
124 ******************************************************************/
125
126 static TDB_DATA message_key_pid(pid_t pid)
127 {
128         static char key[20];
129         TDB_DATA kbuf;
130
131         slprintf(key, sizeof(key)-1, "PID/%d", (int)pid);
132         
133         kbuf.dptr = (char *)key;
134         kbuf.dsize = strlen(key)+1;
135         return kbuf;
136 }
137
138 /****************************************************************************
139  Notify a process that it has a message. If the process doesn't exist 
140  then delete its record in the database.
141 ****************************************************************************/
142
143 static BOOL message_notify(pid_t pid)
144 {
145         /*
146          * Doing kill with a non-positive pid causes messages to be
147          * sent to places we don't want.
148          */
149
150         SMB_ASSERT(pid > 0);
151
152         if (kill(pid, SIGUSR1) == -1) {
153                 if (errno == ESRCH) {
154                         DEBUG(2,("pid %d doesn't exist - deleting messages record\n", (int)pid));
155                         tdb_delete(tdb, message_key_pid(pid));
156                 } else {
157                         DEBUG(2,("message to process %d failed - %s\n", (int)pid, strerror(errno)));
158                 }
159                 return False;
160         }
161         return True;
162 }
163
164 /****************************************************************************
165  Send a message to a particular pid.
166 ****************************************************************************/
167
168 static BOOL message_send_pid_internal(pid_t pid, int msg_type, const void *buf, size_t len,
169                       BOOL duplicates_allowed, unsigned int timeout)
170 {
171         TDB_DATA kbuf;
172         TDB_DATA dbuf;
173         TDB_DATA old_dbuf;
174         struct message_rec rec;
175         char *ptr;
176         struct message_rec prec;
177
178         /*
179          * Doing kill with a non-positive pid causes messages to be
180          * sent to places we don't want.
181          */
182
183         SMB_ASSERT(pid > 0);
184
185         rec.msg_version = MESSAGE_VERSION;
186         rec.msg_type = msg_type;
187         rec.dest = pid;
188         rec.src = sys_getpid();
189         rec.len = len;
190
191         kbuf = message_key_pid(pid);
192
193         dbuf.dptr = (void *)malloc(len + sizeof(rec));
194         if (!dbuf.dptr)
195                 return False;
196
197         memcpy(dbuf.dptr, &rec, sizeof(rec));
198         if (len > 0)
199                 memcpy((void *)((char*)dbuf.dptr+sizeof(rec)), buf, len);
200
201         dbuf.dsize = len + sizeof(rec);
202
203         if (duplicates_allowed) {
204
205                 /* If duplicates are allowed we can just append the message and return. */
206
207                 /* lock the record for the destination */
208                 if (timeout) {
209                         if (tdb_chainlock_with_timeout(tdb, kbuf, timeout) == -1) {
210                                 DEBUG(0,("message_send_pid_internal: failed to get chainlock with timeout %ul.\n", timeout));
211                                 return False;
212                         }
213                 } else {
214                         if (tdb_chainlock(tdb, kbuf) == -1) {
215                                 DEBUG(0,("message_send_pid_internal: failed to get chainlock.\n"));
216                                 return False;
217                         }
218                 }       
219                 tdb_append(tdb, kbuf, dbuf);
220                 tdb_chainunlock(tdb, kbuf);
221
222                 SAFE_FREE(dbuf.dptr);
223                 errno = 0;                    /* paranoia */
224                 return message_notify(pid);
225         }
226
227         /* lock the record for the destination */
228         if (timeout) {
229                 if (tdb_chainlock_with_timeout(tdb, kbuf, timeout) == -1) {
230                         DEBUG(0,("message_send_pid_internal: failed to get chainlock with timeout %ul.\n", timeout));
231                         return False;
232                 }
233         } else {
234                 if (tdb_chainlock(tdb, kbuf) == -1) {
235                         DEBUG(0,("message_send_pid_internal: failed to get chainlock.\n"));
236                         return False;
237                 }
238         }       
239
240         old_dbuf = tdb_fetch(tdb, kbuf);
241
242         if (!old_dbuf.dptr) {
243                 /* its a new record */
244
245                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
246                 tdb_chainunlock(tdb, kbuf);
247
248                 SAFE_FREE(dbuf.dptr);
249                 errno = 0;                    /* paranoia */
250                 return message_notify(pid);
251         }
252
253         /* Not a new record. Check for duplicates. */
254
255         for(ptr = (char *)old_dbuf.dptr; ptr < old_dbuf.dptr + old_dbuf.dsize; ) {
256                 /*
257                  * First check if the message header matches, then, if it's a non-zero
258                  * sized message, check if the data matches. If so it's a duplicate and
259                  * we can discard it. JRA.
260                  */
261
262                 if (!memcmp(ptr, &rec, sizeof(rec))) {
263                         if (!len || (len && !memcmp( ptr + sizeof(rec), buf, len))) {
264                                 tdb_chainunlock(tdb, kbuf);
265                                 DEBUG(10,("message_send_pid_internal: discarding duplicate message.\n"));
266                                 SAFE_FREE(dbuf.dptr);
267                                 SAFE_FREE(old_dbuf.dptr);
268                                 return True;
269                         }
270                 }
271                 memcpy(&prec, ptr, sizeof(prec));
272                 ptr += sizeof(rec) + prec.len;
273         }
274
275         /* we're adding to an existing entry */
276
277         tdb_append(tdb, kbuf, dbuf);
278         tdb_chainunlock(tdb, kbuf);
279
280         SAFE_FREE(old_dbuf.dptr);
281         SAFE_FREE(dbuf.dptr);
282
283         errno = 0;                    /* paranoia */
284         return message_notify(pid);
285 }
286
287 /****************************************************************************
288  Send a message to a particular pid - no timeout.
289 ****************************************************************************/
290
291 BOOL message_send_pid(pid_t pid, int msg_type, const void *buf, size_t len, BOOL duplicates_allowed)
292 {
293         return message_send_pid_internal(pid, msg_type, buf, len, duplicates_allowed, 0);
294 }
295
296 /****************************************************************************
297  Send a message to a particular pid, with timeout in seconds.
298 ****************************************************************************/
299
300 BOOL message_send_pid_with_timeout(pid_t pid, int msg_type, const void *buf, size_t len,
301                 BOOL duplicates_allowed, unsigned int timeout)
302 {
303         return message_send_pid_internal(pid, msg_type, buf, len, duplicates_allowed, timeout);
304 }
305
306 /****************************************************************************
307  Count the messages pending for a particular pid. Expensive....
308 ****************************************************************************/
309
310 unsigned int messages_pending_for_pid(pid_t pid)
311 {
312         TDB_DATA kbuf;
313         TDB_DATA dbuf;
314         char *buf;
315         unsigned int message_count = 0;
316
317         kbuf = message_key_pid(sys_getpid());
318
319         dbuf = tdb_fetch(tdb, kbuf);
320         if (dbuf.dptr == NULL || dbuf.dsize == 0) {
321                 SAFE_FREE(dbuf.dptr);
322                 return 0;
323         }
324
325         for (buf = dbuf.dptr; dbuf.dsize > sizeof(struct message_rec);) {
326                 struct message_rec rec;
327                 memcpy(&rec, buf, sizeof(rec));
328                 buf += (sizeof(rec) + rec.len);
329                 dbuf.dsize -= (sizeof(rec) + rec.len);
330                 message_count++;
331         }
332
333         SAFE_FREE(dbuf.dptr);
334         return message_count;
335 }
336
337 /****************************************************************************
338  Retrieve all messages for the current process.
339 ****************************************************************************/
340
341 static BOOL retrieve_all_messages(char **msgs_buf, size_t *total_len)
342 {
343         TDB_DATA kbuf;
344         TDB_DATA dbuf;
345         TDB_DATA null_dbuf;
346
347         ZERO_STRUCT(null_dbuf);
348
349         *msgs_buf = NULL;
350         *total_len = 0;
351
352         kbuf = message_key_pid(sys_getpid());
353
354         if (tdb_chainlock(tdb, kbuf) == -1)
355                 return False;
356
357         dbuf = tdb_fetch(tdb, kbuf);
358         /*
359          * Replace with an empty record to keep the allocated
360          * space in the tdb.
361          */
362         tdb_store(tdb, kbuf, null_dbuf, TDB_REPLACE);
363         tdb_chainunlock(tdb, kbuf);
364
365         if (dbuf.dptr == NULL || dbuf.dsize == 0) {
366                 SAFE_FREE(dbuf.dptr);
367                 return False;
368         }
369
370         *msgs_buf = dbuf.dptr;
371         *total_len = dbuf.dsize;
372
373         return True;
374 }
375
376 /****************************************************************************
377  Parse out the next message for the current process.
378 ****************************************************************************/
379
380 static BOOL message_recv(char *msgs_buf, size_t total_len, int *msg_type, pid_t *src, char **buf, size_t *len)
381 {
382         struct message_rec rec;
383         char *ret_buf = *buf;
384
385         *buf = NULL;
386         *len = 0;
387
388         if (total_len - (ret_buf - msgs_buf) < sizeof(rec))
389                 return False;
390
391         memcpy(&rec, ret_buf, sizeof(rec));
392         ret_buf += sizeof(rec);
393
394         if (rec.msg_version != MESSAGE_VERSION) {
395                 DEBUG(0,("message version %d received (expected %d)\n", rec.msg_version, MESSAGE_VERSION));
396                 return False;
397         }
398
399         if (rec.len > 0) {
400                 if (total_len - (ret_buf - msgs_buf) < rec.len)
401                         return False;
402         }
403
404         *len = rec.len;
405         *msg_type = rec.msg_type;
406         *src = rec.src;
407         *buf = ret_buf;
408
409         return True;
410 }
411
412 /****************************************************************************
413  Receive and dispatch any messages pending for this process.
414  Notice that all dispatch handlers for a particular msg_type get called,
415  so you can register multiple handlers for a message.
416  *NOTE*: Dispatch functions must be able to cope with incoming
417  messages on an *odd* byte boundary.
418 ****************************************************************************/
419
420 void message_dispatch(void)
421 {
422         int msg_type;
423         pid_t src;
424         char *buf;
425         char *msgs_buf;
426         size_t len, total_len;
427         struct dispatch_fns *dfn;
428         int n_handled;
429
430         if (!received_signal)
431                 return;
432
433         DEBUG(10,("message_dispatch: received_signal = %d\n", received_signal));
434
435         received_signal = 0;
436
437         if (!retrieve_all_messages(&msgs_buf, &total_len))
438                 return;
439
440         for (buf = msgs_buf; message_recv(msgs_buf, total_len, &msg_type, &src, &buf, &len); buf += len) {
441                 DEBUG(10,("message_dispatch: received msg_type=%d src_pid=%u\n",
442                           msg_type, (unsigned int) src));
443                 n_handled = 0;
444                 for (dfn = dispatch_fns; dfn; dfn = dfn->next) {
445                         if (dfn->msg_type == msg_type) {
446                                 DEBUG(10,("message_dispatch: processing message of type %d.\n", msg_type));
447                                 dfn->fn(msg_type, src, len ? (void *)buf : NULL, len);
448                                 n_handled++;
449                         }
450                 }
451                 if (!n_handled) {
452                         DEBUG(5,("message_dispatch: warning: no handlers registed for "
453                                  "msg_type %d in pid %u\n",
454                                  msg_type, (unsigned int)sys_getpid()));
455                 }
456         }
457         SAFE_FREE(msgs_buf);
458 }
459
460 /****************************************************************************
461  Register a dispatch function for a particular message type.
462  *NOTE*: Dispatch functions must be able to cope with incoming
463  messages on an *odd* byte boundary.
464 ****************************************************************************/
465
466 void message_register(int msg_type, 
467                       void (*fn)(int msg_type, pid_t pid, void *buf, size_t len))
468 {
469         struct dispatch_fns *dfn;
470
471         dfn = (struct dispatch_fns *)malloc(sizeof(*dfn));
472
473         if (dfn != NULL) {
474
475                 ZERO_STRUCTPN(dfn);
476
477                 dfn->msg_type = msg_type;
478                 dfn->fn = fn;
479
480                 DLIST_ADD(dispatch_fns, dfn);
481         }
482         else {
483         
484                 DEBUG(0,("message_register: Not enough memory. malloc failed!\n"));
485         }
486 }
487
488 /****************************************************************************
489  De-register the function for a particular message type.
490 ****************************************************************************/
491
492 void message_deregister(int msg_type)
493 {
494         struct dispatch_fns *dfn, *next;
495
496         for (dfn = dispatch_fns; dfn; dfn = next) {
497                 next = dfn->next;
498                 if (dfn->msg_type == msg_type) {
499                         DLIST_REMOVE(dispatch_fns, dfn);
500                         SAFE_FREE(dfn);
501                 }
502         }       
503 }
504
505 struct msg_all {
506         int msg_type;
507         uint32 msg_flag;
508         const void *buf;
509         size_t len;
510         BOOL duplicates;
511         int n_sent;
512 };
513
514 /****************************************************************************
515  Send one of the messages for the broadcast.
516 ****************************************************************************/
517
518 static int traverse_fn(TDB_CONTEXT *the_tdb, TDB_DATA kbuf, TDB_DATA dbuf, void *state)
519 {
520         struct connections_data crec;
521         struct msg_all *msg_all = (struct msg_all *)state;
522
523         if (dbuf.dsize != sizeof(crec))
524                 return 0;
525
526         memcpy(&crec, dbuf.dptr, sizeof(crec));
527
528         if (crec.cnum != -1)
529                 return 0;
530
531         /* Don't send if the receiver hasn't registered an interest. */
532
533         if(!(crec.bcast_msg_flags & msg_all->msg_flag))
534                 return 0;
535
536         /* If the msg send fails because the pid was not found (i.e. smbd died), 
537          * the msg has already been deleted from the messages.tdb.*/
538
539         if (!message_send_pid(crec.pid, msg_all->msg_type,
540                               msg_all->buf, msg_all->len,
541                               msg_all->duplicates)) {
542                 
543                 /* If the pid was not found delete the entry from connections.tdb */
544
545                 if (errno == ESRCH) {
546                         DEBUG(2,("pid %u doesn't exist - deleting connections %d [%s]\n",
547                                         (unsigned int)crec.pid, crec.cnum, crec.name));
548                         tdb_delete(the_tdb, kbuf);
549                 }
550         }
551         msg_all->n_sent++;
552         return 0;
553 }
554
555 /**
556  * Send a message to all smbd processes.
557  *
558  * It isn't very efficient, but should be OK for the sorts of
559  * applications that use it. When we need efficient broadcast we can add
560  * it.
561  *
562  * @param n_sent Set to the number of messages sent.  This should be
563  * equal to the number of processes, but be careful for races.
564  *
565  * @retval True for success.
566  **/
567 BOOL message_send_all(TDB_CONTEXT *conn_tdb, int msg_type,
568                       const void *buf, size_t len,
569                       BOOL duplicates_allowed,
570                       int *n_sent)
571 {
572         struct msg_all msg_all;
573
574         msg_all.msg_type = msg_type;
575         if (msg_type < 1000)
576                 msg_all.msg_flag = FLAG_MSG_GENERAL;
577         else if (msg_type > 1000 && msg_type < 2000)
578                 msg_all.msg_flag = FLAG_MSG_NMBD;
579         else if (msg_type > 2000 && msg_type < 3000)
580                 msg_all.msg_flag = FLAG_MSG_PRINTING;
581         else if (msg_type > 3000 && msg_type < 4000)
582                 msg_all.msg_flag = FLAG_MSG_SMBD;
583         else
584                 return False;
585
586         msg_all.buf = buf;
587         msg_all.len = len;
588         msg_all.duplicates = duplicates_allowed;
589         msg_all.n_sent = 0;
590
591         tdb_traverse(conn_tdb, traverse_fn, &msg_all);
592         if (n_sent)
593                 *n_sent = msg_all.n_sent;
594         return True;
595 }
596 /** @} **/