3f4fcb1ef7578ada455930ec0d8962bcd4e325de
[kai/samba.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 3 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20    
21 */
22
23 #include "includes.h"
24 #include "lib/util/dlinklist.h"
25 #include "lib/events/events.h"
26 #include "lib/socket/socket.h"
27 #include "lib/stream/packet.h"
28 #include "libcli/raw/smb.h"
29
30 struct packet_context {
31         packet_callback_fn_t callback;
32         packet_full_request_fn_t full_request;
33         packet_error_handler_fn_t error_handler;
34         DATA_BLOB partial;
35         uint32_t num_read;
36         uint32_t initial_read;
37         struct socket_context *sock;
38         struct event_context *ev;
39         size_t packet_size;
40         void *private;
41         struct fd_event *fde;
42         BOOL serialise;
43         int processing;
44         BOOL recv_disable;
45         BOOL nofree;
46
47         BOOL busy;
48         BOOL destructor_called;
49
50         struct send_element {
51                 struct send_element *next, *prev;
52                 DATA_BLOB blob;
53                 size_t nsent;
54                 packet_send_callback_fn_t send_callback;
55                 void *send_callback_private;
56         } *send_queue;
57 };
58
59 /*
60   a destructor used when we are processing packets to prevent freeing of this
61   context while it is being used
62 */
63 static int packet_destructor(struct packet_context *pc)
64 {
65         if (pc->busy) {
66                 pc->destructor_called = True;
67                 /* now we refuse the talloc_free() request. The free will
68                    happen again in the packet_recv() code */
69                 return -1;
70         }
71
72         return 0;
73 }
74
75
76 /*
77   initialise a packet receiver
78 */
79 _PUBLIC_ struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
80 {
81         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
82         if (pc != NULL) {
83                 talloc_set_destructor(pc, packet_destructor);
84         }
85         return pc;
86 }
87
88
89 /*
90   set the request callback, called when a full request is ready
91 */
92 _PUBLIC_ void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
93 {
94         pc->callback = callback;
95 }
96
97 /*
98   set the error handler
99 */
100 _PUBLIC_ void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
101 {
102         pc->error_handler = handler;
103 }
104
105 /*
106   set the private pointer passed to the callback functions
107 */
108 _PUBLIC_ void packet_set_private(struct packet_context *pc, void *private)
109 {
110         pc->private = private;
111 }
112
113 /*
114   set the full request callback. Should return as follows:
115      NT_STATUS_OK == blob is a full request.
116      STATUS_MORE_ENTRIES == blob is not complete yet
117      any error == blob is not a valid 
118 */
119 _PUBLIC_ void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
120 {
121         pc->full_request = callback;
122 }
123
124 /*
125   set a socket context to use. You must set a socket_context
126 */
127 _PUBLIC_ void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
128 {
129         pc->sock = sock;
130 }
131
132 /*
133   set an event context. If this is set then the code will ensure that
134   packets arrive with separate events, by creating a immediate event
135   for any secondary packets when more than one packet is read at one
136   time on a socket. This can matter for code that relies on not
137   getting more than one packet per event
138 */
139 _PUBLIC_ void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
140 {
141         pc->ev = ev;
142 }
143
144 /*
145   tell the packet layer the fde for the socket
146 */
147 _PUBLIC_ void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
148 {
149         pc->fde = fde;
150 }
151
152 /*
153   tell the packet layer to serialise requests, so we don't process two
154   requests at once on one connection. You must have set the
155   event_context and fde
156 */
157 _PUBLIC_ void packet_set_serialise(struct packet_context *pc)
158 {
159         pc->serialise = True;
160 }
161
162 /*
163   tell the packet layer how much to read when starting a new packet
164   this ensures it doesn't overread
165 */
166 _PUBLIC_ void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
167 {
168         pc->initial_read = initial_read;
169 }
170
171 /*
172   tell the packet system not to steal/free blobs given to packet_send()
173 */
174 _PUBLIC_ void packet_set_nofree(struct packet_context *pc)
175 {
176         pc->nofree = True;
177 }
178
179
180 /*
181   tell the caller we have an error
182 */
183 static void packet_error(struct packet_context *pc, NTSTATUS status)
184 {
185         pc->sock = NULL;
186         if (pc->error_handler) {
187                 pc->error_handler(pc->private, status);
188                 return;
189         }
190         /* default error handler is to free the callers private pointer */
191         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
192                 DEBUG(0,("packet_error on %s - %s\n", 
193                          talloc_get_name(pc->private), nt_errstr(status)));
194         }
195         talloc_free(pc->private);
196         return;
197 }
198
199
200 /*
201   tell the caller we have EOF
202 */
203 static void packet_eof(struct packet_context *pc)
204 {
205         packet_error(pc, NT_STATUS_END_OF_FILE);
206 }
207
208
209 /*
210   used to put packets on event boundaries
211 */
212 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
213                               struct timeval t, void *private)
214 {
215         struct packet_context *pc = talloc_get_type(private, struct packet_context);
216         if (pc->num_read != 0 && pc->packet_size != 0 &&
217             pc->packet_size <= pc->num_read) {
218                 packet_recv(pc);
219         }
220 }
221
222
223 /*
224   call this when the socket becomes readable to kick off the whole
225   stream parsing process
226 */
227 _PUBLIC_ void packet_recv(struct packet_context *pc)
228 {
229         size_t npending;
230         NTSTATUS status;
231         size_t nread = 0;
232         DATA_BLOB blob;
233
234         if (pc->processing) {
235                 EVENT_FD_NOT_READABLE(pc->fde);
236                 pc->processing++;
237                 return;
238         }
239
240         if (pc->recv_disable) {
241                 EVENT_FD_NOT_READABLE(pc->fde);
242                 return;
243         }
244
245         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
246                 goto next_partial;
247         }
248
249         if (pc->packet_size != 0) {
250                 /* we've already worked out how long this next packet is, so skip the
251                    socket_pending() call */
252                 npending = pc->packet_size - pc->num_read;
253         } else if (pc->initial_read != 0) {
254                 npending = pc->initial_read - pc->num_read;
255         } else {
256                 if (pc->sock) {
257                         status = socket_pending(pc->sock, &npending);
258                 } else {
259                         status = NT_STATUS_CONNECTION_DISCONNECTED;
260                 }
261                 if (!NT_STATUS_IS_OK(status)) {
262                         packet_error(pc, status);
263                         return;
264                 }
265         }
266
267         if (npending == 0) {
268                 packet_eof(pc);
269                 return;
270         }
271
272         if (npending + pc->num_read < npending) {
273                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
274                 return;
275         }
276
277         if (npending + pc->num_read < pc->num_read) {
278                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
279                 return;
280         }
281
282         /* possibly expand the partial packet buffer */
283         if (npending + pc->num_read > pc->partial.length) {
284                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
285                 if (!NT_STATUS_IS_OK(status)) {
286                         packet_error(pc, status);
287                         return;
288                 }
289         }
290
291         if (pc->partial.length < pc->num_read + npending) {
292                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
293                 return;
294         }
295
296         if ((uint8_t *)pc->partial.data + pc->num_read < (uint8_t *)pc->partial.data) {
297                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
298                 return;
299         }
300         if ((uint8_t *)pc->partial.data + pc->num_read + npending < (uint8_t *)pc->partial.data) {
301                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
302                 return;
303         }
304
305         status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
306                              npending, &nread);
307
308         if (NT_STATUS_IS_ERR(status)) {
309                 packet_error(pc, status);
310                 return;
311         }
312         if (!NT_STATUS_IS_OK(status)) {
313                 return;
314         }
315
316         if (nread == 0) {
317                 packet_eof(pc);
318                 return;
319         }
320
321         pc->num_read += nread;
322
323 next_partial:
324         if (pc->partial.length != pc->num_read) {
325                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
326                 if (!NT_STATUS_IS_OK(status)) {
327                         packet_error(pc, status);
328                         return;
329                 }
330         }
331
332         /* see if its a full request */
333         blob = pc->partial;
334         blob.length = pc->num_read;
335         status = pc->full_request(pc->private, blob, &pc->packet_size);
336         if (NT_STATUS_IS_ERR(status)) {
337                 packet_error(pc, status);
338                 return;
339         }
340         if (!NT_STATUS_IS_OK(status)) {
341                 return;
342         }
343
344         if (pc->packet_size > pc->num_read) {
345                 /* the caller made an error */
346                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
347                          (long)pc->packet_size, (long)pc->num_read));
348                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
349                 return;
350         }
351
352         /* it is a full request - give it to the caller */
353         blob = pc->partial;
354         blob.length = pc->num_read;
355
356         if (pc->packet_size < pc->num_read) {
357                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
358                                                pc->num_read - pc->packet_size);
359                 if (pc->partial.data == NULL) {
360                         packet_error(pc, NT_STATUS_NO_MEMORY);
361                         return;
362                 }
363                 /* Trunate the blob sent to the caller to only the packet length */
364                 status = data_blob_realloc(pc, &blob, pc->packet_size);
365                 if (!NT_STATUS_IS_OK(status)) {
366                         packet_error(pc, status);
367                         return;
368                 }
369         } else {
370                 pc->partial = data_blob(NULL, 0);
371         }
372         pc->num_read -= pc->packet_size;
373         pc->packet_size = 0;
374         
375         if (pc->serialise) {
376                 pc->processing = 1;
377         }
378
379         pc->busy = True;
380
381         status = pc->callback(pc->private, blob);
382
383         pc->busy = False;
384
385         if (pc->destructor_called) {
386                 talloc_free(pc);
387                 return;
388         }
389
390         if (pc->processing) {
391                 if (pc->processing > 1) {
392                         EVENT_FD_READABLE(pc->fde);
393                 }
394                 pc->processing = 0;
395         }
396
397         if (!NT_STATUS_IS_OK(status)) {
398                 packet_error(pc, status);
399                 return;
400         }
401
402         /* Have we consumed the whole buffer yet? */
403         if (pc->partial.length == 0) {
404                 return;
405         }
406
407         /* we got multiple packets in one tcp read */
408         if (pc->ev == NULL) {
409                 goto next_partial;
410         }
411
412         blob = pc->partial;
413         blob.length = pc->num_read;
414
415         status = pc->full_request(pc->private, blob, &pc->packet_size);
416         if (NT_STATUS_IS_ERR(status)) {
417                 packet_error(pc, status);
418                 return;
419         }
420
421         if (!NT_STATUS_IS_OK(status)) {
422                 return;
423         }
424
425         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
426 }
427
428
429 /*
430   temporarily disable receiving 
431 */
432 _PUBLIC_ void packet_recv_disable(struct packet_context *pc)
433 {
434         EVENT_FD_NOT_READABLE(pc->fde);
435         pc->recv_disable = True;
436 }
437
438 /*
439   re-enable receiving 
440 */
441 _PUBLIC_ void packet_recv_enable(struct packet_context *pc)
442 {
443         EVENT_FD_READABLE(pc->fde);
444         pc->recv_disable = False;
445         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
446                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
447         }
448 }
449
450 /*
451   trigger a run of the send queue
452 */
453 _PUBLIC_ void packet_queue_run(struct packet_context *pc)
454 {
455         while (pc->send_queue) {
456                 struct send_element *el = pc->send_queue;
457                 NTSTATUS status;
458                 size_t nwritten;
459                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
460                                                  el->blob.length - el->nsent);
461
462                 status = socket_send(pc->sock, &blob, &nwritten);
463
464                 if (NT_STATUS_IS_ERR(status)) {
465                         packet_error(pc, status);
466                         return;
467                 }
468                 if (!NT_STATUS_IS_OK(status)) {
469                         return;
470                 }
471                 el->nsent += nwritten;
472                 if (el->nsent == el->blob.length) {
473                         DLIST_REMOVE(pc->send_queue, el);
474                         if (el->send_callback) {
475                                 el->send_callback(el->send_callback_private);
476                         }
477                         talloc_free(el);
478                 }
479         }
480
481         /* we're out of requests to send, so don't wait for write
482            events any more */
483         EVENT_FD_NOT_WRITEABLE(pc->fde);
484 }
485
486 /*
487   put a packet in the send queue.  When the packet is actually sent,
488   call send_callback.  
489
490   Useful for operations that must occour after sending a message, such
491   as the switch to SASL encryption after as sucessful LDAP bind relpy.
492 */
493 _PUBLIC_ NTSTATUS packet_send_callback(struct packet_context *pc, DATA_BLOB blob,
494                                        packet_send_callback_fn_t send_callback, 
495                                        void *private)
496 {
497         struct send_element *el;
498         el = talloc(pc, struct send_element);
499         NT_STATUS_HAVE_NO_MEMORY(el);
500
501         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
502         el->blob = blob;
503         el->nsent = 0;
504         el->send_callback = send_callback;
505         el->send_callback_private = private;
506
507         /* if we aren't going to free the packet then we must reference it
508            to ensure it doesn't disappear before going out */
509         if (pc->nofree) {
510                 if (!talloc_reference(el, blob.data)) {
511                         return NT_STATUS_NO_MEMORY;
512                 }
513         } else {
514                 talloc_steal(el, blob.data);
515         }
516
517         if (private && !talloc_reference(el, private)) {
518                 return NT_STATUS_NO_MEMORY;
519         }
520
521         EVENT_FD_WRITEABLE(pc->fde);
522
523         return NT_STATUS_OK;
524 }
525
526 /*
527   put a packet in the send queue
528 */
529 _PUBLIC_ NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
530 {
531         return packet_send_callback(pc, blob, NULL, NULL);
532 }
533
534
535 /*
536   a full request checker for NBT formatted packets (first 3 bytes are length)
537 */
538 _PUBLIC_ NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
539 {
540         if (blob.length < 4) {
541                 return STATUS_MORE_ENTRIES;
542         }
543         *size = 4 + smb_len(blob.data);
544         if (*size > blob.length) {
545                 return STATUS_MORE_ENTRIES;
546         }
547         return NT_STATUS_OK;
548 }
549
550
551 /*
552   work out if a packet is complete for protocols that use a 32 bit network byte
553   order length
554 */
555 _PUBLIC_ NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
556 {
557         if (blob.length < 4) {
558                 return STATUS_MORE_ENTRIES;
559         }
560         *size = 4 + RIVAL(blob.data, 0);
561         if (*size > blob.length) {
562                 return STATUS_MORE_ENTRIES;
563         }
564         return NT_STATUS_OK;
565 }