lib/async_req: remove the tevent_fd as early as possible via a writev_cleanup() hook
[samba.git] / lib / async_req / async_sock.c
1 /*
2    Unix SMB/CIFS implementation.
3    async socket syscalls
4    Copyright (C) Volker Lendecke 2008
5
6      ** NOTE! The following LGPL license applies to the async_sock
7      ** library. This does NOT imply that all of Samba is released
8      ** under the LGPL
9
10    This library is free software; you can redistribute it and/or
11    modify it under the terms of the GNU Lesser General Public
12    License as published by the Free Software Foundation; either
13    version 3 of the License, or (at your option) any later version.
14
15    This library is distributed in the hope that it will be useful,
16    but WITHOUT ANY WARRANTY; without even the implied warranty of
17    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18    Library General Public License for more details.
19
20    You should have received a copy of the GNU Lesser General Public License
21    along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 */
23
24 #include "replace.h"
25 #include "system/network.h"
26 #include "system/filesys.h"
27 #include <talloc.h>
28 #include <tevent.h>
29 #include "lib/async_req/async_sock.h"
30 #include "lib/util/iov_buf.h"
31
32 /* Note: lib/util/ is currently GPL */
33 #include "lib/util/tevent_unix.h"
34 #include "lib/util/samba_util.h"
35
36 struct async_connect_state {
37         int fd;
38         struct tevent_fd *fde;
39         int result;
40         long old_sockflags;
41         socklen_t address_len;
42         struct sockaddr_storage address;
43
44         void (*before_connect)(void *private_data);
45         void (*after_connect)(void *private_data);
46         void *private_data;
47 };
48
49 static void async_connect_cleanup(struct tevent_req *req,
50                                   enum tevent_req_state req_state);
51 static void async_connect_connected(struct tevent_context *ev,
52                                     struct tevent_fd *fde, uint16_t flags,
53                                     void *priv);
54
55 /**
56  * @brief async version of connect(2)
57  * @param[in] mem_ctx   The memory context to hang the result off
58  * @param[in] ev        The event context to work from
59  * @param[in] fd        The socket to recv from
60  * @param[in] address   Where to connect?
61  * @param[in] address_len Length of *address
62  * @retval The async request
63  *
64  * This function sets the socket into non-blocking state to be able to call
65  * connect in an async state. This will be reset when the request is finished.
66  */
67
68 struct tevent_req *async_connect_send(
69         TALLOC_CTX *mem_ctx, struct tevent_context *ev, int fd,
70         const struct sockaddr *address, socklen_t address_len,
71         void (*before_connect)(void *private_data),
72         void (*after_connect)(void *private_data),
73         void *private_data)
74 {
75         struct tevent_req *req;
76         struct async_connect_state *state;
77
78         req = tevent_req_create(mem_ctx, &state, struct async_connect_state);
79         if (req == NULL) {
80                 return NULL;
81         }
82
83         /**
84          * We have to set the socket to nonblocking for async connect(2). Keep
85          * the old sockflags around.
86          */
87
88         state->fd = fd;
89         state->before_connect = before_connect;
90         state->after_connect = after_connect;
91         state->private_data = private_data;
92
93         state->old_sockflags = fcntl(fd, F_GETFL, 0);
94         if (state->old_sockflags == -1) {
95                 tevent_req_error(req, errno);
96                 return tevent_req_post(req, ev);
97         }
98
99         tevent_req_set_cleanup_fn(req, async_connect_cleanup);
100
101         state->address_len = address_len;
102         if (address_len > sizeof(state->address)) {
103                 tevent_req_error(req, EINVAL);
104                 return tevent_req_post(req, ev);
105         }
106         memcpy(&state->address, address, address_len);
107
108         set_blocking(fd, false);
109
110         if (state->before_connect != NULL) {
111                 state->before_connect(state->private_data);
112         }
113
114         state->result = connect(fd, address, address_len);
115
116         if (state->after_connect != NULL) {
117                 state->after_connect(state->private_data);
118         }
119
120         if (state->result == 0) {
121                 tevent_req_done(req);
122                 return tevent_req_post(req, ev);
123         }
124
125         /**
126          * A number of error messages show that something good is progressing
127          * and that we have to wait for readability.
128          *
129          * If none of them are present, bail out.
130          */
131
132         if (!(errno == EINPROGRESS || errno == EALREADY ||
133 #ifdef EISCONN
134               errno == EISCONN ||
135 #endif
136               errno == EAGAIN || errno == EINTR)) {
137                 tevent_req_error(req, errno);
138                 return tevent_req_post(req, ev);
139         }
140
141         state->fde = tevent_add_fd(ev, state, fd,
142                                    TEVENT_FD_READ | TEVENT_FD_WRITE,
143                                    async_connect_connected, req);
144         if (state->fde == NULL) {
145                 tevent_req_error(req, ENOMEM);
146                 return tevent_req_post(req, ev);
147         }
148         return req;
149 }
150
151 static void async_connect_cleanup(struct tevent_req *req,
152                                   enum tevent_req_state req_state)
153 {
154         struct async_connect_state *state =
155                 tevent_req_data(req, struct async_connect_state);
156
157         TALLOC_FREE(state->fde);
158         if (state->fd != -1) {
159                 fcntl(state->fd, F_SETFL, state->old_sockflags);
160                 state->fd = -1;
161         }
162 }
163
164 /**
165  * fde event handler for connect(2)
166  * @param[in] ev        The event context that sent us here
167  * @param[in] fde       The file descriptor event associated with the connect
168  * @param[in] flags     Indicate read/writeability of the socket
169  * @param[in] priv      private data, "struct async_req *" in this case
170  */
171
172 static void async_connect_connected(struct tevent_context *ev,
173                                     struct tevent_fd *fde, uint16_t flags,
174                                     void *priv)
175 {
176         struct tevent_req *req = talloc_get_type_abort(
177                 priv, struct tevent_req);
178         struct async_connect_state *state =
179                 tevent_req_data(req, struct async_connect_state);
180         int ret;
181
182         if (state->before_connect != NULL) {
183                 state->before_connect(state->private_data);
184         }
185
186         ret = connect(state->fd, (struct sockaddr *)(void *)&state->address,
187                       state->address_len);
188
189         if (state->after_connect != NULL) {
190                 state->after_connect(state->private_data);
191         }
192
193         if (ret == 0) {
194                 tevent_req_done(req);
195                 return;
196         }
197         if (errno == EINPROGRESS) {
198                 /* Try again later, leave the fde around */
199                 return;
200         }
201         tevent_req_error(req, errno);
202         return;
203 }
204
205 int async_connect_recv(struct tevent_req *req, int *perrno)
206 {
207         int err = tevent_req_simple_recv_unix(req);
208
209         if (err != 0) {
210                 *perrno = err;
211                 return -1;
212         }
213
214         return 0;
215 }
216
217 struct writev_state {
218         struct tevent_context *ev;
219         int fd;
220         struct tevent_fd *fde;
221         struct iovec *iov;
222         int count;
223         size_t total_size;
224         uint16_t flags;
225         bool err_on_readability;
226 };
227
228 static void writev_cleanup(struct tevent_req *req,
229                            enum tevent_req_state req_state);
230 static void writev_trigger(struct tevent_req *req, void *private_data);
231 static void writev_handler(struct tevent_context *ev, struct tevent_fd *fde,
232                            uint16_t flags, void *private_data);
233
234 struct tevent_req *writev_send(TALLOC_CTX *mem_ctx, struct tevent_context *ev,
235                                struct tevent_queue *queue, int fd,
236                                bool err_on_readability,
237                                struct iovec *iov, int count)
238 {
239         struct tevent_req *req;
240         struct writev_state *state;
241
242         req = tevent_req_create(mem_ctx, &state, struct writev_state);
243         if (req == NULL) {
244                 return NULL;
245         }
246         state->ev = ev;
247         state->fd = fd;
248         state->total_size = 0;
249         state->count = count;
250         state->iov = (struct iovec *)talloc_memdup(
251                 state, iov, sizeof(struct iovec) * count);
252         if (tevent_req_nomem(state->iov, req)) {
253                 return tevent_req_post(req, ev);
254         }
255         state->flags = TEVENT_FD_WRITE|TEVENT_FD_READ;
256         state->err_on_readability = err_on_readability;
257
258         tevent_req_set_cleanup_fn(req, writev_cleanup);
259
260         if (queue == NULL) {
261                 state->fde = tevent_add_fd(state->ev, state, state->fd,
262                                     state->flags, writev_handler, req);
263                 if (tevent_req_nomem(state->fde, req)) {
264                         return tevent_req_post(req, ev);
265                 }
266                 return req;
267         }
268
269         if (!tevent_queue_add(queue, ev, req, writev_trigger, NULL)) {
270                 tevent_req_nomem(NULL, req);
271                 return tevent_req_post(req, ev);
272         }
273         return req;
274 }
275
276 static void writev_cleanup(struct tevent_req *req,
277                            enum tevent_req_state req_state)
278 {
279         struct writev_state *state = tevent_req_data(req, struct writev_state);
280
281         TALLOC_FREE(state->fde);
282 }
283
284 static void writev_trigger(struct tevent_req *req, void *private_data)
285 {
286         struct writev_state *state = tevent_req_data(req, struct writev_state);
287
288         state->fde = tevent_add_fd(state->ev, state, state->fd, state->flags,
289                             writev_handler, req);
290         if (tevent_req_nomem(state->fde, req)) {
291                 return;
292         }
293 }
294
295 static void writev_handler(struct tevent_context *ev, struct tevent_fd *fde,
296                            uint16_t flags, void *private_data)
297 {
298         struct tevent_req *req = talloc_get_type_abort(
299                 private_data, struct tevent_req);
300         struct writev_state *state =
301                 tevent_req_data(req, struct writev_state);
302         size_t written;
303         bool ok;
304
305         if ((state->flags & TEVENT_FD_READ) && (flags & TEVENT_FD_READ)) {
306                 int ret, value;
307
308                 if (state->err_on_readability) {
309                         /* Readable and the caller wants an error on read. */
310                         tevent_req_error(req, EPIPE);
311                         return;
312                 }
313
314                 /* Might be an error. Check if there are bytes to read */
315                 ret = ioctl(state->fd, FIONREAD, &value);
316                 /* FIXME - should we also check
317                    for ret == 0 and value == 0 here ? */
318                 if (ret == -1) {
319                         /* There's an error. */
320                         tevent_req_error(req, EPIPE);
321                         return;
322                 }
323                 /* A request for TEVENT_FD_READ will succeed from now and
324                    forevermore until the bytes are read so if there was
325                    an error we'll wait until we do read, then get it in
326                    the read callback function. Until then, remove TEVENT_FD_READ
327                    from the flags we're waiting for. */
328                 state->flags &= ~TEVENT_FD_READ;
329                 TEVENT_FD_NOT_READABLE(fde);
330
331                 /* If not writable, we're done. */
332                 if (!(flags & TEVENT_FD_WRITE)) {
333                         return;
334                 }
335         }
336
337         written = writev(state->fd, state->iov, state->count);
338         if ((written == -1) && (errno == EINTR)) {
339                 /* retry */
340                 return;
341         }
342         if (written == -1) {
343                 tevent_req_error(req, errno);
344                 return;
345         }
346         if (written == 0) {
347                 tevent_req_error(req, EPIPE);
348                 return;
349         }
350         state->total_size += written;
351
352         ok = iov_advance(&state->iov, &state->count, written);
353         if (!ok) {
354                 tevent_req_error(req, EIO);
355                 return;
356         }
357
358         if (state->count == 0) {
359                 tevent_req_done(req);
360                 return;
361         }
362 }
363
364 ssize_t writev_recv(struct tevent_req *req, int *perrno)
365 {
366         struct writev_state *state =
367                 tevent_req_data(req, struct writev_state);
368         ssize_t ret;
369
370         if (tevent_req_is_unix_error(req, perrno)) {
371                 tevent_req_received(req);
372                 return -1;
373         }
374         ret = state->total_size;
375         tevent_req_received(req);
376         return ret;
377 }
378
379 struct read_packet_state {
380         int fd;
381         uint8_t *buf;
382         size_t nread;
383         ssize_t (*more)(uint8_t *buf, size_t buflen, void *private_data);
384         void *private_data;
385 };
386
387 static void read_packet_handler(struct tevent_context *ev,
388                                 struct tevent_fd *fde,
389                                 uint16_t flags, void *private_data);
390
391 struct tevent_req *read_packet_send(TALLOC_CTX *mem_ctx,
392                                     struct tevent_context *ev,
393                                     int fd, size_t initial,
394                                     ssize_t (*more)(uint8_t *buf,
395                                                     size_t buflen,
396                                                     void *private_data),
397                                     void *private_data)
398 {
399         struct tevent_req *result;
400         struct read_packet_state *state;
401         struct tevent_fd *fde;
402
403         result = tevent_req_create(mem_ctx, &state, struct read_packet_state);
404         if (result == NULL) {
405                 return NULL;
406         }
407         state->fd = fd;
408         state->nread = 0;
409         state->more = more;
410         state->private_data = private_data;
411
412         state->buf = talloc_array(state, uint8_t, initial);
413         if (state->buf == NULL) {
414                 goto fail;
415         }
416
417         fde = tevent_add_fd(ev, state, fd, TEVENT_FD_READ, read_packet_handler,
418                             result);
419         if (fde == NULL) {
420                 goto fail;
421         }
422         return result;
423  fail:
424         TALLOC_FREE(result);
425         return NULL;
426 }
427
428 static void read_packet_handler(struct tevent_context *ev,
429                                 struct tevent_fd *fde,
430                                 uint16_t flags, void *private_data)
431 {
432         struct tevent_req *req = talloc_get_type_abort(
433                 private_data, struct tevent_req);
434         struct read_packet_state *state =
435                 tevent_req_data(req, struct read_packet_state);
436         size_t total = talloc_get_size(state->buf);
437         ssize_t nread, more;
438         uint8_t *tmp;
439
440         nread = recv(state->fd, state->buf+state->nread, total-state->nread,
441                      0);
442         if ((nread == -1) && (errno == ENOTSOCK)) {
443                 nread = read(state->fd, state->buf+state->nread,
444                              total-state->nread);
445         }
446         if ((nread == -1) && (errno == EINTR)) {
447                 /* retry */
448                 return;
449         }
450         if (nread == -1) {
451                 tevent_req_error(req, errno);
452                 return;
453         }
454         if (nread == 0) {
455                 tevent_req_error(req, EPIPE);
456                 return;
457         }
458
459         state->nread += nread;
460         if (state->nread < total) {
461                 /* Come back later */
462                 return;
463         }
464
465         /*
466          * We got what was initially requested. See if "more" asks for -- more.
467          */
468         if (state->more == NULL) {
469                 /* Nobody to ask, this is a async read_data */
470                 tevent_req_done(req);
471                 return;
472         }
473
474         more = state->more(state->buf, total, state->private_data);
475         if (more == -1) {
476                 /* We got an invalid packet, tell the caller */
477                 tevent_req_error(req, EIO);
478                 return;
479         }
480         if (more == 0) {
481                 /* We're done, full packet received */
482                 tevent_req_done(req);
483                 return;
484         }
485
486         if (total + more < total) {
487                 tevent_req_error(req, EMSGSIZE);
488                 return;
489         }
490
491         tmp = talloc_realloc(state, state->buf, uint8_t, total+more);
492         if (tevent_req_nomem(tmp, req)) {
493                 return;
494         }
495         state->buf = tmp;
496 }
497
498 ssize_t read_packet_recv(struct tevent_req *req, TALLOC_CTX *mem_ctx,
499                          uint8_t **pbuf, int *perrno)
500 {
501         struct read_packet_state *state =
502                 tevent_req_data(req, struct read_packet_state);
503
504         if (tevent_req_is_unix_error(req, perrno)) {
505                 return -1;
506         }
507         *pbuf = talloc_move(mem_ctx, &state->buf);
508         return talloc_get_size(*pbuf);
509 }
510
511 struct wait_for_read_state {
512         struct tevent_req *req;
513         struct tevent_fd *fde;
514 };
515
516 static void wait_for_read_done(struct tevent_context *ev,
517                                struct tevent_fd *fde,
518                                uint16_t flags,
519                                void *private_data);
520
521 struct tevent_req *wait_for_read_send(TALLOC_CTX *mem_ctx,
522                                       struct tevent_context *ev,
523                                       int fd)
524 {
525         struct tevent_req *req;
526         struct wait_for_read_state *state;
527
528         req = tevent_req_create(mem_ctx, &state, struct wait_for_read_state);
529         if (req == NULL) {
530                 return NULL;
531         }
532         state->req = req;
533         state->fde = tevent_add_fd(ev, state, fd, TEVENT_FD_READ,
534                                    wait_for_read_done, state);
535         if (tevent_req_nomem(state->fde, req)) {
536                 return tevent_req_post(req, ev);
537         }
538         return req;
539 }
540
541 static void wait_for_read_done(struct tevent_context *ev,
542                                struct tevent_fd *fde,
543                                uint16_t flags,
544                                void *private_data)
545 {
546         struct wait_for_read_state *state = talloc_get_type_abort(
547                 private_data, struct wait_for_read_state);
548
549         if (flags & TEVENT_FD_READ) {
550                 TALLOC_FREE(state->fde);
551                 tevent_req_done(state->req);
552         }
553 }
554
555 bool wait_for_read_recv(struct tevent_req *req, int *perr)
556 {
557         int err;
558
559         if (tevent_req_is_unix_error(req, &err)) {
560                 *perr = err;
561                 return false;
562         }
563         return true;
564 }