Lots of fixes for error paths where tdb_fetch() data need freeing.
[bbaumbach/samba-autobuild/.git] / source / lib / messages.c
1 /* 
2    Unix SMB/CIFS implementation.
3    Samba internal messaging functions
4    Copyright (C) Andrew Tridgell 2000
5    Copyright (C) 2001 by Martin Pool
6    
7    This program is free software; you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation; either version 2 of the License, or
10    (at your option) any later version.
11    
12    This program is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16    
17    You should have received a copy of the GNU General Public License
18    along with this program; if not, write to the Free Software
19    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20 */
21
22 /**
23    @defgroups messages Internal messaging framework
24    @{
25    @file messages.c
26
27    This module is used for internal messaging between Samba daemons. 
28
29    The idea is that if a part of Samba wants to do communication with
30    another Samba process then it will do a message_register() of a
31    dispatch function, and use message_send_pid() to send messages to
32    that process.
33
34    The dispatch function is given the pid of the sender, and it can
35    use that to reply by message_send_pid().  See ping_message() for a
36    simple example.
37
38    This system doesn't have any inherent size limitations but is not
39    very efficient for large messages or when messages are sent in very
40    quick succession.
41
42 */
43
44 #include "includes.h"
45
46 /* the locking database handle */
47 static TDB_CONTEXT *tdb;
48 static int received_signal;
49
50 /* change the message version with any incompatible changes in the protocol */
51 #define MESSAGE_VERSION 1
52
53 struct message_rec {
54         int msg_version;
55         int msg_type;
56         pid_t dest;
57         pid_t src;
58         size_t len;
59 };
60
61 /* we have a linked list of dispatch handlers */
62 static struct dispatch_fns {
63         struct dispatch_fns *next, *prev;
64         int msg_type;
65         void (*fn)(int msg_type, pid_t pid, void *buf, size_t len);
66 } *dispatch_fns;
67
68 /****************************************************************************
69  Notifications come in as signals.
70 ****************************************************************************/
71
72 static void sig_usr1(void)
73 {
74         received_signal = 1;
75         sys_select_signal();
76 }
77
78 /****************************************************************************
79  A useful function for testing the message system.
80 ****************************************************************************/
81
82 static void ping_message(int msg_type, pid_t src, void *buf, size_t len)
83 {
84         char *msg = buf ? buf : "none";
85         DEBUG(1,("INFO: Received PING message from PID %u [%s]\n",(unsigned int)src, msg));
86         message_send_pid(src, MSG_PONG, buf, len, True);
87 }
88
89 /****************************************************************************
90  Initialise the messaging functions. 
91 ****************************************************************************/
92
93 BOOL message_init(void)
94 {
95         if (tdb) return True;
96
97         tdb = tdb_open_log(lock_path("messages.tdb"), 
98                        0, TDB_CLEAR_IF_FIRST|TDB_DEFAULT, 
99                        O_RDWR|O_CREAT,0600);
100
101         if (!tdb) {
102                 DEBUG(0,("ERROR: Failed to initialise messages database\n"));
103                 return False;
104         }
105
106         CatchSignal(SIGUSR1, SIGNAL_CAST sig_usr1);
107
108         message_register(MSG_PING, ping_message);
109
110         return True;
111 }
112
113 /*******************************************************************
114  Form a static tdb key from a pid.
115 ******************************************************************/
116
117 static TDB_DATA message_key_pid(pid_t pid)
118 {
119         static char key[20];
120         TDB_DATA kbuf;
121
122         slprintf(key, sizeof(key)-1, "PID/%d", (int)pid);
123         
124         kbuf.dptr = (char *)key;
125         kbuf.dsize = strlen(key)+1;
126         return kbuf;
127 }
128
129 /****************************************************************************
130  Notify a process that it has a message. If the process doesn't exist 
131  then delete its record in the database.
132 ****************************************************************************/
133
134 static BOOL message_notify(pid_t pid)
135 {
136         /* Doing kill with a non-positive pid causes messages to be
137          * sent to places we don't want. */
138         SMB_ASSERT(pid > 0);
139         if (kill(pid, SIGUSR1) == -1) {
140                 if (errno == ESRCH) {
141                         DEBUG(2,("pid %d doesn't exist - deleting messages record\n", (int)pid));
142                         tdb_delete(tdb, message_key_pid(pid));
143                 } else {
144                         DEBUG(2,("message to process %d failed - %s\n", (int)pid, strerror(errno)));
145                 }
146                 return False;
147         }
148         return True;
149 }
150
151 /****************************************************************************
152  Send a message to a particular pid.
153 ****************************************************************************/
154
155 BOOL message_send_pid(pid_t pid, int msg_type, const void *buf, size_t len,
156                       BOOL duplicates_allowed)
157 {
158         TDB_DATA kbuf;
159         TDB_DATA dbuf;
160         struct message_rec rec;
161         void *p;
162
163         rec.msg_version = MESSAGE_VERSION;
164         rec.msg_type = msg_type;
165         rec.dest = pid;
166         rec.src = sys_getpid();
167         rec.len = len;
168
169         /* Doing kill with a non-positive pid causes messages to be
170          * sent to places we don't want. */
171         SMB_ASSERT(pid > 0);
172
173         kbuf = message_key_pid(pid);
174
175         /* lock the record for the destination */
176         tdb_chainlock(tdb, kbuf);
177
178         dbuf = tdb_fetch(tdb, kbuf);
179
180         if (!dbuf.dptr) {
181                 /* its a new record */
182                 p = (void *)malloc(len + sizeof(rec));
183                 if (!p)
184                         goto failed;
185
186                 memcpy(p, &rec, sizeof(rec));
187                 if (len > 0)
188                         memcpy((void *)((char*)p+sizeof(rec)), buf, len);
189
190                 dbuf.dptr = p;
191                 dbuf.dsize = len + sizeof(rec);
192                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
193                 SAFE_FREE(p);
194                 goto ok;
195         }
196
197         if (!duplicates_allowed) {
198                 char *ptr;
199                 struct message_rec prec;
200                 
201                 for(ptr = (char *)dbuf.dptr; ptr < dbuf.dptr + dbuf.dsize; ) {
202                         /*
203                          * First check if the message header matches, then, if it's a non-zero
204                          * sized message, check if the data matches. If so it's a duplicate and
205                          * we can discard it. JRA.
206                          */
207
208                         if (!memcmp(ptr, &rec, sizeof(rec))) {
209                                 if (!len || (len && !memcmp( ptr + sizeof(rec), buf, len))) {
210                                         DEBUG(10,("message_send_pid: discarding duplicate message.\n"));
211                                         SAFE_FREE(dbuf.dptr);
212                                         tdb_chainunlock(tdb, kbuf);
213                                         return True;
214                                 }
215                         }
216                         memcpy(&prec, ptr, sizeof(prec));
217                         ptr += sizeof(rec) + prec.len;
218                 }
219         }
220
221         /* we're adding to an existing entry */
222         p = (void *)malloc(dbuf.dsize + len + sizeof(rec));
223         if (!p)
224                 goto failed;
225
226         memcpy(p, dbuf.dptr, dbuf.dsize);
227         memcpy((void *)((char*)p+dbuf.dsize), &rec, sizeof(rec));
228         if (len > 0)
229                 memcpy((void *)((char*)p+dbuf.dsize+sizeof(rec)), buf, len);
230
231         SAFE_FREE(dbuf.dptr);
232         dbuf.dptr = p;
233         dbuf.dsize += len + sizeof(rec);
234         tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
235         SAFE_FREE(dbuf.dptr);
236
237  ok:
238         tdb_chainunlock(tdb, kbuf);
239         errno = 0;                    /* paranoia */
240         return message_notify(pid);
241
242  failed:
243         tdb_chainunlock(tdb, kbuf);
244         errno = 0;                    /* paranoia */
245         return False;
246 }
247
248 /****************************************************************************
249  Retrieve the next message for the current process.
250 ****************************************************************************/
251
252 static BOOL message_recv(int *msg_type, pid_t *src, void **buf, size_t *len)
253 {
254         TDB_DATA kbuf;
255         TDB_DATA dbuf;
256         struct message_rec rec;
257
258         kbuf = message_key_pid(sys_getpid());
259
260         tdb_chainlock(tdb, kbuf);
261         
262         dbuf = tdb_fetch(tdb, kbuf);
263         if (dbuf.dptr == NULL || dbuf.dsize == 0)
264                 goto failed;
265
266         memcpy(&rec, dbuf.dptr, sizeof(rec));
267
268         if (rec.msg_version != MESSAGE_VERSION) {
269                 DEBUG(0,("message version %d received (expected %d)\n", rec.msg_version, MESSAGE_VERSION));
270                 goto failed;
271         }
272
273         if (rec.len > 0) {
274                 (*buf) = (void *)malloc(rec.len);
275                 if (!(*buf))
276                         goto failed;
277
278                 memcpy(*buf, dbuf.dptr+sizeof(rec), rec.len);
279         } else {
280                 *buf = NULL;
281         }
282
283         *len = rec.len;
284         *msg_type = rec.msg_type;
285         *src = rec.src;
286
287         if (dbuf.dsize - (sizeof(rec)+rec.len) > 0)
288                 memmove(dbuf.dptr, dbuf.dptr+sizeof(rec)+rec.len, dbuf.dsize - (sizeof(rec)+rec.len));
289         dbuf.dsize -= sizeof(rec)+rec.len;
290
291         if (dbuf.dsize == 0)
292                 tdb_delete(tdb, kbuf);
293         else
294                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
295
296         SAFE_FREE(dbuf.dptr);
297         tdb_chainunlock(tdb, kbuf);
298         return True;
299
300  failed:
301         tdb_chainunlock(tdb, kbuf);
302         SAFE_FREE(dbuf.dptr);
303         return False;
304 }
305
306 /****************************************************************************
307  Receive and dispatch any messages pending for this process.
308  Notice that all dispatch handlers for a particular msg_type get called,
309  so you can register multiple handlers for a message.
310 ****************************************************************************/
311
312 void message_dispatch(void)
313 {
314         int msg_type;
315         pid_t src;
316         void *buf;
317         size_t len;
318         struct dispatch_fns *dfn;
319         int n_handled;
320
321         if (!received_signal) return;
322
323         DEBUG(10,("message_dispatch: received_signal = %d\n", received_signal));
324
325         received_signal = 0;
326
327         while (message_recv(&msg_type, &src, &buf, &len)) {
328                 DEBUG(10,("message_dispatch: received msg_type=%d src_pid=%d\n",
329                           msg_type, (int) src));
330                 n_handled = 0;
331                 for (dfn = dispatch_fns; dfn; dfn = dfn->next) {
332                         if (dfn->msg_type == msg_type) {
333                                 DEBUG(10,("message_dispatch: processing message of type %d.\n", msg_type));
334                                 dfn->fn(msg_type, src, buf, len);
335                                 n_handled++;
336                         }
337                 }
338                 if (!n_handled) {
339                         DEBUG(5,("message_dispatch: warning: no handlers registered for "
340                                  "msg_type %d in pid %d\n",
341                                  msg_type, sys_getpid()));
342                 }
343                 SAFE_FREE(buf);
344         }
345 }
346
347 /****************************************************************************
348  Register a dispatch function for a particular message type.
349 ****************************************************************************/
350
351 void message_register(int msg_type, 
352                       void (*fn)(int msg_type, pid_t pid, void *buf, size_t len))
353 {
354         struct dispatch_fns *dfn;
355
356         dfn = (struct dispatch_fns *)malloc(sizeof(*dfn));
357
358         if (dfn != NULL) {
359
360                 ZERO_STRUCTPN(dfn);
361
362                 dfn->msg_type = msg_type;
363                 dfn->fn = fn;
364
365                 DLIST_ADD(dispatch_fns, dfn);
366         }
367         else {
368         
369                 DEBUG(0,("message_register: Not enough memory. malloc failed!\n"));
370         }
371 }
372
373 /****************************************************************************
374  De-register the function for a particular message type.
375 ****************************************************************************/
376
377 void message_deregister(int msg_type)
378 {
379         struct dispatch_fns *dfn, *next;
380
381         for (dfn = dispatch_fns; dfn; dfn = next) {
382                 next = dfn->next;
383                 if (dfn->msg_type == msg_type) {
384                         DLIST_REMOVE(dispatch_fns, dfn);
385                         SAFE_FREE(dfn);
386                 }
387         }       
388 }
389
390 struct msg_all {
391         int msg_type;
392         uint32 msg_flag;
393         const void *buf;
394         size_t len;
395         BOOL duplicates;
396         int n_sent;
397 };
398
399 /****************************************************************************
400  Send one of the messages for the broadcast.
401 ****************************************************************************/
402
403 static int traverse_fn(TDB_CONTEXT *the_tdb, TDB_DATA kbuf, TDB_DATA dbuf, void *state)
404 {
405         struct connections_data crec;
406         struct msg_all *msg_all = (struct msg_all *)state;
407
408         if (dbuf.dsize != sizeof(crec))
409                 return 0;
410
411         memcpy(&crec, dbuf.dptr, sizeof(crec));
412
413         if (crec.cnum != -1)
414                 return 0;
415
416         /* Don't send if the receiver hasn't registered an interest. */
417
418         if(!(crec.bcast_msg_flags & msg_all->msg_flag))
419                 return 0;
420
421         /* If the msg send fails because the pid was not found (i.e. smbd died), 
422          * the msg has already been deleted from the messages.tdb.*/
423
424         if (!message_send_pid(crec.pid, msg_all->msg_type,
425                               msg_all->buf, msg_all->len,
426                               msg_all->duplicates)) {
427                 
428                 /* If the pid was not found delete the entry from connections.tdb */
429
430                 if (errno == ESRCH) {
431                         DEBUG(2,("pid %u doesn't exist - deleting connections %d [%s]\n",
432                                         (unsigned int)crec.pid, crec.cnum, crec.name));
433                         tdb_delete(the_tdb, kbuf);
434                 }
435         }
436         msg_all->n_sent++;
437         return 0;
438 }
439
440 /**
441  * Send a message to all smbd processes.
442  *
443  * It isn't very efficient, but should be OK for the sorts of
444  * applications that use it. When we need efficient broadcast we can add
445  * it.
446  *
447  * @param n_sent Set to the number of messages sent.  This should be
448  * equal to the number of processes, but be careful for races.
449  *
450  * @return True for success.
451  **/
452 BOOL message_send_all(TDB_CONTEXT *conn_tdb, int msg_type,
453                       const void *buf, size_t len,
454                       BOOL duplicates_allowed,
455                       int *n_sent)
456 {
457         struct msg_all msg_all;
458
459         msg_all.msg_type = msg_type;
460         if (msg_type < 1000)
461                 msg_all.msg_flag = FLAG_MSG_GENERAL;
462         else if (msg_type > 1000 && msg_type < 2000)
463                 msg_all.msg_flag = FLAG_MSG_NMBD;
464         else if (msg_type > 2000 && msg_type < 3000)
465                 msg_all.msg_flag = FLAG_MSG_PRINTING;
466         else if (msg_type > 3000 && msg_type < 4000)
467                 msg_all.msg_flag = FLAG_MSG_SMBD;
468         else
469                 return False;
470
471         msg_all.buf = buf;
472         msg_all.len = len;
473         msg_all.duplicates = duplicates_allowed;
474         msg_all.n_sent = 0;
475
476         tdb_traverse(conn_tdb, traverse_fn, &msg_all);
477         if (n_sent)
478                 *n_sent = msg_all.n_sent;
479         return True;
480 }
481 /** @} **/