Implement --privacy option, though apparently not working yet.
[rsync.git] / io.c
1 /* 
2    Copyright (C) Andrew Tridgell 1996
3    Copyright (C) Paul Mackerras 1996
4    
5    This program is free software; you can redistribute it and/or modify
6    it under the terms of the GNU General Public License as published by
7    the Free Software Foundation; either version 2 of the License, or
8    (at your option) any later version.
9    
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU General Public License for more details.
14    
15    You should have received a copy of the GNU General Public License
16    along with this program; if not, write to the Free Software
17    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 */
19
20 /*
21   socket and pipe IO utilities used in rsync 
22
23   tridge, June 1996
24
25   Midstrength stream cypher privacy added by Martin Pool, October 2000.
26   */
27 #include "rsync.h"
28 #include "lib/arcfour.h"
29 #include "assert.h"
30
31 /* if no timeout is specified then use a 60 second select timeout */
32 #define SELECT_TIMEOUT 60
33
34 extern int bwlimit;
35
36 static int io_multiplexing_out;
37 static int io_multiplexing_in;
38 static int multiplex_in_fd;
39 static int multiplex_out_fd;
40 static time_t last_io;
41 static int eof_error=1;
42 extern int verbose;
43 extern int io_timeout;
44 extern struct stats stats;
45
46 extern ArcfourContext arcfour_enc_ctx, arcfour_dec_ctx;
47
48 static int io_error_fd = -1;
49
50 static void read_loop(int fd, char *buf, int len);
51
52 static void check_timeout(void)
53 {
54         extern int am_server, am_daemon;
55         time_t t;
56         
57         if (!io_timeout) return;
58
59         if (!last_io) {
60                 last_io = time(NULL);
61                 return;
62         }
63
64         t = time(NULL);
65
66         if (last_io && io_timeout && (t-last_io) >= io_timeout) {
67                 if (!am_server && !am_daemon) {
68                         rprintf(FERROR,"io timeout after %d second - exiting\n", 
69                                 (int)(t-last_io));
70                 }
71                 exit_cleanup(RERR_TIMEOUT);
72         }
73 }
74
75 /* setup the fd used to propogate errors */
76 void io_set_error_fd(int fd)
77 {
78         io_error_fd = fd;
79 }
80
81 /* read some data from the error fd and write it to the write log code */
82 static void read_error_fd(void)
83 {
84         char buf[200];
85         int n;
86         int fd = io_error_fd;
87         int tag, len;
88
89         io_error_fd = -1;
90
91         read_loop(fd, buf, 4);
92         tag = IVAL(buf, 0);
93
94         len = tag & 0xFFFFFF;
95         tag = tag >> 24;
96         tag -= MPLEX_BASE;
97
98         while (len) {
99                 n = len;
100                 if (n > (sizeof(buf)-1)) n = sizeof(buf)-1;
101                 read_loop(fd, buf, n);
102                 rwrite((enum logcode)tag, buf, n);
103                 len -= n;
104         }
105
106         io_error_fd = fd;
107 }
108
109
110 static int no_flush;
111
112 /*
113  * This is the most fundamental socket read function -- the only one that
114  * actually calls the kernel.
115  *
116  * It reads from a socket with IO timeout. return the number of bytes
117  * read. If no bytes can be read then exit, never return a number <= 0
118  *
119  * If arcfour_enabled is set, it decrypts data while reading using the
120  * global arcfour state.
121  */
122 static int read_timeout(int fd, char *buf, int len)
123 {
124         int n, ret=0;
125
126         io_flush();
127
128         while (ret == 0) {
129                 fd_set fds;
130                 struct timeval tv;
131                 int fd_count = fd+1;
132
133                 FD_ZERO(&fds);
134                 FD_SET(fd, &fds);
135                 if (io_error_fd != -1) {
136                         FD_SET(io_error_fd, &fds);
137                         if (io_error_fd > fd) fd_count = io_error_fd+1;
138                 }
139
140                 tv.tv_sec = io_timeout?io_timeout:SELECT_TIMEOUT;
141                 tv.tv_usec = 0;
142
143                 errno = 0;
144
145                 if (select(fd_count, &fds, NULL, NULL, &tv) < 1) {
146                         if (errno == EBADF) {
147                                 exit_cleanup(RERR_SOCKETIO);
148                         }
149                         check_timeout();
150                         continue;
151                 }
152
153                 if (io_error_fd != -1 && FD_ISSET(io_error_fd, &fds)) {
154                         read_error_fd();
155                 }
156
157                 if (!FD_ISSET(fd, &fds)) continue;
158
159                 n = read(fd, buf, len);
160
161                 if (n > 0) {
162                         /* arcfour can decrypt in place. */
163                         if (arcfour_enabled) {
164                                 rprintf(FERROR, "decrypt %d bytes..", n);
165                                 arcfour_decrypt(&arcfour_dec_ctx, buf, buf, n);
166                                 rprintf(FERROR, "done\n");
167                         }
168                         buf += n;
169                         len -= n;
170                         ret += n;
171                         if (io_timeout)
172                                 last_io = time(NULL);
173                         continue;
174                 }
175
176                 if (n == -1 && errno == EINTR) {
177                         continue;
178                 }
179
180                 if (n == -1 && 
181                     (errno == EWOULDBLOCK || errno == EAGAIN)) {
182                         continue;
183                 }
184
185
186                 if (n == 0) {
187                         if (eof_error) {
188                                 rprintf(FERROR,"unexpected EOF in read_timeout\n");
189                         }
190                         exit_cleanup(RERR_STREAMIO);
191                 }
192
193                 /* this prevents us trying to write errors on a dead socket */
194                 io_multiplexing_close();
195
196                 rprintf(FERROR,"read error: %s\n", strerror(errno));
197                 exit_cleanup(RERR_STREAMIO);
198         }
199
200         return ret;
201 }
202
203 /* continue trying to read len bytes - don't return until len
204    has been read */
205 static void read_loop(int fd, char *buf, int len)
206 {
207         while (len) {
208                 int n = read_timeout(fd, buf, len);
209
210                 buf += n;
211                 len -= n;
212         }
213 }
214
215 /* read from the file descriptor handling multiplexing - 
216    return number of bytes read
217    never return <= 0 */
218 static int read_unbuffered(int fd, char *buf, int len)
219 {
220         static int remaining;
221         int tag, ret=0;
222         char line[1024];
223
224         if (!io_multiplexing_in || fd != multiplex_in_fd) 
225                 return read_timeout(fd, buf, len);
226
227         while (ret == 0) {
228                 if (remaining) {
229                         len = MIN(len, remaining);
230                         read_loop(fd, buf, len);
231                         remaining -= len;
232                         ret = len;
233                         continue;
234                 }
235
236                 read_loop(fd, line, 4);
237                 tag = IVAL(line, 0);
238
239                 remaining = tag & 0xFFFFFF;
240                 tag = tag >> 24;
241
242                 if (tag == MPLEX_BASE) continue;
243
244                 tag -= MPLEX_BASE;
245
246                 if (tag != FERROR && tag != FINFO) {
247                         rprintf(FERROR,"unexpected tag %d\n", tag);
248                         exit_cleanup(RERR_STREAMIO);
249                 }
250
251                 if (remaining > sizeof(line)-1) {
252                         rprintf(FERROR,"multiplexing overflow %d\n\n", 
253                                 remaining);
254                         exit_cleanup(RERR_STREAMIO);
255                 }
256
257                 read_loop(fd, line, remaining);
258                 line[remaining] = 0;
259
260                 rprintf((enum logcode)tag,"%s", line);
261                 remaining = 0;
262         }
263
264         return ret;
265 }
266
267
268 /* do a buffered read from fd. don't return until all N bytes
269    have been read. If all N can't be read then exit with an error */
270 static void readfd(int fd,char *buffer,int N)
271 {
272         int  ret;
273         int total=0;  
274         
275         while (total < N) {
276                 io_flush();
277
278                 ret = read_unbuffered(fd,buffer + total,N-total);
279                 total += ret;
280         }
281
282         stats.total_read += total;
283 }
284
285
286 int32 read_int(int f)
287 {
288         char b[4];
289         int32 ret;
290
291         readfd(f,b,4);
292         ret = IVAL(b,0);
293         if (ret == (int32)0xffffffff) return -1;
294         return ret;
295 }
296
297 int64 read_longint(int f)
298 {
299         extern int remote_version;
300         int64 ret;
301         char b[8];
302         ret = read_int(f);
303
304         if ((int32)ret != (int32)0xffffffff) {
305                 return ret;
306         }
307
308 #ifdef NO_INT64
309         rprintf(FERROR,"Integer overflow - attempted 64 bit offset\n");
310         exit_cleanup(RERR_UNSUPPORTED);
311 #else
312         if (remote_version >= 16) {
313                 readfd(f,b,8);
314                 ret = IVAL(b,0) | (((int64)IVAL(b,4))<<32);
315         }
316 #endif
317
318         return ret;
319 }
320
321 void read_buf(int f,char *buf,int len)
322 {
323         readfd(f,buf,len);
324 }
325
326 void read_sbuf(int f,char *buf,int len)
327 {
328         read_buf(f,buf,len);
329         buf[len] = 0;
330 }
331
332 unsigned char read_byte(int f)
333 {
334         unsigned char c;
335         read_buf(f,(char *)&c,1);
336         return c;
337 }
338
339
340 /*
341  * Write len bytes to fd.
342  *
343  * If arcfour_enabled is true, encrypt all data as it passes onto the
344  * wire using the global arcfour state.
345  */
346 static void writefd_unbuffered(int fd, char const *buf, int len)
347 {
348         int total = 0;
349         fd_set w_fds, r_fds;
350         int fd_count, count;
351         struct timeval tv;
352
353         no_flush++;
354
355         while (total < len) {
356                 FD_ZERO(&w_fds);
357                 FD_ZERO(&r_fds);
358                 FD_SET(fd,&w_fds);
359                 fd_count = fd;
360
361                 if (io_error_fd != -1) {
362                         FD_SET(io_error_fd,&r_fds);
363                         if (io_error_fd > fd_count) 
364                                 fd_count = io_error_fd;
365                 }
366
367                 tv.tv_sec = io_timeout?io_timeout:SELECT_TIMEOUT;
368                 tv.tv_usec = 0;
369
370                 errno = 0;
371
372                 count = select(fd_count+1,
373                                io_error_fd != -1?&r_fds:NULL,
374                                &w_fds,NULL,
375                                &tv);
376
377                 if (count <= 0) {
378                         if (errno == EBADF) {
379                                 exit_cleanup(RERR_SOCKETIO);
380                         }
381                         check_timeout();
382                         continue;
383                 }
384
385                 if (io_error_fd != -1 && FD_ISSET(io_error_fd, &r_fds)) {
386                         read_error_fd();
387                 }
388
389                 if (FD_ISSET(fd, &w_fds)) {
390                         int ret, n = len-total;
391
392                         ret = write(fd,buf+total,n);
393
394                         if (ret == -1 && errno == EINTR) {
395                                 continue;
396                         }
397
398                         if (ret == -1 && 
399                             (errno == EWOULDBLOCK || errno == EAGAIN)) {
400                                 msleep(1);
401                                 continue;
402                         }
403
404                         if (ret <= 0) {
405                                 rprintf(FERROR,"erroring writing %d bytes - exiting\n", len);
406                                 exit_cleanup(RERR_STREAMIO);
407                         }
408
409                         /* Sleep after writing to limit I/O bandwidth */
410                         if (bwlimit)
411                         {
412                             tv.tv_sec = 0;
413                             tv.tv_usec = ret * 1000 / bwlimit;
414                             while (tv.tv_usec > 1000000)
415                             {
416                                 tv.tv_sec++;
417                                 tv.tv_usec -= 1000000;
418                             }
419                             select(0, NULL, NULL, NULL, &tv);
420                         }
421  
422                         total += ret;
423
424                         if (io_timeout)
425                                 last_io = time(NULL);
426                 }
427         }
428
429         no_flush--;
430 }
431
432
433 /* build up a temporary buffer of encrypted data */
434 static void writefd_encrypt(int fd, char const *buf, int len)
435 {
436         static char *arcbuf = NULL;
437         static int buf_len = 0;
438         
439         if (arcfour_enabled) {
440                 assert(len > 0);
441                 if (len > buf_len  ||  !arcbuf) {
442                         if (arcbuf)
443                                 free(arcbuf);
444                         arcbuf = malloc(len);
445                         buf_len = len;
446                 }
447
448                 rprintf(FERROR, "encrypt %d bytes ..", len);
449
450                 arcfour_encrypt(&arcfour_dec_ctx, arcbuf, buf, len);
451 //                writefd_unbuffered(fd, arcbuf, len);
452                 writefd_unbuffered(fd, buf, len);
453                 rprintf(FERROR, "done\n");
454         } else {
455                 writefd_unbuffered(fd, buf, len);
456         }
457 }
458
459
460 static char *io_buffer;
461 static int io_buffer_count;
462
463 void io_start_buffering(int fd)
464 {
465         if (io_buffer) return;
466         multiplex_out_fd = fd;
467         io_buffer = (char *)malloc(IO_BUFFER_SIZE);
468         if (!io_buffer) out_of_memory("writefd");
469         io_buffer_count = 0;
470 }
471
472 /* write an message to a multiplexed stream. If this fails then rsync
473    exits */
474 static void mplex_write(int fd, enum logcode code, char *buf, int len)
475 {
476         char buffer[4096];
477         int n = len;
478
479         SIVAL(buffer, 0, ((MPLEX_BASE + (int)code)<<24) + len);
480
481         if (n > (sizeof(buffer)-4)) {
482                 n = sizeof(buffer)-4;
483         }
484
485         memcpy(&buffer[4], buf, n);
486         writefd_encrypt(fd, buffer, n+4);
487
488         len -= n;
489         buf += n;
490
491         if (len) {
492                 writefd_encrypt(fd, buf, len);
493         }
494 }
495
496
497 void io_flush(void)
498 {
499         int fd = multiplex_out_fd;
500         if (!io_buffer_count || no_flush) return;
501
502         if (io_multiplexing_out) {
503                 mplex_write(fd, FNONE, io_buffer, io_buffer_count);
504         } else {
505                 writefd_encrypt(fd, io_buffer, io_buffer_count);
506         }
507         io_buffer_count = 0;
508 }
509
510 void io_end_buffering(int fd)
511 {
512         io_flush();
513         if (!io_multiplexing_out) {
514                 free(io_buffer);
515                 io_buffer = NULL;
516         }
517 }
518
519 /* some OSes have a bug where an exit causes the pending writes on
520    a socket to be flushed. Do an explicit shutdown to try to prevent this */
521 void io_shutdown(void)
522 {
523         if (multiplex_out_fd != -1) close(multiplex_out_fd);
524         if (io_error_fd != -1) close(io_error_fd);
525         multiplex_out_fd = -1;
526         io_error_fd = -1;
527 }
528
529
530 static void writefd(int fd,char *buf,int len)
531 {
532         stats.total_written += len;
533
534         if (!io_buffer || fd != multiplex_out_fd) {
535                 writefd_encrypt(fd, buf, len);
536                 return;
537         }
538
539         while (len) {
540                 int n = MIN(len, IO_BUFFER_SIZE-io_buffer_count);
541                 if (n > 0) {
542                         memcpy(io_buffer+io_buffer_count, buf, n);
543                         buf += n;
544                         len -= n;
545                         io_buffer_count += n;
546                 }
547                 
548                 if (io_buffer_count == IO_BUFFER_SIZE) io_flush();
549         }
550 }
551
552
553 void write_int(int f,int32 x)
554 {
555         char b[4];
556         SIVAL(b,0,x);
557         writefd(f,b,4);
558 }
559
560 void write_longint(int f, int64 x)
561 {
562         extern int remote_version;
563         char b[8];
564
565         if (remote_version < 16 || x <= 0x7FFFFFFF) {
566                 write_int(f, (int)x);
567                 return;
568         }
569
570         write_int(f, (int32)0xFFFFFFFF);
571         SIVAL(b,0,(x&0xFFFFFFFF));
572         SIVAL(b,4,((x>>32)&0xFFFFFFFF));
573
574         writefd(f,b,8);
575 }
576
577 void write_buf(int f,char *buf,int len)
578 {
579         writefd(f,buf,len);
580 }
581
582 /* write a string to the connection */
583 static void write_sbuf(int f,char *buf)
584 {
585         write_buf(f, buf, strlen(buf));
586 }
587
588
589 void write_byte(int f,unsigned char c)
590 {
591         write_buf(f,(char *)&c,1);
592 }
593
594 int read_line(int f, char *buf, int maxlen)
595 {
596         eof_error = 0;
597
598         while (maxlen) {
599                 buf[0] = 0;
600                 read_buf(f, buf, 1);
601                 if (buf[0] == 0) return 0;
602                 if (buf[0] == '\n') {
603                         buf[0] = 0;
604                         break;
605                 }
606                 if (buf[0] != '\r') {
607                         buf++;
608                         maxlen--;
609                 }
610         }
611         if (maxlen == 0) {
612                 *buf = 0;
613                 return 0;
614         }
615
616         eof_error = 1;
617
618         return 1;
619 }
620
621
622 void io_printf(int fd, const char *format, ...)
623 {
624         va_list ap;  
625         char buf[1024];
626         int len;
627         
628         va_start(ap, format);
629         len = vslprintf(buf, sizeof(buf), format, ap);
630         va_end(ap);
631
632         if (len < 0) exit_cleanup(RERR_STREAMIO);
633
634         write_sbuf(fd, buf);
635 }
636
637
638 /* setup for multiplexing an error stream with the data stream */
639 void io_start_multiplex_out(int fd)
640 {
641         multiplex_out_fd = fd;
642         io_flush();
643         io_start_buffering(fd);
644         io_multiplexing_out = 1;
645 }
646
647 /* setup for multiplexing an error stream with the data stream */
648 void io_start_multiplex_in(int fd)
649 {
650         multiplex_in_fd = fd;
651         io_flush();
652         io_multiplexing_in = 1;
653 }
654
655 /* write an message to the multiplexed error stream */
656 int io_multiplex_write(enum logcode code, char *buf, int len)
657 {
658         if (!io_multiplexing_out) return 0;
659
660         io_flush();
661         stats.total_written += (len+4);
662         mplex_write(multiplex_out_fd, code, buf, len);
663         return 1;
664 }
665
666 /* write a message to the special error fd */
667 int io_error_write(int f, enum logcode code, char *buf, int len)
668 {
669         if (f == -1) return 0;
670         mplex_write(f, code, buf, len);
671         return 1;
672 }
673
674 /* stop output multiplexing */
675 void io_multiplexing_close(void)
676 {
677         io_multiplexing_out = 0;
678 }
679