Merge commit 'release-4-0-0alpha1' into v4-0-test
[kai/samba.git] / source / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 3 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20    
21 */
22
23 #include "includes.h"
24 #include "lib/util/dlinklist.h"
25 #include "lib/events/events.h"
26 #include "lib/socket/socket.h"
27 #include "lib/stream/packet.h"
28 #include "libcli/raw/smb.h"
29
30 struct packet_context {
31         packet_callback_fn_t callback;
32         packet_full_request_fn_t full_request;
33         packet_error_handler_fn_t error_handler;
34         DATA_BLOB partial;
35         uint32_t num_read;
36         uint32_t initial_read;
37         struct socket_context *sock;
38         struct event_context *ev;
39         size_t packet_size;
40         void *private;
41         struct fd_event *fde;
42         bool serialise;
43         int processing;
44         bool recv_disable;
45         bool nofree;
46
47         bool busy;
48         bool destructor_called;
49
50         struct send_element {
51                 struct send_element *next, *prev;
52                 DATA_BLOB blob;
53                 size_t nsent;
54                 packet_send_callback_fn_t send_callback;
55                 void *send_callback_private;
56         } *send_queue;
57 };
58
59 /*
60   a destructor used when we are processing packets to prevent freeing of this
61   context while it is being used
62 */
63 static int packet_destructor(struct packet_context *pc)
64 {
65         if (pc->busy) {
66                 pc->destructor_called = true;
67                 /* now we refuse the talloc_free() request. The free will
68                    happen again in the packet_recv() code */
69                 return -1;
70         }
71
72         return 0;
73 }
74
75
76 /*
77   initialise a packet receiver
78 */
79 _PUBLIC_ struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
80 {
81         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
82         if (pc != NULL) {
83                 talloc_set_destructor(pc, packet_destructor);
84         }
85         return pc;
86 }
87
88
89 /*
90   set the request callback, called when a full request is ready
91 */
92 _PUBLIC_ void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
93 {
94         pc->callback = callback;
95 }
96
97 /*
98   set the error handler
99 */
100 _PUBLIC_ void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
101 {
102         pc->error_handler = handler;
103 }
104
105 /*
106   set the private pointer passed to the callback functions
107 */
108 _PUBLIC_ void packet_set_private(struct packet_context *pc, void *private)
109 {
110         pc->private = private;
111 }
112
113 /*
114   set the full request callback. Should return as follows:
115      NT_STATUS_OK == blob is a full request.
116      STATUS_MORE_ENTRIES == blob is not complete yet
117      any error == blob is not a valid 
118 */
119 _PUBLIC_ void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
120 {
121         pc->full_request = callback;
122 }
123
124 /*
125   set a socket context to use. You must set a socket_context
126 */
127 _PUBLIC_ void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
128 {
129         pc->sock = sock;
130 }
131
132 /*
133   set an event context. If this is set then the code will ensure that
134   packets arrive with separate events, by creating a immediate event
135   for any secondary packets when more than one packet is read at one
136   time on a socket. This can matter for code that relies on not
137   getting more than one packet per event
138 */
139 _PUBLIC_ void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
140 {
141         pc->ev = ev;
142 }
143
144 /*
145   tell the packet layer the fde for the socket
146 */
147 _PUBLIC_ void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
148 {
149         pc->fde = fde;
150 }
151
152 /*
153   tell the packet layer to serialise requests, so we don't process two
154   requests at once on one connection. You must have set the
155   event_context and fde
156 */
157 _PUBLIC_ void packet_set_serialise(struct packet_context *pc)
158 {
159         pc->serialise = true;
160 }
161
162 /*
163   tell the packet layer how much to read when starting a new packet
164   this ensures it doesn't overread
165 */
166 _PUBLIC_ void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
167 {
168         pc->initial_read = initial_read;
169 }
170
171 /*
172   tell the packet system not to steal/free blobs given to packet_send()
173 */
174 _PUBLIC_ void packet_set_nofree(struct packet_context *pc)
175 {
176         pc->nofree = true;
177 }
178
179
180 /*
181   tell the caller we have an error
182 */
183 static void packet_error(struct packet_context *pc, NTSTATUS status)
184 {
185         pc->sock = NULL;
186         if (pc->error_handler) {
187                 pc->error_handler(pc->private, status);
188                 return;
189         }
190         /* default error handler is to free the callers private pointer */
191         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
192                 DEBUG(0,("packet_error on %s - %s\n", 
193                          talloc_get_name(pc->private), nt_errstr(status)));
194         }
195         talloc_free(pc->private);
196         return;
197 }
198
199
200 /*
201   tell the caller we have EOF
202 */
203 static void packet_eof(struct packet_context *pc)
204 {
205         packet_error(pc, NT_STATUS_END_OF_FILE);
206 }
207
208
209 /*
210   used to put packets on event boundaries
211 */
212 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
213                               struct timeval t, void *private)
214 {
215         struct packet_context *pc = talloc_get_type(private, struct packet_context);
216         if (pc->num_read != 0 && pc->packet_size != 0 &&
217             pc->packet_size <= pc->num_read) {
218                 packet_recv(pc);
219         }
220 }
221
222
223 /*
224   call this when the socket becomes readable to kick off the whole
225   stream parsing process
226 */
227 _PUBLIC_ void packet_recv(struct packet_context *pc)
228 {
229         size_t npending;
230         NTSTATUS status;
231         size_t nread = 0;
232         DATA_BLOB blob;
233
234         if (pc->processing) {
235                 EVENT_FD_NOT_READABLE(pc->fde);
236                 pc->processing++;
237                 return;
238         }
239
240         if (pc->recv_disable) {
241                 EVENT_FD_NOT_READABLE(pc->fde);
242                 return;
243         }
244
245         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
246                 goto next_partial;
247         }
248
249         if (pc->packet_size != 0) {
250                 /* we've already worked out how long this next packet is, so skip the
251                    socket_pending() call */
252                 npending = pc->packet_size - pc->num_read;
253         } else if (pc->initial_read != 0) {
254                 npending = pc->initial_read - pc->num_read;
255         } else {
256                 if (pc->sock) {
257                         status = socket_pending(pc->sock, &npending);
258                 } else {
259                         status = NT_STATUS_CONNECTION_DISCONNECTED;
260                 }
261                 if (!NT_STATUS_IS_OK(status)) {
262                         packet_error(pc, status);
263                         return;
264                 }
265         }
266
267         if (npending == 0) {
268                 packet_eof(pc);
269                 return;
270         }
271
272         if (npending + pc->num_read < npending) {
273                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
274                 return;
275         }
276
277         if (npending + pc->num_read < pc->num_read) {
278                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
279                 return;
280         }
281
282         /* possibly expand the partial packet buffer */
283         if (npending + pc->num_read > pc->partial.length) {
284                 if (!data_blob_realloc(pc, &pc->partial, npending+pc->num_read)) {
285                         packet_error(pc, NT_STATUS_NO_MEMORY);
286                         return;
287                 }
288         }
289
290         if (pc->partial.length < pc->num_read + npending) {
291                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
292                 return;
293         }
294
295         if ((uint8_t *)pc->partial.data + pc->num_read < (uint8_t *)pc->partial.data) {
296                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
297                 return;
298         }
299         if ((uint8_t *)pc->partial.data + pc->num_read + npending < (uint8_t *)pc->partial.data) {
300                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
301                 return;
302         }
303
304         status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
305                              npending, &nread);
306
307         if (NT_STATUS_IS_ERR(status)) {
308                 packet_error(pc, status);
309                 return;
310         }
311         if (!NT_STATUS_IS_OK(status)) {
312                 return;
313         }
314
315         if (nread == 0) {
316                 packet_eof(pc);
317                 return;
318         }
319
320         pc->num_read += nread;
321
322 next_partial:
323         if (pc->partial.length != pc->num_read) {
324                 if (!data_blob_realloc(pc, &pc->partial, pc->num_read)) {
325                         packet_error(pc, NT_STATUS_NO_MEMORY);
326                         return;
327                 }
328         }
329
330         /* see if its a full request */
331         blob = pc->partial;
332         blob.length = pc->num_read;
333         status = pc->full_request(pc->private, blob, &pc->packet_size);
334         if (NT_STATUS_IS_ERR(status)) {
335                 packet_error(pc, status);
336                 return;
337         }
338         if (!NT_STATUS_IS_OK(status)) {
339                 return;
340         }
341
342         if (pc->packet_size > pc->num_read) {
343                 /* the caller made an error */
344                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
345                          (long)pc->packet_size, (long)pc->num_read));
346                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
347                 return;
348         }
349
350         /* it is a full request - give it to the caller */
351         blob = pc->partial;
352         blob.length = pc->num_read;
353
354         if (pc->packet_size < pc->num_read) {
355                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
356                                                pc->num_read - pc->packet_size);
357                 if (pc->partial.data == NULL) {
358                         packet_error(pc, NT_STATUS_NO_MEMORY);
359                         return;
360                 }
361                 /* Trunate the blob sent to the caller to only the packet length */
362                 if (!data_blob_realloc(pc, &blob, pc->packet_size)) {
363                         packet_error(pc, NT_STATUS_NO_MEMORY);
364                         return;
365                 }
366         } else {
367                 pc->partial = data_blob(NULL, 0);
368         }
369         pc->num_read -= pc->packet_size;
370         pc->packet_size = 0;
371         
372         if (pc->serialise) {
373                 pc->processing = 1;
374         }
375
376         pc->busy = true;
377
378         status = pc->callback(pc->private, blob);
379
380         pc->busy = false;
381
382         if (pc->destructor_called) {
383                 talloc_free(pc);
384                 return;
385         }
386
387         if (pc->processing) {
388                 if (pc->processing > 1) {
389                         EVENT_FD_READABLE(pc->fde);
390                 }
391                 pc->processing = 0;
392         }
393
394         if (!NT_STATUS_IS_OK(status)) {
395                 packet_error(pc, status);
396                 return;
397         }
398
399         /* Have we consumed the whole buffer yet? */
400         if (pc->partial.length == 0) {
401                 return;
402         }
403
404         /* we got multiple packets in one tcp read */
405         if (pc->ev == NULL) {
406                 goto next_partial;
407         }
408
409         blob = pc->partial;
410         blob.length = pc->num_read;
411
412         status = pc->full_request(pc->private, blob, &pc->packet_size);
413         if (NT_STATUS_IS_ERR(status)) {
414                 packet_error(pc, status);
415                 return;
416         }
417
418         if (!NT_STATUS_IS_OK(status)) {
419                 return;
420         }
421
422         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
423 }
424
425
426 /*
427   temporarily disable receiving 
428 */
429 _PUBLIC_ void packet_recv_disable(struct packet_context *pc)
430 {
431         EVENT_FD_NOT_READABLE(pc->fde);
432         pc->recv_disable = true;
433 }
434
435 /*
436   re-enable receiving 
437 */
438 _PUBLIC_ void packet_recv_enable(struct packet_context *pc)
439 {
440         EVENT_FD_READABLE(pc->fde);
441         pc->recv_disable = false;
442         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
443                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
444         }
445 }
446
447 /*
448   trigger a run of the send queue
449 */
450 _PUBLIC_ void packet_queue_run(struct packet_context *pc)
451 {
452         while (pc->send_queue) {
453                 struct send_element *el = pc->send_queue;
454                 NTSTATUS status;
455                 size_t nwritten;
456                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
457                                                  el->blob.length - el->nsent);
458
459                 status = socket_send(pc->sock, &blob, &nwritten);
460
461                 if (NT_STATUS_IS_ERR(status)) {
462                         packet_error(pc, status);
463                         return;
464                 }
465                 if (!NT_STATUS_IS_OK(status)) {
466                         return;
467                 }
468                 el->nsent += nwritten;
469                 if (el->nsent == el->blob.length) {
470                         DLIST_REMOVE(pc->send_queue, el);
471                         if (el->send_callback) {
472                                 el->send_callback(el->send_callback_private);
473                         }
474                         talloc_free(el);
475                 }
476         }
477
478         /* we're out of requests to send, so don't wait for write
479            events any more */
480         EVENT_FD_NOT_WRITEABLE(pc->fde);
481 }
482
483 /*
484   put a packet in the send queue.  When the packet is actually sent,
485   call send_callback.  
486
487   Useful for operations that must occour after sending a message, such
488   as the switch to SASL encryption after as sucessful LDAP bind relpy.
489 */
490 _PUBLIC_ NTSTATUS packet_send_callback(struct packet_context *pc, DATA_BLOB blob,
491                                        packet_send_callback_fn_t send_callback, 
492                                        void *private)
493 {
494         struct send_element *el;
495         el = talloc(pc, struct send_element);
496         NT_STATUS_HAVE_NO_MEMORY(el);
497
498         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
499         el->blob = blob;
500         el->nsent = 0;
501         el->send_callback = send_callback;
502         el->send_callback_private = private;
503
504         /* if we aren't going to free the packet then we must reference it
505            to ensure it doesn't disappear before going out */
506         if (pc->nofree) {
507                 if (!talloc_reference(el, blob.data)) {
508                         return NT_STATUS_NO_MEMORY;
509                 }
510         } else {
511                 talloc_steal(el, blob.data);
512         }
513
514         if (private && !talloc_reference(el, private)) {
515                 return NT_STATUS_NO_MEMORY;
516         }
517
518         EVENT_FD_WRITEABLE(pc->fde);
519
520         return NT_STATUS_OK;
521 }
522
523 /*
524   put a packet in the send queue
525 */
526 _PUBLIC_ NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
527 {
528         return packet_send_callback(pc, blob, NULL, NULL);
529 }
530
531
532 /*
533   a full request checker for NBT formatted packets (first 3 bytes are length)
534 */
535 _PUBLIC_ NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
536 {
537         if (blob.length < 4) {
538                 return STATUS_MORE_ENTRIES;
539         }
540         *size = 4 + smb_len(blob.data);
541         if (*size > blob.length) {
542                 return STATUS_MORE_ENTRIES;
543         }
544         return NT_STATUS_OK;
545 }
546
547
548 /*
549   work out if a packet is complete for protocols that use a 32 bit network byte
550   order length
551 */
552 _PUBLIC_ NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
553 {
554         if (blob.length < 4) {
555                 return STATUS_MORE_ENTRIES;
556         }
557         *size = 4 + RIVAL(blob.data, 0);
558         if (*size > blob.length) {
559                 return STATUS_MORE_ENTRIES;
560         }
561         return NT_STATUS_OK;
562 }