45ca1feb45a6065f1bcbe8bfdfe38e178d721ca1
[jra/samba/.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21    
22 */
23
24 #include "includes.h"
25 #include "dlinklist.h"
26 #include "lib/events/events.h"
27 #include "lib/socket/socket.h"
28 #include "lib/tls/tls.h"
29 #include "lib/stream/packet.h"
30
31
32 struct packet_context {
33         packet_callback_fn_t callback;
34         packet_full_request_fn_t full_request;
35         packet_error_handler_fn_t error_handler;
36         DATA_BLOB partial;
37         uint32_t num_read;
38         uint32_t initial_read;
39         struct tls_context *tls;
40         struct socket_context *sock;
41         struct event_context *ev;
42         size_t packet_size;
43         void *private;
44         struct fd_event *fde;
45         BOOL serialise;
46         int processing;
47         BOOL recv_disable;
48         BOOL nofree;
49
50         struct send_element {
51                 struct send_element *next, *prev;
52                 DATA_BLOB blob;
53                 size_t nsent;
54         } *send_queue;
55 };
56
57 /*
58   initialise a packet receiver
59 */
60 struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
61 {
62         return talloc_zero(mem_ctx, struct packet_context);
63 }
64
65
66 /*
67   set the request callback, called when a full request is ready
68 */
69 void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
70 {
71         pc->callback = callback;
72 }
73
74 /*
75   set the error handler
76 */
77 void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
78 {
79         pc->error_handler = handler;
80 }
81
82 /*
83   set the private pointer passed to the callback functions
84 */
85 void packet_set_private(struct packet_context *pc, void *private)
86 {
87         pc->private = private;
88 }
89
90 /*
91   set the full request callback. Should return as follows:
92      NT_STATUS_OK == blob is a full request.
93      STATUS_MORE_ENTRIES == blob is not complete yet
94      any error == blob is not a valid 
95 */
96 void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
97 {
98         pc->full_request = callback;
99 }
100
101 /*
102   set a tls context to use. You must either set a tls_context or a socket_context
103 */
104 void packet_set_tls(struct packet_context *pc, struct tls_context *tls)
105 {
106         pc->tls = tls;
107 }
108
109 /*
110   set a socket context to use. You must either set a tls_context or a socket_context
111 */
112 void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
113 {
114         pc->sock = sock;
115 }
116
117 /*
118   set an event context. If this is set then the code will ensure that
119   packets arrive with separate events, by creating a immediate event
120   for any secondary packets when more than one packet is read at one
121   time on a socket. This can matter for code that relies on not
122   getting more than one packet per event
123 */
124 void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
125 {
126         pc->ev = ev;
127 }
128
129 /*
130   tell the packet layer the fde for the socket
131 */
132 void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
133 {
134         pc->fde = fde;
135 }
136
137 /*
138   tell the packet layer to serialise requests, so we don't process two
139   requests at once on one connection. You must have set the
140   event_context and fde
141 */
142 void packet_set_serialise(struct packet_context *pc)
143 {
144         pc->serialise = True;
145 }
146
147 /*
148   tell the packet layer how much to read when starting a new packet
149   this ensures it doesn't overread
150 */
151 void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
152 {
153         pc->initial_read = initial_read;
154 }
155
156 /*
157   tell the packet system not to steal/free blobs given to packet_send()
158 */
159 void packet_set_nofree(struct packet_context *pc)
160 {
161         pc->nofree = True;
162 }
163
164
165 /*
166   tell the caller we have an error
167 */
168 static void packet_error(struct packet_context *pc, NTSTATUS status)
169 {
170         pc->tls = NULL;
171         pc->sock = NULL;
172         if (pc->error_handler) {
173                 pc->error_handler(pc->private, status);
174                 return;
175         }
176         /* default error handler is to free the callers private pointer */
177         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
178                 DEBUG(0,("packet_error on %s - %s\n", 
179                          talloc_get_name(pc->private), nt_errstr(status)));
180         }
181         talloc_free(pc->private);
182         return;
183 }
184
185
186 /*
187   tell the caller we have EOF
188 */
189 static void packet_eof(struct packet_context *pc)
190 {
191         packet_error(pc, NT_STATUS_END_OF_FILE);
192 }
193
194
195 /*
196   used to put packets on event boundaries
197 */
198 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
199                               struct timeval t, void *private)
200 {
201         struct packet_context *pc = talloc_get_type(private, struct packet_context);
202         if (pc->num_read != 0 && pc->packet_size != 0 &&
203             pc->packet_size <= pc->num_read) {
204                 packet_recv(pc);
205         }
206 }
207
208 /*
209   call this when the socket becomes readable to kick off the whole
210   stream parsing process
211 */
212 void packet_recv(struct packet_context *pc)
213 {
214         size_t npending;
215         NTSTATUS status;
216         size_t nread = 0;
217         DATA_BLOB blob;
218
219         if (pc->processing) {
220                 EVENT_FD_NOT_READABLE(pc->fde);
221                 pc->processing++;
222                 return;
223         }
224
225         if (pc->recv_disable) {
226                 EVENT_FD_NOT_READABLE(pc->fde);
227                 return;
228         }
229
230         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
231                 goto next_partial;
232         }
233
234         if (pc->packet_size != 0) {
235                 /* we've already worked out how long this next packet is, so skip the
236                    socket_pending() call */
237                 npending = pc->packet_size - pc->num_read;
238         } else if (pc->initial_read != 0) {
239                 npending = pc->initial_read - pc->num_read;
240         } else {
241                 if (pc->tls) {
242                         status = tls_socket_pending(pc->tls, &npending);
243                 } else if (pc->sock) {
244                         status = socket_pending(pc->sock, &npending);
245                 } else {
246                         status = NT_STATUS_CONNECTION_DISCONNECTED;
247                 }
248                 if (!NT_STATUS_IS_OK(status)) {
249                         packet_error(pc, status);
250                         return;
251                 }
252         }
253
254         if (npending == 0) {
255                 packet_eof(pc);
256                 return;
257         }
258
259         /* possibly expand the partial packet buffer */
260         if (npending + pc->num_read > pc->partial.length) {
261                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
262                 if (!NT_STATUS_IS_OK(status)) {
263                         packet_error(pc, status);
264                         return;
265                 }
266         }
267
268         if (pc->tls) {
269                 status = tls_socket_recv(pc->tls, pc->partial.data + pc->num_read, 
270                                          npending, &nread);
271         } else {
272                 status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
273                                      npending, &nread, 0);
274         }
275         if (NT_STATUS_IS_ERR(status)) {
276                 packet_error(pc, status);
277                 return;
278         }
279         if (!NT_STATUS_IS_OK(status)) {
280                 return;
281         }
282
283         if (nread == 0) {
284                 packet_eof(pc);
285                 return;
286         }
287
288         pc->num_read += nread;
289
290 next_partial:
291         if (pc->partial.length != pc->num_read) {
292                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
293                 if (!NT_STATUS_IS_OK(status)) {
294                         packet_error(pc, status);
295                         return;
296                 }
297         }
298
299         /* see if its a full request */
300         blob = pc->partial;
301         blob.length = pc->num_read;
302         status = pc->full_request(pc->private, blob, &pc->packet_size);
303         if (NT_STATUS_IS_ERR(status)) {
304                 packet_error(pc, status);
305                 return;
306         }
307         if (!NT_STATUS_IS_OK(status)) {
308                 return;
309         }
310
311         if (pc->packet_size > pc->num_read) {
312                 /* the caller made an error */
313                 DEBUG(0,("Invalid packet_size %u greater than num_read %u\n",
314                          pc->packet_size, pc->num_read));
315                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
316                 return;
317         }
318
319         /* it is a full request - give it to the caller */
320         blob = pc->partial;
321         blob.length = pc->num_read;
322
323         if (pc->packet_size < pc->num_read) {
324                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
325                                                pc->num_read - pc->packet_size);
326                 if (pc->partial.data == NULL) {
327                         packet_error(pc, NT_STATUS_NO_MEMORY);
328                         return;
329                 }
330                 status = data_blob_realloc(pc, &blob, pc->packet_size);
331                 if (!NT_STATUS_IS_OK(status)) {
332                         packet_error(pc, status);
333                         return;
334                 }
335         } else {
336                 pc->partial = data_blob(NULL, 0);
337         }
338         pc->num_read -= pc->packet_size;
339         pc->packet_size = 0;
340         
341         if (pc->serialise) {
342                 pc->processing = 1;
343         }
344
345         status = pc->callback(pc->private, blob);
346
347         if (pc->processing) {
348                 if (pc->processing > 1) {
349                         EVENT_FD_READABLE(pc->fde);
350                 }
351                 pc->processing = 0;
352         }
353
354         if (!NT_STATUS_IS_OK(status)) {
355                 packet_error(pc, status);
356                 return;
357         }
358
359         if (pc->partial.length == 0) {
360                 return;
361         }
362
363         /* we got multiple packets in one tcp read */
364         if (pc->ev == NULL) {
365                 goto next_partial;
366         }
367
368         blob = pc->partial;
369         blob.length = pc->num_read;
370
371         status = pc->full_request(pc->private, blob, &pc->packet_size);
372         if (NT_STATUS_IS_ERR(status)) {
373                 packet_error(pc, status);
374                 return;
375         }
376
377         if (!NT_STATUS_IS_OK(status)) {
378                 return;
379         }
380
381         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
382 }
383
384
385 /*
386   temporarily disable receiving 
387 */
388 void packet_recv_disable(struct packet_context *pc)
389 {
390         EVENT_FD_NOT_READABLE(pc->fde);
391         pc->recv_disable = True;
392 }
393
394 /*
395   re-enable receiving 
396 */
397 void packet_recv_enable(struct packet_context *pc)
398 {
399         EVENT_FD_READABLE(pc->fde);
400         pc->recv_disable = False;
401         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
402                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
403         }
404 }
405
406 /*
407   trigger a run of the send queue
408 */
409 void packet_queue_run(struct packet_context *pc)
410 {
411         while (pc->send_queue) {
412                 struct send_element *el = pc->send_queue;
413                 NTSTATUS status;
414                 size_t nwritten;
415                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
416                                                  el->blob.length - el->nsent);
417
418                 if (pc->tls) {
419                         status = tls_socket_send(pc->tls, &blob, &nwritten);
420                 } else {
421                         status = socket_send(pc->sock, &blob, &nwritten, 0);
422                 }
423                 if (NT_STATUS_IS_ERR(status)) {
424                         packet_error(pc, NT_STATUS_NET_WRITE_FAULT);
425                         return;
426                 }
427                 if (!NT_STATUS_IS_OK(status)) {
428                         return;
429                 }
430                 el->nsent += nwritten;
431                 if (el->nsent == el->blob.length) {
432                         DLIST_REMOVE(pc->send_queue, el);
433                         talloc_free(el);
434                 }
435         }
436
437         /* we're out of requests to send, so don't wait for write
438            events any more */
439         EVENT_FD_NOT_WRITEABLE(pc->fde);
440 }
441
442 /*
443   put a packet in the send queue
444 */
445 NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
446 {
447         struct send_element *el;
448         el = talloc(pc, struct send_element);
449         NT_STATUS_HAVE_NO_MEMORY(el);
450
451         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
452         el->blob = blob;
453         el->nsent = 0;
454
455         /* if we aren't going to free the packet then we must reference it
456            to ensure it doesn't disappear before going out */
457         if (pc->nofree) {
458                 if (!talloc_reference(el, blob.data)) {
459                         return NT_STATUS_NO_MEMORY;
460                 }
461         } else {
462                 talloc_steal(el, blob.data);
463         }
464
465         EVENT_FD_WRITEABLE(pc->fde);
466
467         return NT_STATUS_OK;
468 }
469
470
471 /*
472   a full request checker for NBT formatted packets (first 3 bytes are length)
473 */
474 NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
475 {
476         if (blob.length < 4) {
477                 return STATUS_MORE_ENTRIES;
478         }
479         *size = 4 + smb_len(blob.data);
480         if (*size > blob.length) {
481                 return STATUS_MORE_ENTRIES;
482         }
483         return NT_STATUS_OK;
484 }
485
486
487 /*
488   work out if a packet is complete for protocols that use a 32 bit network byte
489   order length
490 */
491 NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
492 {
493         if (blob.length < 4) {
494                 return STATUS_MORE_ENTRIES;
495         }
496         *size = 4 + RIVAL(blob.data, 0);
497         if (*size > blob.length) {
498                 return STATUS_MORE_ENTRIES;
499         }
500         return NT_STATUS_OK;
501 }