first public release of samba4 code
[samba.git] / source4 / lib / messages.c
1 /* 
2    Unix SMB/CIFS implementation.
3    Samba internal messaging functions
4    Copyright (C) Andrew Tridgell 2000
5    Copyright (C) 2001 by Martin Pool
6    Copyright (C) 2002 by Jeremy Allison
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 /**
24   @defgroup messages Internal messaging framework
25   @{
26   @file messages.c
27   
28   @brief  Module for internal messaging between Samba daemons. 
29
30    The idea is that if a part of Samba wants to do communication with
31    another Samba process then it will do a message_register() of a
32    dispatch function, and use message_send_pid() to send messages to
33    that process.
34
35    The dispatch function is given the pid of the sender, and it can
36    use that to reply by message_send_pid().  See ping_message() for a
37    simple example.
38
39    @caution Dispatch functions must be able to cope with incoming
40    messages on an *odd* byte boundary.
41
42    This system doesn't have any inherent size limitations but is not
43    very efficient for large messages or when messages are sent in very
44    quick succession.
45
46 */
47
48 #include "includes.h"
49
50 /* the locking database handle */
51 static TDB_CONTEXT *tdb;
52 static int received_signal;
53
54 /* change the message version with any incompatible changes in the protocol */
55 #define MESSAGE_VERSION 1
56
57 struct message_rec {
58         int msg_version;
59         int msg_type;
60         pid_t dest;
61         pid_t src;
62         size_t len;
63 };
64
65 /* we have a linked list of dispatch handlers */
66 static struct dispatch_fns {
67         struct dispatch_fns *next, *prev;
68         int msg_type;
69         void (*fn)(int msg_type, pid_t pid, void *buf, size_t len);
70 } *dispatch_fns;
71
72 /****************************************************************************
73  Notifications come in as signals.
74 ****************************************************************************/
75
76 static void sig_usr1(void)
77 {
78         received_signal = 1;
79         sys_select_signal();
80 }
81
82 /****************************************************************************
83  A useful function for testing the message system.
84 ****************************************************************************/
85
86 static void ping_message(int msg_type, pid_t src, void *buf, size_t len)
87 {
88         const char *msg = buf ? buf : "none";
89         DEBUG(1,("INFO: Received PING message from PID %u [%s]\n",(unsigned int)src, msg));
90         message_send_pid(src, MSG_PONG, buf, len, True);
91 }
92
93 /****************************************************************************
94  Initialise the messaging functions. 
95 ****************************************************************************/
96
97 BOOL message_init(void)
98 {
99         TALLOC_CTX *mem_ctx;
100         
101         if (tdb) return True;
102
103         mem_ctx = talloc_init("message_init");
104         if (!mem_ctx) {
105                 DEBUG(0,("ERROR: No memory to initialise messages database\n"));
106                 return False;
107         }
108         tdb = tdb_open_log(lock_path(mem_ctx, "messages.tdb"), 
109                        0, TDB_CLEAR_IF_FIRST|TDB_DEFAULT, 
110                        O_RDWR|O_CREAT,0600);
111         talloc_destroy(mem_ctx);
112
113         if (!tdb) {
114                 DEBUG(0,("ERROR: Failed to initialise messages database\n"));
115                 return False;
116         }
117
118         CatchSignal(SIGUSR1, SIGNAL_CAST sig_usr1);
119
120         message_register(MSG_PING, ping_message);
121
122         return True;
123 }
124
125 /*******************************************************************
126  Form a static tdb key from a pid.
127 ******************************************************************/
128
129 static TDB_DATA message_key_pid(pid_t pid)
130 {
131         static char key[20];
132         TDB_DATA kbuf;
133
134         slprintf(key, sizeof(key)-1, "PID/%d", (int)pid);
135         
136         kbuf.dptr = (char *)key;
137         kbuf.dsize = strlen(key)+1;
138         return kbuf;
139 }
140
141 /****************************************************************************
142  Notify a process that it has a message. If the process doesn't exist 
143  then delete its record in the database.
144 ****************************************************************************/
145
146 static BOOL message_notify(pid_t pid)
147 {
148         /*
149          * Doing kill with a non-positive pid causes messages to be
150          * sent to places we don't want.
151          */
152
153         SMB_ASSERT(pid > 0);
154
155         if (kill(pid, SIGUSR1) == -1) {
156                 if (errno == ESRCH) {
157                         DEBUG(2,("pid %d doesn't exist - deleting messages record\n", (int)pid));
158                         tdb_delete(tdb, message_key_pid(pid));
159                 } else {
160                         DEBUG(2,("message to process %d failed - %s\n", (int)pid, strerror(errno)));
161                 }
162                 return False;
163         }
164         return True;
165 }
166
167 /****************************************************************************
168  Send a message to a particular pid.
169 ****************************************************************************/
170
171 static BOOL message_send_pid_internal(pid_t pid, int msg_type, const void *buf, size_t len,
172                       BOOL duplicates_allowed, unsigned int timeout)
173 {
174         TDB_DATA kbuf;
175         TDB_DATA dbuf;
176         TDB_DATA old_dbuf;
177         struct message_rec rec;
178         char *ptr;
179         struct message_rec prec;
180
181         /*
182          * Doing kill with a non-positive pid causes messages to be
183          * sent to places we don't want.
184          */
185
186         SMB_ASSERT(pid > 0);
187
188         rec.msg_version = MESSAGE_VERSION;
189         rec.msg_type = msg_type;
190         rec.dest = pid;
191         rec.src = getpid();
192         rec.len = len;
193
194         kbuf = message_key_pid(pid);
195
196         dbuf.dptr = (void *)malloc(len + sizeof(rec));
197         if (!dbuf.dptr)
198                 return False;
199
200         memcpy(dbuf.dptr, &rec, sizeof(rec));
201         if (len > 0)
202                 memcpy((void *)((char*)dbuf.dptr+sizeof(rec)), buf, len);
203
204         dbuf.dsize = len + sizeof(rec);
205
206         if (duplicates_allowed) {
207
208                 /* If duplicates are allowed we can just append the message and return. */
209
210                 /* lock the record for the destination */
211                 if (timeout) {
212                         if (tdb_chainlock_with_timeout(tdb, kbuf, timeout) == -1) {
213                                 DEBUG(0,("message_send_pid_internal: failed to get chainlock with timeout %ul.\n", timeout));
214                                 return False;
215                         }
216                 } else {
217                         if (tdb_chainlock(tdb, kbuf) == -1) {
218                                 DEBUG(0,("message_send_pid_internal: failed to get chainlock.\n"));
219                                 return False;
220                         }
221                 }       
222                 tdb_append(tdb, kbuf, dbuf);
223                 tdb_chainunlock(tdb, kbuf);
224
225                 SAFE_FREE(dbuf.dptr);
226                 errno = 0;                    /* paranoia */
227                 return message_notify(pid);
228         }
229
230         /* lock the record for the destination */
231         if (timeout) {
232                 if (tdb_chainlock_with_timeout(tdb, kbuf, timeout) == -1) {
233                         DEBUG(0,("message_send_pid_internal: failed to get chainlock with timeout %ul.\n", timeout));
234                         return False;
235                 }
236         } else {
237                 if (tdb_chainlock(tdb, kbuf) == -1) {
238                         DEBUG(0,("message_send_pid_internal: failed to get chainlock.\n"));
239                         return False;
240                 }
241         }       
242
243         old_dbuf = tdb_fetch(tdb, kbuf);
244
245         if (!old_dbuf.dptr) {
246                 /* its a new record */
247
248                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
249                 tdb_chainunlock(tdb, kbuf);
250
251                 SAFE_FREE(dbuf.dptr);
252                 errno = 0;                    /* paranoia */
253                 return message_notify(pid);
254         }
255
256         /* Not a new record. Check for duplicates. */
257
258         for(ptr = (char *)old_dbuf.dptr; ptr < old_dbuf.dptr + old_dbuf.dsize; ) {
259                 /*
260                  * First check if the message header matches, then, if it's a non-zero
261                  * sized message, check if the data matches. If so it's a duplicate and
262                  * we can discard it. JRA.
263                  */
264
265                 if (!memcmp(ptr, &rec, sizeof(rec))) {
266                         if (!len || (len && !memcmp( ptr + sizeof(rec), buf, len))) {
267                                 tdb_chainunlock(tdb, kbuf);
268                                 DEBUG(10,("message_send_pid_internal: discarding duplicate message.\n"));
269                                 SAFE_FREE(dbuf.dptr);
270                                 SAFE_FREE(old_dbuf.dptr);
271                                 return True;
272                         }
273                 }
274                 memcpy(&prec, ptr, sizeof(prec));
275                 ptr += sizeof(rec) + prec.len;
276         }
277
278         /* we're adding to an existing entry */
279
280         tdb_append(tdb, kbuf, dbuf);
281         tdb_chainunlock(tdb, kbuf);
282
283         SAFE_FREE(old_dbuf.dptr);
284         SAFE_FREE(dbuf.dptr);
285
286         errno = 0;                    /* paranoia */
287         return message_notify(pid);
288 }
289
290 /****************************************************************************
291  Send a message to a particular pid - no timeout.
292 ****************************************************************************/
293
294 BOOL message_send_pid(pid_t pid, int msg_type, const void *buf, size_t len, BOOL duplicates_allowed)
295 {
296         return message_send_pid_internal(pid, msg_type, buf, len, duplicates_allowed, 0);
297 }
298
299 /****************************************************************************
300  Send a message to a particular pid, with timeout in seconds.
301 ****************************************************************************/
302
303 BOOL message_send_pid_with_timeout(pid_t pid, int msg_type, const void *buf, size_t len,
304                 BOOL duplicates_allowed, unsigned int timeout)
305 {
306         return message_send_pid_internal(pid, msg_type, buf, len, duplicates_allowed, timeout);
307 }
308
309 /****************************************************************************
310  Retrieve all messages for the current process.
311 ****************************************************************************/
312
313 static BOOL retrieve_all_messages(char **msgs_buf, size_t *total_len)
314 {
315         TDB_DATA kbuf;
316         TDB_DATA dbuf;
317         TDB_DATA null_dbuf;
318
319         ZERO_STRUCT(null_dbuf);
320
321         *msgs_buf = NULL;
322         *total_len = 0;
323
324         kbuf = message_key_pid(getpid());
325
326         tdb_chainlock(tdb, kbuf);
327         dbuf = tdb_fetch(tdb, kbuf);
328         /*
329          * Replace with an empty record to keep the allocated
330          * space in the tdb.
331          */
332         tdb_store(tdb, kbuf, null_dbuf, TDB_REPLACE);
333         tdb_chainunlock(tdb, kbuf);
334
335         if (dbuf.dptr == NULL || dbuf.dsize == 0) {
336                 SAFE_FREE(dbuf.dptr);
337                 return False;
338         }
339
340         *msgs_buf = dbuf.dptr;
341         *total_len = dbuf.dsize;
342
343         return True;
344 }
345
346 /****************************************************************************
347  Parse out the next message for the current process.
348 ****************************************************************************/
349
350 static BOOL message_recv(char *msgs_buf, size_t total_len, int *msg_type, pid_t *src, char **buf, size_t *len)
351 {
352         struct message_rec rec;
353         char *ret_buf = *buf;
354
355         *buf = NULL;
356         *len = 0;
357
358         if (total_len - (ret_buf - msgs_buf) < sizeof(rec))
359                 return False;
360
361         memcpy(&rec, ret_buf, sizeof(rec));
362         ret_buf += sizeof(rec);
363
364         if (rec.msg_version != MESSAGE_VERSION) {
365                 DEBUG(0,("message version %d received (expected %d)\n", rec.msg_version, MESSAGE_VERSION));
366                 return False;
367         }
368
369         if (rec.len > 0) {
370                 if (total_len - (ret_buf - msgs_buf) < rec.len)
371                         return False;
372         }
373
374         *len = rec.len;
375         *msg_type = rec.msg_type;
376         *src = rec.src;
377         *buf = ret_buf;
378
379         return True;
380 }
381
382 /****************************************************************************
383  Receive and dispatch any messages pending for this process.
384  Notice that all dispatch handlers for a particular msg_type get called,
385  so you can register multiple handlers for a message.
386  *NOTE*: Dispatch functions must be able to cope with incoming
387  messages on an *odd* byte boundary.
388 ****************************************************************************/
389
390 void message_dispatch(void)
391 {
392         int msg_type;
393         pid_t src;
394         char *buf;
395         char *msgs_buf;
396         size_t len, total_len;
397         struct dispatch_fns *dfn;
398         int n_handled;
399
400         if (!received_signal)
401                 return;
402
403         DEBUG(10,("message_dispatch: received_signal = %d\n", received_signal));
404
405         received_signal = 0;
406
407         if (!retrieve_all_messages(&msgs_buf, &total_len))
408                 return;
409
410         for (buf = msgs_buf; message_recv(msgs_buf, total_len, &msg_type, &src, &buf, &len); buf += len) {
411                 DEBUG(10,("message_dispatch: received msg_type=%d src_pid=%u\n",
412                           msg_type, (unsigned int) src));
413                 n_handled = 0;
414                 for (dfn = dispatch_fns; dfn; dfn = dfn->next) {
415                         if (dfn->msg_type == msg_type) {
416                                 DEBUG(10,("message_dispatch: processing message of type %d.\n", msg_type));
417                                 dfn->fn(msg_type, src, len ? (void *)buf : NULL, len);
418                                 n_handled++;
419                         }
420                 }
421                 if (!n_handled) {
422                         DEBUG(5,("message_dispatch: warning: no handlers registed for "
423                                  "msg_type %d in pid %u\n",
424                                  msg_type, (unsigned int)getpid()));
425                 }
426         }
427         SAFE_FREE(msgs_buf);
428 }
429
430 /****************************************************************************
431  Register a dispatch function for a particular message type.
432  *NOTE*: Dispatch functions must be able to cope with incoming
433  messages on an *odd* byte boundary.
434 ****************************************************************************/
435
436 void message_register(int msg_type, 
437                       void (*fn)(int msg_type, pid_t pid, void *buf, size_t len))
438 {
439         struct dispatch_fns *dfn;
440
441         dfn = (struct dispatch_fns *)malloc(sizeof(*dfn));
442
443         if (dfn != NULL) {
444
445                 ZERO_STRUCTPN(dfn);
446
447                 dfn->msg_type = msg_type;
448                 dfn->fn = fn;
449
450                 DLIST_ADD(dispatch_fns, dfn);
451         }
452         else {
453         
454                 DEBUG(0,("message_register: Not enough memory. malloc failed!\n"));
455         }
456 }
457
458 /****************************************************************************
459  De-register the function for a particular message type.
460 ****************************************************************************/
461
462 void message_deregister(int msg_type)
463 {
464         struct dispatch_fns *dfn, *next;
465
466         for (dfn = dispatch_fns; dfn; dfn = next) {
467                 next = dfn->next;
468                 if (dfn->msg_type == msg_type) {
469                         DLIST_REMOVE(dispatch_fns, dfn);
470                         SAFE_FREE(dfn);
471                 }
472         }       
473 }
474
475 struct msg_all {
476         int msg_type;
477         uint32 msg_flag;
478         const void *buf;
479         size_t len;
480         BOOL duplicates;
481         int n_sent;
482 };
483
484 /****************************************************************************
485  Send one of the messages for the broadcast.
486 ****************************************************************************/
487
488 static int traverse_fn(TDB_CONTEXT *the_tdb, TDB_DATA kbuf, TDB_DATA dbuf, void *state)
489 {
490         struct connections_data crec;
491         struct msg_all *msg_all = (struct msg_all *)state;
492
493         if (dbuf.dsize != sizeof(crec))
494                 return 0;
495
496         memcpy(&crec, dbuf.dptr, sizeof(crec));
497
498         if (crec.cnum != -1)
499                 return 0;
500
501         /* Don't send if the receiver hasn't registered an interest. */
502
503         if(!(crec.bcast_msg_flags & msg_all->msg_flag))
504                 return 0;
505
506         /* If the msg send fails because the pid was not found (i.e. smbd died), 
507          * the msg has already been deleted from the messages.tdb.*/
508
509         if (!message_send_pid(crec.pid, msg_all->msg_type,
510                               msg_all->buf, msg_all->len,
511                               msg_all->duplicates)) {
512                 
513                 /* If the pid was not found delete the entry from connections.tdb */
514
515                 if (errno == ESRCH) {
516                         DEBUG(2,("pid %u doesn't exist - deleting connections %d [%s]\n",
517                                         (unsigned int)crec.pid, crec.cnum, crec.name));
518                         tdb_delete(the_tdb, kbuf);
519                 }
520         }
521         msg_all->n_sent++;
522         return 0;
523 }
524
525 /**
526  * Send a message to all smbd processes.
527  *
528  * It isn't very efficient, but should be OK for the sorts of
529  * applications that use it. When we need efficient broadcast we can add
530  * it.
531  *
532  * @param n_sent Set to the number of messages sent.  This should be
533  * equal to the number of processes, but be careful for races.
534  *
535  * @retval True for success.
536  **/
537 BOOL message_send_all(TDB_CONTEXT *conn_tdb, int msg_type,
538                       const void *buf, size_t len,
539                       BOOL duplicates_allowed,
540                       int *n_sent)
541 {
542         struct msg_all msg_all;
543
544         msg_all.msg_type = msg_type;
545         if (msg_type < 1000)
546                 msg_all.msg_flag = FLAG_MSG_GENERAL;
547         else if (msg_type > 1000 && msg_type < 2000)
548                 msg_all.msg_flag = FLAG_MSG_NMBD;
549         else if (msg_type > 2000 && msg_type < 3000)
550                 msg_all.msg_flag = FLAG_MSG_PRINTING;
551         else if (msg_type > 3000 && msg_type < 4000)
552                 msg_all.msg_flag = FLAG_MSG_SMBD;
553         else
554                 return False;
555
556         msg_all.buf = buf;
557         msg_all.len = len;
558         msg_all.duplicates = duplicates_allowed;
559         msg_all.n_sent = 0;
560
561         tdb_traverse(conn_tdb, traverse_fn, &msg_all);
562         if (n_sent)
563                 *n_sent = msg_all.n_sent;
564         return True;
565 }
566 /** @} **/