s3: Eliminate select from packet_fd_read_sync
[ira/wip.git] / source3 / lib / packet.c
1 /* 
2    Unix SMB/CIFS implementation.
3    Packet handling
4    Copyright (C) Volker Lendecke 2007
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "../lib/util/select.h"
22 #include "system/select.h"
23
24 struct packet_context {
25         int fd;
26         DATA_BLOB in, out;
27 };
28
29 /*
30  * Close the underlying fd
31  */
32 static int packet_context_destructor(struct packet_context *ctx)
33 {
34         return close(ctx->fd);
35 }
36
37 /*
38  * Initialize a packet context. The fd is given to the packet context, meaning
39  * that it is automatically closed when the packet context is freed.
40  */
41 struct packet_context *packet_init(TALLOC_CTX *mem_ctx, int fd)
42 {
43         struct packet_context *result;
44
45         if (!(result = TALLOC_ZERO_P(mem_ctx, struct packet_context))) {
46                 return NULL;
47         }
48
49         result->fd = fd;
50         talloc_set_destructor(result, packet_context_destructor);
51         return result;
52 }
53
54 /*
55  * Pull data from the fd
56  */
57 NTSTATUS packet_fd_read(struct packet_context *ctx)
58 {
59         int res, available;
60         size_t new_size;
61         uint8 *in;
62
63         res = ioctl(ctx->fd, FIONREAD, &available);
64
65         if (res == -1) {
66                 DEBUG(10, ("ioctl(FIONREAD) failed: %s\n", strerror(errno)));
67                 return map_nt_error_from_unix(errno);
68         }
69
70         SMB_ASSERT(available >= 0);
71
72         if (available == 0) {
73                 return NT_STATUS_END_OF_FILE;
74         }
75
76         new_size = ctx->in.length + available;
77
78         if (new_size < ctx->in.length) {
79                 DEBUG(0, ("integer wrap\n"));
80                 return NT_STATUS_NO_MEMORY;
81         }
82
83         if (!(in = TALLOC_REALLOC_ARRAY(ctx, ctx->in.data, uint8, new_size))) {
84                 DEBUG(10, ("talloc failed\n"));
85                 return NT_STATUS_NO_MEMORY;
86         }
87
88         ctx->in.data = in;
89
90         res = recv(ctx->fd, in + ctx->in.length, available, 0);
91
92         if (res < 0) {
93                 DEBUG(10, ("recv failed: %s\n", strerror(errno)));
94                 return map_nt_error_from_unix(errno);
95         }
96
97         if (res == 0) {
98                 return NT_STATUS_END_OF_FILE;
99         }
100
101         ctx->in.length += res;
102
103         return NT_STATUS_OK;
104 }
105
106 NTSTATUS packet_fd_read_sync(struct packet_context *ctx, int timeout)
107 {
108         int res, revents;
109
110         res = poll_one_fd(ctx->fd, POLLIN|POLLHUP, timeout, &revents);
111         if (res == 0) {
112                 DEBUG(10, ("poll timed out\n"));
113                 return NT_STATUS_IO_TIMEOUT;
114         }
115
116         if (res == -1) {
117                 DEBUG(10, ("poll returned %s\n", strerror(errno)));
118                 return map_nt_error_from_unix(errno);
119         }
120         if ((revents & (POLLIN|POLLHUP|POLLERR)) == 0) {
121                 DEBUG(10, ("socket not readable\n"));
122                 return NT_STATUS_IO_TIMEOUT;
123         }
124
125         return packet_fd_read(ctx);
126 }
127
128 bool packet_handler(struct packet_context *ctx,
129                     bool (*full_req)(const uint8_t *buf,
130                                      size_t available,
131                                      size_t *length,
132                                      void *priv),
133                     NTSTATUS (*callback)(uint8_t *buf, size_t length,
134                                          void *priv),
135                     void *priv, NTSTATUS *status)
136 {
137         size_t length;
138         uint8_t *buf;
139
140         if (!full_req(ctx->in.data, ctx->in.length, &length, priv)) {
141                 return False;
142         }
143
144         if (length > ctx->in.length) {
145                 *status = NT_STATUS_INTERNAL_ERROR;
146                 return true;
147         }
148
149         if (length == ctx->in.length) {
150                 buf = ctx->in.data;
151                 ctx->in.data = NULL;
152                 ctx->in.length = 0;
153         } else {
154                 buf = (uint8_t *)TALLOC_MEMDUP(ctx, ctx->in.data, length);
155                 if (buf == NULL) {
156                         *status = NT_STATUS_NO_MEMORY;
157                         return true;
158                 }
159
160                 memmove(ctx->in.data, ctx->in.data + length,
161                         ctx->in.length - length);
162                 ctx->in.length -= length;
163         }
164
165         *status = callback(buf, length, priv);
166         return True;
167 }
168
169 /*
170  * How many bytes of outgoing data do we have pending?
171  */
172 size_t packet_outgoing_bytes(struct packet_context *ctx)
173 {
174         return ctx->out.length;
175 }
176
177 /*
178  * Push data to the fd
179  */
180 NTSTATUS packet_fd_write(struct packet_context *ctx)
181 {
182         ssize_t sent;
183
184         sent = send(ctx->fd, ctx->out.data, ctx->out.length, 0);
185
186         if (sent == -1) {
187                 DEBUG(0, ("send failed: %s\n", strerror(errno)));
188                 return map_nt_error_from_unix(errno);
189         }
190
191         memmove(ctx->out.data, ctx->out.data + sent,
192                 ctx->out.length - sent);
193         ctx->out.length -= sent;
194
195         return NT_STATUS_OK;
196 }
197
198 /*
199  * Sync flush all outgoing bytes
200  */
201 NTSTATUS packet_flush(struct packet_context *ctx)
202 {
203         while (ctx->out.length != 0) {
204                 NTSTATUS status = packet_fd_write(ctx);
205                 if (!NT_STATUS_IS_OK(status)) {
206                         return status;
207                 }
208         }
209         return NT_STATUS_OK;
210 }
211
212 /*
213  * Send a list of DATA_BLOBs
214  *
215  * Example:  packet_send(ctx, 2, data_blob_const(&size, sizeof(size)),
216  *                       data_blob_const(buf, size));
217  */
218 NTSTATUS packet_send(struct packet_context *ctx, int num_blobs, ...)
219 {
220         va_list ap;
221         int i;
222         size_t len;
223         uint8 *out;
224
225         len = ctx->out.length;
226
227         va_start(ap, num_blobs);
228         for (i=0; i<num_blobs; i++) {
229                 size_t tmp;
230                 DATA_BLOB blob = va_arg(ap, DATA_BLOB);
231
232                 tmp = len + blob.length;
233                 if (tmp < len) {
234                         DEBUG(0, ("integer overflow\n"));
235                         va_end(ap);
236                         return NT_STATUS_NO_MEMORY;
237                 }
238                 len = tmp;
239         }
240         va_end(ap);
241
242         if (len == 0) {
243                 return NT_STATUS_OK;
244         }
245
246         if (!(out = TALLOC_REALLOC_ARRAY(ctx, ctx->out.data, uint8, len))) {
247                 DEBUG(0, ("talloc failed\n"));
248                 return NT_STATUS_NO_MEMORY;
249         }
250
251         ctx->out.data = out;
252
253         va_start(ap, num_blobs);
254         for (i=0; i<num_blobs; i++) {
255                 DATA_BLOB blob = va_arg(ap, DATA_BLOB);
256
257                 memcpy(ctx->out.data+ctx->out.length, blob.data, blob.length);
258                 ctx->out.length += blob.length;
259         }
260         va_end(ap);
261
262         SMB_ASSERT(ctx->out.length == len);
263         return NT_STATUS_OK;
264 }
265
266 /*
267  * Get the packet context's file descriptor
268  */
269 int packet_get_fd(struct packet_context *ctx)
270 {
271         return ctx->fd;
272 }
273