heaps of cleanup in the io code.
[rsync.git] / io.c
1 /* 
2    Copyright (C) Andrew Tridgell 1996
3    Copyright (C) Paul Mackerras 1996
4    
5    This program is free software; you can redistribute it and/or modify
6    it under the terms of the GNU General Public License as published by
7    the Free Software Foundation; either version 2 of the License, or
8    (at your option) any later version.
9    
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU General Public License for more details.
14    
15    You should have received a copy of the GNU General Public License
16    along with this program; if not, write to the Free Software
17    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 */
19
20 /*
21   Utilities used in rsync 
22
23   tridge, June 1996
24   */
25 #include "rsync.h"
26
27 static int64 total_written;
28 static int64 total_read;
29
30 static int io_multiplexing_out;
31 static int io_multiplexing_in;
32 static int multiplex_in_fd;
33 static int multiplex_out_fd;
34 static time_t last_io;
35
36 extern int verbose;
37 extern int io_timeout;
38
39 int64 write_total(void)
40 {
41         return total_written;
42 }
43
44 int64 read_total(void)
45 {
46         return total_read;
47 }
48
49 static int buffer_f_in = -1;
50
51 void setup_readbuffer(int f_in)
52 {
53         buffer_f_in = f_in;
54 }
55
56 static void check_timeout(void)
57 {
58         time_t t;
59         
60         if (!io_timeout) return;
61
62         if (!last_io) {
63                 last_io = time(NULL);
64                 return;
65         }
66
67         t = time(NULL);
68
69         if (last_io && io_timeout && (t-last_io)>io_timeout) {
70                 rprintf(FERROR,"read timeout after %d second - exiting\n", 
71                         (int)(t-last_io));
72                 exit_cleanup(1);
73         }
74 }
75
76
77 static char *read_buffer;
78 static char *read_buffer_p;
79 static int read_buffer_len;
80 static int read_buffer_size;
81
82 /* read from a socket with IO timeout. return the number of
83    bytes read. If no bytes can be read then exit, never return
84    a number <= 0 */
85 static int read_timeout(int fd, char *buf, int len)
86 {
87         int n, ret=0;
88
89         while (ret == 0) {
90                 fd_set fds;
91                 struct timeval tv;
92
93                 FD_ZERO(&fds);
94                 FD_SET(fd, &fds);
95                 tv.tv_sec = io_timeout;
96                 tv.tv_usec = 0;
97
98                 if (select(fd+1, &fds, NULL, NULL, 
99                            io_timeout?&tv:NULL) != 1) {
100                         check_timeout();
101                         continue;
102                 }
103
104                 n = read(fd, buf, len);
105
106                 if (n > 0) {
107                         buf += n;
108                         len -= n;
109                         ret += n;
110                         if (io_timeout)
111                                 last_io = time(NULL);
112                         continue;
113                 }
114
115                 if (n == -1 && errno == EINTR) {
116                         continue;
117                 }
118
119                 if (n == 0) {
120                         rprintf(FERROR,"EOF in read_timeout\n");
121                         exit_cleanup(1);
122                 }
123
124                 rprintf(FERROR,"read error: %s\n", strerror(errno));
125                 exit_cleanup(1);
126         }
127
128         return ret;
129 }
130
131 /* continue trying to read len bytes - don't return until len
132    has been read */
133 static void read_loop(int fd, char *buf, int len)
134 {
135         while (len) {
136                 int n = read_timeout(fd, buf, len);
137
138                 buf += n;
139                 len -= n;
140         }
141 }
142
143 /* read from the file descriptor handing multiplexing - 
144    return number of bytes read
145    never return <= 0 */
146 static int read_unbuffered(int fd, char *buf, int len)
147 {
148         static int remaining;
149         char ibuf[4];
150         int tag, ret=0;
151         char line[1024];
152
153         if (!io_multiplexing_in || fd != multiplex_in_fd) 
154                 return read_timeout(fd, buf, len);
155
156         while (ret == 0) {
157                 if (remaining) {
158                         len = MIN(len, remaining);
159                         read_loop(fd, buf, len);
160                         remaining -= len;
161                         ret = len;
162                         continue;
163                 }
164
165                 read_loop(fd, ibuf, 4);
166                 tag = IVAL(ibuf, 0);
167
168                 remaining = tag & 0xFFFFFF;
169                 tag = tag >> 24;
170
171                 if (tag == MPLEX_BASE) continue;
172
173                 tag -= MPLEX_BASE;
174
175                 if (tag != FERROR && tag != FINFO) {
176                         rprintf(FERROR,"unexpected tag %d\n", tag);
177                         exit_cleanup(1);
178                 }
179
180                 if (remaining > sizeof(line)-1) {
181                         rprintf(FERROR,"multiplexing overflow %d\n\n", 
182                                 remaining);
183                         exit_cleanup(1);
184                 }
185
186                 read_loop(fd, line, remaining);
187                 line[remaining] = 0;
188
189                 rprintf(tag,"%s", line);
190                 remaining = 0;
191         }
192
193         return ret;
194 }
195
196
197
198 /* This function was added to overcome a deadlock problem when using
199  * ssh.  It looks like we can't allow our receive queue to get full or
200  * ssh will clag up. Uggh.  */
201 static void read_check(int f)
202 {
203         int n = 8192;
204
205         if (f == -1) return;
206
207         if (read_buffer_len == 0) {
208                 read_buffer_p = read_buffer;
209         }
210
211         if (n > MAX_READ_BUFFER/4)
212                 n = MAX_READ_BUFFER/4;
213
214         if (read_buffer_p != read_buffer) {
215                 memmove(read_buffer,read_buffer_p,read_buffer_len);
216                 read_buffer_p = read_buffer;
217         }
218
219         if (n > (read_buffer_size - read_buffer_len)) {
220                 read_buffer_size += n;
221                 if (!read_buffer)
222                         read_buffer = (char *)malloc(read_buffer_size);
223                 else
224                         read_buffer = (char *)realloc(read_buffer,read_buffer_size);
225                 if (!read_buffer) out_of_memory("read check");      
226                 read_buffer_p = read_buffer;      
227         }
228
229         n = read_unbuffered(f,read_buffer+read_buffer_len,n);
230         read_buffer_len += n;
231 }
232
233
234 /* do a buffered read from fd. don't return until all N bytes
235    have been read. If all N can't be read then exit with an error */
236 static void readfd(int fd,char *buffer,int N)
237 {
238         int  ret;
239         int total=0;  
240         
241         if (read_buffer_len < N && N < 1024) {
242                 read_check(buffer_f_in);
243         }
244         
245         while (total < N) {
246                 if (read_buffer_len > 0 && buffer_f_in == fd) {
247                         ret = MIN(read_buffer_len,N-total);
248                         memcpy(buffer+total,read_buffer_p,ret);
249                         read_buffer_p += ret;
250                         read_buffer_len -= ret;
251                         total += ret;
252                         continue;
253                 } 
254
255                 io_flush();
256
257                 ret = read_unbuffered(fd,buffer + total,N-total);
258                 total += ret;
259         }
260 }
261
262
263 int32 read_int(int f)
264 {
265         char b[4];
266         readfd(f,b,4);
267         total_read += 4;
268         return IVAL(b,0);
269 }
270
271 int64 read_longint(int f)
272 {
273         extern int remote_version;
274         int64 ret;
275         char b[8];
276         ret = read_int(f);
277
278         if ((int32)ret != (int32)0xffffffff) return ret;
279
280 #ifdef NO_INT64
281         rprintf(FERROR,"Integer overflow - attempted 64 bit offset\n");
282         exit_cleanup(1);
283 #else
284         if (remote_version >= 16) {
285                 readfd(f,b,8);
286                 total_read += 8;
287                 ret = IVAL(b,0) | (((int64)IVAL(b,4))<<32);
288         }
289 #endif
290
291         return ret;
292 }
293
294 void read_buf(int f,char *buf,int len)
295 {
296         readfd(f,buf,len);
297         total_read += len;
298 }
299
300 void read_sbuf(int f,char *buf,int len)
301 {
302         read_buf(f,buf,len);
303         buf[len] = 0;
304 }
305
306 unsigned char read_byte(int f)
307 {
308         unsigned char c;
309         read_buf(f,(char *)&c,1);
310         return c;
311 }
312
313
314
315 /* write len bytes to fd, possibly reading from buffer_f_in if set
316    in order to unclog the pipe. don't return until all len
317    bytes have been written */
318 static void writefd_unbuffered(int fd,char *buf,int len)
319 {
320         int total = 0;
321         fd_set w_fds, r_fds;
322         int fd_count, count;
323         struct timeval tv;
324         int reading;
325
326         reading = (buffer_f_in != -1 && read_buffer_len < MAX_READ_BUFFER);
327
328         while (total < len) {
329                 FD_ZERO(&w_fds);
330                 FD_ZERO(&r_fds);
331                 FD_SET(fd,&w_fds);
332                 fd_count = fd+1;
333
334                 if (reading) {
335                         FD_SET(buffer_f_in,&r_fds);
336                         if (buffer_f_in > fd) 
337                                 fd_count = buffer_f_in+1;
338                 }
339
340                 tv.tv_sec = io_timeout;
341                 tv.tv_usec = 0;
342
343                 count = select(fd_count,
344                                reading?&r_fds:NULL,
345                                &w_fds,NULL,
346                                io_timeout?&tv:NULL);
347
348                 if (count <= 0) {
349                         check_timeout();
350                         continue;
351                 }
352
353                 if (FD_ISSET(fd, &w_fds)) {
354                         int ret = write(fd,buf+total,len-total);
355
356                         if (ret == -1 && errno == EINTR) {
357                                 continue;
358                         }
359
360                         if (ret <= 0) {
361                                 rprintf(FERROR,"erroring writing %d bytes - exiting\n", len);
362                                 exit_cleanup(1);
363                         }
364
365                         total += ret;
366                         if (io_timeout)
367                                 last_io = time(NULL);
368                         continue;
369                 }
370
371                 if (reading && FD_ISSET(buffer_f_in, &r_fds)) {
372                         read_check(buffer_f_in);
373                 }
374         }
375 }
376
377
378 static char *io_buffer;
379 static int io_buffer_count;
380
381 void io_start_buffering(int fd)
382 {
383         if (io_buffer) return;
384         multiplex_out_fd = fd;
385         io_buffer = (char *)malloc(IO_BUFFER_SIZE+4);
386         if (!io_buffer) out_of_memory("writefd");
387         io_buffer_count = 0;
388
389         /* leave room for the multiplex header in case it's needed */
390         io_buffer += 4;
391 }
392
393 void io_flush(void)
394 {
395         int fd = multiplex_out_fd;
396         if (!io_buffer_count) return;
397
398         if (io_multiplexing_out) {
399                 SIVAL(io_buffer-4, 0, (MPLEX_BASE<<24) + io_buffer_count);
400                 writefd_unbuffered(fd, io_buffer-4, io_buffer_count+4);
401         } else {
402                 writefd_unbuffered(fd, io_buffer, io_buffer_count);
403         }
404         io_buffer_count = 0;
405 }
406
407 void io_end_buffering(int fd)
408 {
409         io_flush();
410         if (!io_multiplexing_out) {
411                 free(io_buffer-4);
412                 io_buffer = NULL;
413         }
414 }
415
416 static void writefd(int fd,char *buf,int len)
417 {
418         if (!io_buffer) {
419                 writefd_unbuffered(fd, buf, len);
420                 return;
421         }
422
423         while (len) {
424                 int n = MIN(len, IO_BUFFER_SIZE-io_buffer_count);
425                 if (n > 0) {
426                         memcpy(io_buffer+io_buffer_count, buf, n);
427                         buf += n;
428                         len -= n;
429                         io_buffer_count += n;
430                 }
431                 
432                 if (io_buffer_count == IO_BUFFER_SIZE) io_flush();
433         }
434 }
435
436
437 void write_int(int f,int32 x)
438 {
439         char b[4];
440         SIVAL(b,0,x);
441         writefd(f,b,4);
442         total_written += 4;
443 }
444
445 void write_longint(int f, int64 x)
446 {
447         extern int remote_version;
448         char b[8];
449
450         if (remote_version < 16 || x <= 0x7FFFFFFF) {
451                 write_int(f, (int)x);
452                 return;
453         }
454
455         write_int(f, -1);
456         SIVAL(b,0,(x&0xFFFFFFFF));
457         SIVAL(b,4,((x>>32)&0xFFFFFFFF));
458
459         writefd(f,b,8);
460         total_written += 8;
461 }
462
463 void write_buf(int f,char *buf,int len)
464 {
465         writefd(f,buf,len);
466         total_written += len;
467 }
468
469 /* write a string to the connection */
470 void write_sbuf(int f,char *buf)
471 {
472         write_buf(f, buf, strlen(buf));
473 }
474
475
476 void write_byte(int f,unsigned char c)
477 {
478         write_buf(f,(char *)&c,1);
479 }
480
481 int read_line(int f, char *buf, int maxlen)
482 {
483         while (maxlen) {
484                 read_buf(f, buf, 1);
485                 if (buf[0] == '\n') {
486                         buf[0] = 0;
487                         break;
488                 }
489                 if (buf[0] != '\r') {
490                         buf++;
491                         maxlen--;
492                 }
493         }
494         if (maxlen == 0) {
495                 *buf = 0;
496                 return 0;
497         }
498         return 1;
499 }
500
501
502 void io_printf(int fd, const char *format, ...)
503 {
504         va_list ap;  
505         char buf[1024];
506         int len;
507         
508         va_start(ap, format);
509         len = vslprintf(buf, sizeof(buf)-1, format, ap);
510         va_end(ap);
511
512         if (len < 0) exit_cleanup(1);
513
514         write_sbuf(fd, buf);
515 }
516
517
518 /* setup for multiplexing an error stream with the data stream */
519 void io_start_multiplex_out(int fd)
520 {
521         multiplex_out_fd = fd;
522         io_flush();
523         io_start_buffering(fd);
524         io_multiplexing_out = 1;
525 }
526
527 /* setup for multiplexing an error stream with the data stream */
528 void io_start_multiplex_in(int fd)
529 {
530         multiplex_in_fd = fd;
531         io_flush();
532         if (read_buffer_len) {
533                 fprintf(stderr,"ERROR: data in read buffer at mplx start\n");
534                 exit_cleanup(1);
535         }
536
537         io_multiplexing_in = 1;
538 }
539
540 /* write an message to the error stream */
541 int io_multiplex_write(int f, char *buf, int len)
542 {
543         if (!io_multiplexing_out) return 0;
544
545         io_flush();
546
547         SIVAL(io_buffer-4, 0, ((MPLEX_BASE + f)<<24) + len);
548         memcpy(io_buffer, buf, len);
549
550         writefd_unbuffered(multiplex_out_fd, io_buffer-4, len+4);
551         return 1;
552 }
553
554 void io_close_input(int fd)
555 {
556         buffer_f_in = -1;
557 }