r11967: Fix more 64-bit warnings.
[kai/samba.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21    
22 */
23
24 #include "includes.h"
25 #include "dlinklist.h"
26 #include "lib/events/events.h"
27 #include "lib/socket/socket.h"
28 #include "lib/tls/tls.h"
29 #include "lib/stream/packet.h"
30
31
32 struct packet_context {
33         packet_callback_fn_t callback;
34         packet_full_request_fn_t full_request;
35         packet_error_handler_fn_t error_handler;
36         DATA_BLOB partial;
37         uint32_t num_read;
38         uint32_t initial_read;
39         struct tls_context *tls;
40         struct socket_context *sock;
41         struct event_context *ev;
42         size_t packet_size;
43         void *private;
44         struct fd_event *fde;
45         BOOL serialise;
46         int processing;
47         BOOL recv_disable;
48         BOOL nofree;
49
50         BOOL busy;
51         BOOL destructor_called;
52
53         struct send_element {
54                 struct send_element *next, *prev;
55                 DATA_BLOB blob;
56                 size_t nsent;
57         } *send_queue;
58 };
59
60 /*
61   a destructor used when we are processing packets to prevent freeing of this
62   context while it is being used
63 */
64 static int packet_destructor(void *p)
65 {
66         struct packet_context *pc = talloc_get_type(p, struct packet_context);
67
68         if (pc->busy) {
69                 pc->destructor_called = True;
70                 /* now we refuse the talloc_free() request. The free will
71                    happen again in the packet_recv() code */
72                 return -1;
73         }
74
75         return 0;
76 }
77
78
79 /*
80   initialise a packet receiver
81 */
82 struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
83 {
84         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
85         if (pc != NULL) {
86                 talloc_set_destructor(pc, packet_destructor);
87         }
88         return pc;
89 }
90
91
92 /*
93   set the request callback, called when a full request is ready
94 */
95 void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
96 {
97         pc->callback = callback;
98 }
99
100 /*
101   set the error handler
102 */
103 void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
104 {
105         pc->error_handler = handler;
106 }
107
108 /*
109   set the private pointer passed to the callback functions
110 */
111 void packet_set_private(struct packet_context *pc, void *private)
112 {
113         pc->private = private;
114 }
115
116 /*
117   set the full request callback. Should return as follows:
118      NT_STATUS_OK == blob is a full request.
119      STATUS_MORE_ENTRIES == blob is not complete yet
120      any error == blob is not a valid 
121 */
122 void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
123 {
124         pc->full_request = callback;
125 }
126
127 /*
128   set a tls context to use. You must either set a tls_context or a socket_context
129 */
130 void packet_set_tls(struct packet_context *pc, struct tls_context *tls)
131 {
132         pc->tls = tls;
133 }
134
135 /*
136   set a socket context to use. You must either set a tls_context or a socket_context
137 */
138 void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
139 {
140         pc->sock = sock;
141 }
142
143 /*
144   set an event context. If this is set then the code will ensure that
145   packets arrive with separate events, by creating a immediate event
146   for any secondary packets when more than one packet is read at one
147   time on a socket. This can matter for code that relies on not
148   getting more than one packet per event
149 */
150 void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
151 {
152         pc->ev = ev;
153 }
154
155 /*
156   tell the packet layer the fde for the socket
157 */
158 void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
159 {
160         pc->fde = fde;
161 }
162
163 /*
164   tell the packet layer to serialise requests, so we don't process two
165   requests at once on one connection. You must have set the
166   event_context and fde
167 */
168 void packet_set_serialise(struct packet_context *pc)
169 {
170         pc->serialise = True;
171 }
172
173 /*
174   tell the packet layer how much to read when starting a new packet
175   this ensures it doesn't overread
176 */
177 void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
178 {
179         pc->initial_read = initial_read;
180 }
181
182 /*
183   tell the packet system not to steal/free blobs given to packet_send()
184 */
185 void packet_set_nofree(struct packet_context *pc)
186 {
187         pc->nofree = True;
188 }
189
190
191 /*
192   tell the caller we have an error
193 */
194 static void packet_error(struct packet_context *pc, NTSTATUS status)
195 {
196         pc->tls = NULL;
197         pc->sock = NULL;
198         if (pc->error_handler) {
199                 pc->error_handler(pc->private, status);
200                 return;
201         }
202         /* default error handler is to free the callers private pointer */
203         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
204                 DEBUG(0,("packet_error on %s - %s\n", 
205                          talloc_get_name(pc->private), nt_errstr(status)));
206         }
207         talloc_free(pc->private);
208         return;
209 }
210
211
212 /*
213   tell the caller we have EOF
214 */
215 static void packet_eof(struct packet_context *pc)
216 {
217         packet_error(pc, NT_STATUS_END_OF_FILE);
218 }
219
220
221 /*
222   used to put packets on event boundaries
223 */
224 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
225                               struct timeval t, void *private)
226 {
227         struct packet_context *pc = talloc_get_type(private, struct packet_context);
228         if (pc->num_read != 0 && pc->packet_size != 0 &&
229             pc->packet_size <= pc->num_read) {
230                 packet_recv(pc);
231         }
232 }
233
234
235 /*
236   call this when the socket becomes readable to kick off the whole
237   stream parsing process
238 */
239 void packet_recv(struct packet_context *pc)
240 {
241         size_t npending;
242         NTSTATUS status;
243         size_t nread = 0;
244         DATA_BLOB blob;
245
246         if (pc->processing) {
247                 EVENT_FD_NOT_READABLE(pc->fde);
248                 pc->processing++;
249                 return;
250         }
251
252         if (pc->recv_disable) {
253                 EVENT_FD_NOT_READABLE(pc->fde);
254                 return;
255         }
256
257         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
258                 goto next_partial;
259         }
260
261         if (pc->packet_size != 0) {
262                 /* we've already worked out how long this next packet is, so skip the
263                    socket_pending() call */
264                 npending = pc->packet_size - pc->num_read;
265         } else if (pc->initial_read != 0) {
266                 npending = pc->initial_read - pc->num_read;
267         } else {
268                 if (pc->tls) {
269                         status = tls_socket_pending(pc->tls, &npending);
270                 } else if (pc->sock) {
271                         status = socket_pending(pc->sock, &npending);
272                 } else {
273                         status = NT_STATUS_CONNECTION_DISCONNECTED;
274                 }
275                 if (!NT_STATUS_IS_OK(status)) {
276                         packet_error(pc, status);
277                         return;
278                 }
279         }
280
281         if (npending == 0) {
282                 packet_eof(pc);
283                 return;
284         }
285
286         /* possibly expand the partial packet buffer */
287         if (npending + pc->num_read > pc->partial.length) {
288                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
289                 if (!NT_STATUS_IS_OK(status)) {
290                         packet_error(pc, status);
291                         return;
292                 }
293         }
294
295         if (pc->tls) {
296                 status = tls_socket_recv(pc->tls, pc->partial.data + pc->num_read, 
297                                          npending, &nread);
298         } else {
299                 status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
300                                      npending, &nread, 0);
301         }
302         if (NT_STATUS_IS_ERR(status)) {
303                 packet_error(pc, status);
304                 return;
305         }
306         if (!NT_STATUS_IS_OK(status)) {
307                 return;
308         }
309
310         if (nread == 0) {
311                 packet_eof(pc);
312                 return;
313         }
314
315         pc->num_read += nread;
316
317 next_partial:
318         if (pc->partial.length != pc->num_read) {
319                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
320                 if (!NT_STATUS_IS_OK(status)) {
321                         packet_error(pc, status);
322                         return;
323                 }
324         }
325
326         /* see if its a full request */
327         blob = pc->partial;
328         blob.length = pc->num_read;
329         status = pc->full_request(pc->private, blob, &pc->packet_size);
330         if (NT_STATUS_IS_ERR(status)) {
331                 packet_error(pc, status);
332                 return;
333         }
334         if (!NT_STATUS_IS_OK(status)) {
335                 return;
336         }
337
338         if (pc->packet_size > pc->num_read) {
339                 /* the caller made an error */
340                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
341                          (long)pc->packet_size, (long)pc->num_read));
342                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
343                 return;
344         }
345
346         /* it is a full request - give it to the caller */
347         blob = pc->partial;
348         blob.length = pc->num_read;
349
350         if (pc->packet_size < pc->num_read) {
351                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
352                                                pc->num_read - pc->packet_size);
353                 if (pc->partial.data == NULL) {
354                         packet_error(pc, NT_STATUS_NO_MEMORY);
355                         return;
356                 }
357                 status = data_blob_realloc(pc, &blob, pc->packet_size);
358                 if (!NT_STATUS_IS_OK(status)) {
359                         packet_error(pc, status);
360                         return;
361                 }
362         } else {
363                 pc->partial = data_blob(NULL, 0);
364         }
365         pc->num_read -= pc->packet_size;
366         pc->packet_size = 0;
367         
368         if (pc->serialise) {
369                 pc->processing = 1;
370         }
371
372         pc->busy = True;
373
374         status = pc->callback(pc->private, blob);
375
376         pc->busy = False;
377
378         if (pc->destructor_called) {
379                 talloc_free(pc);
380                 return;
381         }
382
383         if (pc->processing) {
384                 if (pc->processing > 1) {
385                         EVENT_FD_READABLE(pc->fde);
386                 }
387                 pc->processing = 0;
388         }
389
390         if (!NT_STATUS_IS_OK(status)) {
391                 packet_error(pc, status);
392                 return;
393         }
394
395         if (pc->partial.length == 0) {
396                 return;
397         }
398
399         /* we got multiple packets in one tcp read */
400         if (pc->ev == NULL) {
401                 goto next_partial;
402         }
403
404         blob = pc->partial;
405         blob.length = pc->num_read;
406
407         status = pc->full_request(pc->private, blob, &pc->packet_size);
408         if (NT_STATUS_IS_ERR(status)) {
409                 packet_error(pc, status);
410                 return;
411         }
412
413         if (!NT_STATUS_IS_OK(status)) {
414                 return;
415         }
416
417         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
418 }
419
420
421 /*
422   temporarily disable receiving 
423 */
424 void packet_recv_disable(struct packet_context *pc)
425 {
426         EVENT_FD_NOT_READABLE(pc->fde);
427         pc->recv_disable = True;
428 }
429
430 /*
431   re-enable receiving 
432 */
433 void packet_recv_enable(struct packet_context *pc)
434 {
435         EVENT_FD_READABLE(pc->fde);
436         pc->recv_disable = False;
437         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
438                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
439         }
440 }
441
442 /*
443   trigger a run of the send queue
444 */
445 void packet_queue_run(struct packet_context *pc)
446 {
447         while (pc->send_queue) {
448                 struct send_element *el = pc->send_queue;
449                 NTSTATUS status;
450                 size_t nwritten;
451                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
452                                                  el->blob.length - el->nsent);
453
454                 if (pc->tls) {
455                         status = tls_socket_send(pc->tls, &blob, &nwritten);
456                 } else {
457                         status = socket_send(pc->sock, &blob, &nwritten, 0);
458                 }
459                 if (NT_STATUS_IS_ERR(status)) {
460                         packet_error(pc, NT_STATUS_NET_WRITE_FAULT);
461                         return;
462                 }
463                 if (!NT_STATUS_IS_OK(status)) {
464                         return;
465                 }
466                 el->nsent += nwritten;
467                 if (el->nsent == el->blob.length) {
468                         DLIST_REMOVE(pc->send_queue, el);
469                         talloc_free(el);
470                 }
471         }
472
473         /* we're out of requests to send, so don't wait for write
474            events any more */
475         EVENT_FD_NOT_WRITEABLE(pc->fde);
476 }
477
478 /*
479   put a packet in the send queue
480 */
481 NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
482 {
483         struct send_element *el;
484         el = talloc(pc, struct send_element);
485         NT_STATUS_HAVE_NO_MEMORY(el);
486
487         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
488         el->blob = blob;
489         el->nsent = 0;
490
491         /* if we aren't going to free the packet then we must reference it
492            to ensure it doesn't disappear before going out */
493         if (pc->nofree) {
494                 if (!talloc_reference(el, blob.data)) {
495                         return NT_STATUS_NO_MEMORY;
496                 }
497         } else {
498                 talloc_steal(el, blob.data);
499         }
500
501         EVENT_FD_WRITEABLE(pc->fde);
502
503         return NT_STATUS_OK;
504 }
505
506
507 /*
508   a full request checker for NBT formatted packets (first 3 bytes are length)
509 */
510 NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
511 {
512         if (blob.length < 4) {
513                 return STATUS_MORE_ENTRIES;
514         }
515         *size = 4 + smb_len(blob.data);
516         if (*size > blob.length) {
517                 return STATUS_MORE_ENTRIES;
518         }
519         return NT_STATUS_OK;
520 }
521
522
523 /*
524   work out if a packet is complete for protocols that use a 32 bit network byte
525   order length
526 */
527 NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
528 {
529         if (blob.length < 4) {
530                 return STATUS_MORE_ENTRIES;
531         }
532         *size = 4 + RIVAL(blob.data, 0);
533         if (*size > blob.length) {
534                 return STATUS_MORE_ENTRIES;
535         }
536         return NT_STATUS_OK;
537 }