1da7f5706b9706e1a0f205041b6bedb2c64b9bf8
[bbaumbach/samba-autobuild/.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21    
22 */
23
24 #include "includes.h"
25 #include "smb.h"
26 #include "dlinklist.h"
27 #include "lib/events/events.h"
28 #include "lib/socket/socket.h"
29 #include "lib/stream/packet.h"
30
31
32 struct packet_context {
33         packet_callback_fn_t callback;
34         packet_full_request_fn_t full_request;
35         packet_error_handler_fn_t error_handler;
36         DATA_BLOB partial;
37         uint32_t num_read;
38         uint32_t initial_read;
39         struct socket_context *sock;
40         struct event_context *ev;
41         size_t packet_size;
42         void *private;
43         struct fd_event *fde;
44         BOOL serialise;
45         int processing;
46         BOOL recv_disable;
47         BOOL nofree;
48
49         BOOL busy;
50         BOOL destructor_called;
51
52         struct send_element {
53                 struct send_element *next, *prev;
54                 DATA_BLOB blob;
55                 size_t nsent;
56         } *send_queue;
57 };
58
59 /*
60   a destructor used when we are processing packets to prevent freeing of this
61   context while it is being used
62 */
63 static int packet_destructor(void *p)
64 {
65         struct packet_context *pc = talloc_get_type(p, struct packet_context);
66
67         if (pc->busy) {
68                 pc->destructor_called = True;
69                 /* now we refuse the talloc_free() request. The free will
70                    happen again in the packet_recv() code */
71                 return -1;
72         }
73
74         return 0;
75 }
76
77
78 /*
79   initialise a packet receiver
80 */
81 _PUBLIC_ struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
82 {
83         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
84         if (pc != NULL) {
85                 talloc_set_destructor(pc, packet_destructor);
86         }
87         return pc;
88 }
89
90
91 /*
92   set the request callback, called when a full request is ready
93 */
94 _PUBLIC_ void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
95 {
96         pc->callback = callback;
97 }
98
99 /*
100   set the error handler
101 */
102 _PUBLIC_ void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
103 {
104         pc->error_handler = handler;
105 }
106
107 /*
108   set the private pointer passed to the callback functions
109 */
110 _PUBLIC_ void packet_set_private(struct packet_context *pc, void *private)
111 {
112         pc->private = private;
113 }
114
115 /*
116   set the full request callback. Should return as follows:
117      NT_STATUS_OK == blob is a full request.
118      STATUS_MORE_ENTRIES == blob is not complete yet
119      any error == blob is not a valid 
120 */
121 _PUBLIC_ void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
122 {
123         pc->full_request = callback;
124 }
125
126 /*
127   set a socket context to use. You must set a socket_context
128 */
129 _PUBLIC_ void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
130 {
131         pc->sock = sock;
132 }
133
134 /*
135   set an event context. If this is set then the code will ensure that
136   packets arrive with separate events, by creating a immediate event
137   for any secondary packets when more than one packet is read at one
138   time on a socket. This can matter for code that relies on not
139   getting more than one packet per event
140 */
141 _PUBLIC_ void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
142 {
143         pc->ev = ev;
144 }
145
146 /*
147   tell the packet layer the fde for the socket
148 */
149 _PUBLIC_ void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
150 {
151         pc->fde = fde;
152 }
153
154 /*
155   tell the packet layer to serialise requests, so we don't process two
156   requests at once on one connection. You must have set the
157   event_context and fde
158 */
159 _PUBLIC_ void packet_set_serialise(struct packet_context *pc)
160 {
161         pc->serialise = True;
162 }
163
164 /*
165   tell the packet layer how much to read when starting a new packet
166   this ensures it doesn't overread
167 */
168 _PUBLIC_ void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
169 {
170         pc->initial_read = initial_read;
171 }
172
173 /*
174   tell the packet system not to steal/free blobs given to packet_send()
175 */
176 _PUBLIC_ void packet_set_nofree(struct packet_context *pc)
177 {
178         pc->nofree = True;
179 }
180
181
182 /*
183   tell the caller we have an error
184 */
185 static void packet_error(struct packet_context *pc, NTSTATUS status)
186 {
187         pc->sock = NULL;
188         if (pc->error_handler) {
189                 pc->error_handler(pc->private, status);
190                 return;
191         }
192         /* default error handler is to free the callers private pointer */
193         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
194                 DEBUG(0,("packet_error on %s - %s\n", 
195                          talloc_get_name(pc->private), nt_errstr(status)));
196         }
197         talloc_free(pc->private);
198         return;
199 }
200
201
202 /*
203   tell the caller we have EOF
204 */
205 static void packet_eof(struct packet_context *pc)
206 {
207         packet_error(pc, NT_STATUS_END_OF_FILE);
208 }
209
210
211 /*
212   used to put packets on event boundaries
213 */
214 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
215                               struct timeval t, void *private)
216 {
217         struct packet_context *pc = talloc_get_type(private, struct packet_context);
218         if (pc->num_read != 0 && pc->packet_size != 0 &&
219             pc->packet_size <= pc->num_read) {
220                 packet_recv(pc);
221         }
222 }
223
224
225 /*
226   call this when the socket becomes readable to kick off the whole
227   stream parsing process
228 */
229 _PUBLIC_ void packet_recv(struct packet_context *pc)
230 {
231         size_t npending;
232         NTSTATUS status;
233         size_t nread = 0;
234         DATA_BLOB blob;
235
236         if (pc->processing) {
237                 EVENT_FD_NOT_READABLE(pc->fde);
238                 pc->processing++;
239                 return;
240         }
241
242         if (pc->recv_disable) {
243                 EVENT_FD_NOT_READABLE(pc->fde);
244                 return;
245         }
246
247         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
248                 goto next_partial;
249         }
250
251         if (pc->packet_size != 0) {
252                 /* we've already worked out how long this next packet is, so skip the
253                    socket_pending() call */
254                 npending = pc->packet_size - pc->num_read;
255         } else if (pc->initial_read != 0) {
256                 npending = pc->initial_read - pc->num_read;
257         } else {
258                 if (pc->sock) {
259                         status = socket_pending(pc->sock, &npending);
260                 } else {
261                         status = NT_STATUS_CONNECTION_DISCONNECTED;
262                 }
263                 if (!NT_STATUS_IS_OK(status)) {
264                         packet_error(pc, status);
265                         return;
266                 }
267         }
268
269         if (npending == 0) {
270                 packet_eof(pc);
271                 return;
272         }
273
274         /* possibly expand the partial packet buffer */
275         if (npending + pc->num_read > pc->partial.length) {
276                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
277                 if (!NT_STATUS_IS_OK(status)) {
278                         packet_error(pc, status);
279                         return;
280                 }
281         }
282
283         status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
284                              npending, &nread);
285
286         if (NT_STATUS_IS_ERR(status)) {
287                 packet_error(pc, status);
288                 return;
289         }
290         if (!NT_STATUS_IS_OK(status)) {
291                 return;
292         }
293
294         if (nread == 0) {
295                 packet_eof(pc);
296                 return;
297         }
298
299         pc->num_read += nread;
300
301 next_partial:
302         if (pc->partial.length != pc->num_read) {
303                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
304                 if (!NT_STATUS_IS_OK(status)) {
305                         packet_error(pc, status);
306                         return;
307                 }
308         }
309
310         /* see if its a full request */
311         blob = pc->partial;
312         blob.length = pc->num_read;
313         status = pc->full_request(pc->private, blob, &pc->packet_size);
314         if (NT_STATUS_IS_ERR(status)) {
315                 packet_error(pc, status);
316                 return;
317         }
318         if (!NT_STATUS_IS_OK(status)) {
319                 return;
320         }
321
322         if (pc->packet_size > pc->num_read) {
323                 /* the caller made an error */
324                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
325                          (long)pc->packet_size, (long)pc->num_read));
326                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
327                 return;
328         }
329
330         /* it is a full request - give it to the caller */
331         blob = pc->partial;
332         blob.length = pc->num_read;
333
334         if (pc->packet_size < pc->num_read) {
335                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
336                                                pc->num_read - pc->packet_size);
337                 if (pc->partial.data == NULL) {
338                         packet_error(pc, NT_STATUS_NO_MEMORY);
339                         return;
340                 }
341                 status = data_blob_realloc(pc, &blob, pc->packet_size);
342                 if (!NT_STATUS_IS_OK(status)) {
343                         packet_error(pc, status);
344                         return;
345                 }
346         } else {
347                 pc->partial = data_blob(NULL, 0);
348         }
349         pc->num_read -= pc->packet_size;
350         pc->packet_size = 0;
351         
352         if (pc->serialise) {
353                 pc->processing = 1;
354         }
355
356         pc->busy = True;
357
358         status = pc->callback(pc->private, blob);
359
360         pc->busy = False;
361
362         if (pc->destructor_called) {
363                 talloc_free(pc);
364                 return;
365         }
366
367         if (pc->processing) {
368                 if (pc->processing > 1) {
369                         EVENT_FD_READABLE(pc->fde);
370                 }
371                 pc->processing = 0;
372         }
373
374         if (!NT_STATUS_IS_OK(status)) {
375                 packet_error(pc, status);
376                 return;
377         }
378
379         if (pc->partial.length == 0) {
380                 return;
381         }
382
383         /* we got multiple packets in one tcp read */
384         if (pc->ev == NULL) {
385                 goto next_partial;
386         }
387
388         blob = pc->partial;
389         blob.length = pc->num_read;
390
391         status = pc->full_request(pc->private, blob, &pc->packet_size);
392         if (NT_STATUS_IS_ERR(status)) {
393                 packet_error(pc, status);
394                 return;
395         }
396
397         if (!NT_STATUS_IS_OK(status)) {
398                 return;
399         }
400
401         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
402 }
403
404
405 /*
406   temporarily disable receiving 
407 */
408 _PUBLIC_ void packet_recv_disable(struct packet_context *pc)
409 {
410         EVENT_FD_NOT_READABLE(pc->fde);
411         pc->recv_disable = True;
412 }
413
414 /*
415   re-enable receiving 
416 */
417 _PUBLIC_ void packet_recv_enable(struct packet_context *pc)
418 {
419         EVENT_FD_READABLE(pc->fde);
420         pc->recv_disable = False;
421         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
422                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
423         }
424 }
425
426 /*
427   trigger a run of the send queue
428 */
429 _PUBLIC_ void packet_queue_run(struct packet_context *pc)
430 {
431         while (pc->send_queue) {
432                 struct send_element *el = pc->send_queue;
433                 NTSTATUS status;
434                 size_t nwritten;
435                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
436                                                  el->blob.length - el->nsent);
437
438                 status = socket_send(pc->sock, &blob, &nwritten);
439
440                 if (NT_STATUS_IS_ERR(status)) {
441                         packet_error(pc, NT_STATUS_NET_WRITE_FAULT);
442                         return;
443                 }
444                 if (!NT_STATUS_IS_OK(status)) {
445                         return;
446                 }
447                 el->nsent += nwritten;
448                 if (el->nsent == el->blob.length) {
449                         DLIST_REMOVE(pc->send_queue, el);
450                         talloc_free(el);
451                 }
452         }
453
454         /* we're out of requests to send, so don't wait for write
455            events any more */
456         EVENT_FD_NOT_WRITEABLE(pc->fde);
457 }
458
459 /*
460   put a packet in the send queue
461 */
462 _PUBLIC_ NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
463 {
464         struct send_element *el;
465         el = talloc(pc, struct send_element);
466         NT_STATUS_HAVE_NO_MEMORY(el);
467
468         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
469         el->blob = blob;
470         el->nsent = 0;
471
472         /* if we aren't going to free the packet then we must reference it
473            to ensure it doesn't disappear before going out */
474         if (pc->nofree) {
475                 if (!talloc_reference(el, blob.data)) {
476                         return NT_STATUS_NO_MEMORY;
477                 }
478         } else {
479                 talloc_steal(el, blob.data);
480         }
481
482         EVENT_FD_WRITEABLE(pc->fde);
483
484         return NT_STATUS_OK;
485 }
486
487
488 /*
489   a full request checker for NBT formatted packets (first 3 bytes are length)
490 */
491 _PUBLIC_ NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
492 {
493         if (blob.length < 4) {
494                 return STATUS_MORE_ENTRIES;
495         }
496         *size = 4 + smb_len(blob.data);
497         if (*size > blob.length) {
498                 return STATUS_MORE_ENTRIES;
499         }
500         return NT_STATUS_OK;
501 }
502
503
504 /*
505   work out if a packet is complete for protocols that use a 32 bit network byte
506   order length
507 */
508 _PUBLIC_ NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
509 {
510         if (blob.length < 4) {
511                 return STATUS_MORE_ENTRIES;
512         }
513         *size = 4 + RIVAL(blob.data, 0);
514         if (*size > blob.length) {
515                 return STATUS_MORE_ENTRIES;
516         }
517         return NT_STATUS_OK;
518 }