r15854: more talloc_set_destructor() typesafe fixes
[bbaumbach/samba-autobuild/.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21    
22 */
23
24 #include "includes.h"
25 #include "smb.h"
26 #include "dlinklist.h"
27 #include "lib/events/events.h"
28 #include "lib/socket/socket.h"
29 #include "lib/stream/packet.h"
30
31
32 struct packet_context {
33         packet_callback_fn_t callback;
34         packet_full_request_fn_t full_request;
35         packet_error_handler_fn_t error_handler;
36         DATA_BLOB partial;
37         uint32_t num_read;
38         uint32_t initial_read;
39         struct socket_context *sock;
40         struct event_context *ev;
41         size_t packet_size;
42         void *private;
43         struct fd_event *fde;
44         BOOL serialise;
45         int processing;
46         BOOL recv_disable;
47         BOOL nofree;
48
49         BOOL busy;
50         BOOL destructor_called;
51
52         struct send_element {
53                 struct send_element *next, *prev;
54                 DATA_BLOB blob;
55                 size_t nsent;
56         } *send_queue;
57 };
58
59 /*
60   a destructor used when we are processing packets to prevent freeing of this
61   context while it is being used
62 */
63 static int packet_destructor(struct packet_context *pc)
64 {
65         if (pc->busy) {
66                 pc->destructor_called = True;
67                 /* now we refuse the talloc_free() request. The free will
68                    happen again in the packet_recv() code */
69                 return -1;
70         }
71
72         return 0;
73 }
74
75
76 /*
77   initialise a packet receiver
78 */
79 _PUBLIC_ struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
80 {
81         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
82         if (pc != NULL) {
83                 talloc_set_destructor(pc, packet_destructor);
84         }
85         return pc;
86 }
87
88
89 /*
90   set the request callback, called when a full request is ready
91 */
92 _PUBLIC_ void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
93 {
94         pc->callback = callback;
95 }
96
97 /*
98   set the error handler
99 */
100 _PUBLIC_ void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
101 {
102         pc->error_handler = handler;
103 }
104
105 /*
106   set the private pointer passed to the callback functions
107 */
108 _PUBLIC_ void packet_set_private(struct packet_context *pc, void *private)
109 {
110         pc->private = private;
111 }
112
113 /*
114   set the full request callback. Should return as follows:
115      NT_STATUS_OK == blob is a full request.
116      STATUS_MORE_ENTRIES == blob is not complete yet
117      any error == blob is not a valid 
118 */
119 _PUBLIC_ void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
120 {
121         pc->full_request = callback;
122 }
123
124 /*
125   set a socket context to use. You must set a socket_context
126 */
127 _PUBLIC_ void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
128 {
129         pc->sock = sock;
130 }
131
132 /*
133   set an event context. If this is set then the code will ensure that
134   packets arrive with separate events, by creating a immediate event
135   for any secondary packets when more than one packet is read at one
136   time on a socket. This can matter for code that relies on not
137   getting more than one packet per event
138 */
139 _PUBLIC_ void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
140 {
141         pc->ev = ev;
142 }
143
144 /*
145   tell the packet layer the fde for the socket
146 */
147 _PUBLIC_ void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
148 {
149         pc->fde = fde;
150 }
151
152 /*
153   tell the packet layer to serialise requests, so we don't process two
154   requests at once on one connection. You must have set the
155   event_context and fde
156 */
157 _PUBLIC_ void packet_set_serialise(struct packet_context *pc)
158 {
159         pc->serialise = True;
160 }
161
162 /*
163   tell the packet layer how much to read when starting a new packet
164   this ensures it doesn't overread
165 */
166 _PUBLIC_ void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
167 {
168         pc->initial_read = initial_read;
169 }
170
171 /*
172   tell the packet system not to steal/free blobs given to packet_send()
173 */
174 _PUBLIC_ void packet_set_nofree(struct packet_context *pc)
175 {
176         pc->nofree = True;
177 }
178
179
180 /*
181   tell the caller we have an error
182 */
183 static void packet_error(struct packet_context *pc, NTSTATUS status)
184 {
185         pc->sock = NULL;
186         if (pc->error_handler) {
187                 pc->error_handler(pc->private, status);
188                 return;
189         }
190         /* default error handler is to free the callers private pointer */
191         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
192                 DEBUG(0,("packet_error on %s - %s\n", 
193                          talloc_get_name(pc->private), nt_errstr(status)));
194         }
195         talloc_free(pc->private);
196         return;
197 }
198
199
200 /*
201   tell the caller we have EOF
202 */
203 static void packet_eof(struct packet_context *pc)
204 {
205         packet_error(pc, NT_STATUS_END_OF_FILE);
206 }
207
208
209 /*
210   used to put packets on event boundaries
211 */
212 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
213                               struct timeval t, void *private)
214 {
215         struct packet_context *pc = talloc_get_type(private, struct packet_context);
216         if (pc->num_read != 0 && pc->packet_size != 0 &&
217             pc->packet_size <= pc->num_read) {
218                 packet_recv(pc);
219         }
220 }
221
222
223 /*
224   call this when the socket becomes readable to kick off the whole
225   stream parsing process
226 */
227 _PUBLIC_ void packet_recv(struct packet_context *pc)
228 {
229         size_t npending;
230         NTSTATUS status;
231         size_t nread = 0;
232         DATA_BLOB blob;
233
234         if (pc->processing) {
235                 EVENT_FD_NOT_READABLE(pc->fde);
236                 pc->processing++;
237                 return;
238         }
239
240         if (pc->recv_disable) {
241                 EVENT_FD_NOT_READABLE(pc->fde);
242                 return;
243         }
244
245         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
246                 goto next_partial;
247         }
248
249         if (pc->packet_size != 0) {
250                 /* we've already worked out how long this next packet is, so skip the
251                    socket_pending() call */
252                 npending = pc->packet_size - pc->num_read;
253         } else if (pc->initial_read != 0) {
254                 npending = pc->initial_read - pc->num_read;
255         } else {
256                 if (pc->sock) {
257                         status = socket_pending(pc->sock, &npending);
258                 } else {
259                         status = NT_STATUS_CONNECTION_DISCONNECTED;
260                 }
261                 if (!NT_STATUS_IS_OK(status)) {
262                         packet_error(pc, status);
263                         return;
264                 }
265         }
266
267         if (npending == 0) {
268                 packet_eof(pc);
269                 return;
270         }
271
272         /* possibly expand the partial packet buffer */
273         if (npending + pc->num_read > pc->partial.length) {
274                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
275                 if (!NT_STATUS_IS_OK(status)) {
276                         packet_error(pc, status);
277                         return;
278                 }
279         }
280
281         status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
282                              npending, &nread);
283
284         if (NT_STATUS_IS_ERR(status)) {
285                 packet_error(pc, status);
286                 return;
287         }
288         if (!NT_STATUS_IS_OK(status)) {
289                 return;
290         }
291
292         if (nread == 0) {
293                 packet_eof(pc);
294                 return;
295         }
296
297         pc->num_read += nread;
298
299 next_partial:
300         if (pc->partial.length != pc->num_read) {
301                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
302                 if (!NT_STATUS_IS_OK(status)) {
303                         packet_error(pc, status);
304                         return;
305                 }
306         }
307
308         /* see if its a full request */
309         blob = pc->partial;
310         blob.length = pc->num_read;
311         status = pc->full_request(pc->private, blob, &pc->packet_size);
312         if (NT_STATUS_IS_ERR(status)) {
313                 packet_error(pc, status);
314                 return;
315         }
316         if (!NT_STATUS_IS_OK(status)) {
317                 return;
318         }
319
320         if (pc->packet_size > pc->num_read) {
321                 /* the caller made an error */
322                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
323                          (long)pc->packet_size, (long)pc->num_read));
324                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
325                 return;
326         }
327
328         /* it is a full request - give it to the caller */
329         blob = pc->partial;
330         blob.length = pc->num_read;
331
332         if (pc->packet_size < pc->num_read) {
333                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
334                                                pc->num_read - pc->packet_size);
335                 if (pc->partial.data == NULL) {
336                         packet_error(pc, NT_STATUS_NO_MEMORY);
337                         return;
338                 }
339                 status = data_blob_realloc(pc, &blob, pc->packet_size);
340                 if (!NT_STATUS_IS_OK(status)) {
341                         packet_error(pc, status);
342                         return;
343                 }
344         } else {
345                 pc->partial = data_blob(NULL, 0);
346         }
347         pc->num_read -= pc->packet_size;
348         pc->packet_size = 0;
349         
350         if (pc->serialise) {
351                 pc->processing = 1;
352         }
353
354         pc->busy = True;
355
356         status = pc->callback(pc->private, blob);
357
358         pc->busy = False;
359
360         if (pc->destructor_called) {
361                 talloc_free(pc);
362                 return;
363         }
364
365         if (pc->processing) {
366                 if (pc->processing > 1) {
367                         EVENT_FD_READABLE(pc->fde);
368                 }
369                 pc->processing = 0;
370         }
371
372         if (!NT_STATUS_IS_OK(status)) {
373                 packet_error(pc, status);
374                 return;
375         }
376
377         if (pc->partial.length == 0) {
378                 return;
379         }
380
381         /* we got multiple packets in one tcp read */
382         if (pc->ev == NULL) {
383                 goto next_partial;
384         }
385
386         blob = pc->partial;
387         blob.length = pc->num_read;
388
389         status = pc->full_request(pc->private, blob, &pc->packet_size);
390         if (NT_STATUS_IS_ERR(status)) {
391                 packet_error(pc, status);
392                 return;
393         }
394
395         if (!NT_STATUS_IS_OK(status)) {
396                 return;
397         }
398
399         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
400 }
401
402
403 /*
404   temporarily disable receiving 
405 */
406 _PUBLIC_ void packet_recv_disable(struct packet_context *pc)
407 {
408         EVENT_FD_NOT_READABLE(pc->fde);
409         pc->recv_disable = True;
410 }
411
412 /*
413   re-enable receiving 
414 */
415 _PUBLIC_ void packet_recv_enable(struct packet_context *pc)
416 {
417         EVENT_FD_READABLE(pc->fde);
418         pc->recv_disable = False;
419         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
420                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
421         }
422 }
423
424 /*
425   trigger a run of the send queue
426 */
427 _PUBLIC_ void packet_queue_run(struct packet_context *pc)
428 {
429         while (pc->send_queue) {
430                 struct send_element *el = pc->send_queue;
431                 NTSTATUS status;
432                 size_t nwritten;
433                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
434                                                  el->blob.length - el->nsent);
435
436                 status = socket_send(pc->sock, &blob, &nwritten);
437
438                 if (NT_STATUS_IS_ERR(status)) {
439                         packet_error(pc, NT_STATUS_NET_WRITE_FAULT);
440                         return;
441                 }
442                 if (!NT_STATUS_IS_OK(status)) {
443                         return;
444                 }
445                 el->nsent += nwritten;
446                 if (el->nsent == el->blob.length) {
447                         DLIST_REMOVE(pc->send_queue, el);
448                         talloc_free(el);
449                 }
450         }
451
452         /* we're out of requests to send, so don't wait for write
453            events any more */
454         EVENT_FD_NOT_WRITEABLE(pc->fde);
455 }
456
457 /*
458   put a packet in the send queue
459 */
460 _PUBLIC_ NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
461 {
462         struct send_element *el;
463         el = talloc(pc, struct send_element);
464         NT_STATUS_HAVE_NO_MEMORY(el);
465
466         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
467         el->blob = blob;
468         el->nsent = 0;
469
470         /* if we aren't going to free the packet then we must reference it
471            to ensure it doesn't disappear before going out */
472         if (pc->nofree) {
473                 if (!talloc_reference(el, blob.data)) {
474                         return NT_STATUS_NO_MEMORY;
475                 }
476         } else {
477                 talloc_steal(el, blob.data);
478         }
479
480         EVENT_FD_WRITEABLE(pc->fde);
481
482         return NT_STATUS_OK;
483 }
484
485
486 /*
487   a full request checker for NBT formatted packets (first 3 bytes are length)
488 */
489 _PUBLIC_ NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
490 {
491         if (blob.length < 4) {
492                 return STATUS_MORE_ENTRIES;
493         }
494         *size = 4 + smb_len(blob.data);
495         if (*size > blob.length) {
496                 return STATUS_MORE_ENTRIES;
497         }
498         return NT_STATUS_OK;
499 }
500
501
502 /*
503   work out if a packet is complete for protocols that use a 32 bit network byte
504   order length
505 */
506 _PUBLIC_ NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
507 {
508         if (blob.length < 4) {
509                 return STATUS_MORE_ENTRIES;
510         }
511         *size = 4 + RIVAL(blob.data, 0);
512         if (*size > blob.length) {
513                 return STATUS_MORE_ENTRIES;
514         }
515         return NT_STATUS_OK;
516 }