lib/replace: make sure krb5_cc_default[_name]() is no longer used directly
[samba.git] / source3 / lib / recvfile.c
index 31d9311498d2399df4318b3bdf56734c4a4b11d5..e1eb241d7bd8ce168443882d3cc7cad145cb6590 100644 (file)
 
 #include "includes.h"
 #include "system/filesys.h"
+#include "lib/util/sys_rw.h"
 
 /* Do this on our own in TRANSFER_BUF_SIZE chunks.
  * It's safe to make direct syscalls to lseek/write here
  * as we're below the Samba vfs layer.
  *
- * If tofd is -1 we just drain the incoming socket of count
- * bytes without writing to the outgoing fd.
- * If a write fails we do the same (to cope with disk full)
- * errors.
- *
  * Returns -1 on short reads from fromfd (read error)
  * and sets errno.
  *
  * Returns number of bytes written to 'tofd'
- * or thrown away if 'tofd == -1'.
  * return != count then sets errno.
  * Returns count if complete success.
  */
 
 static ssize_t default_sys_recvfile(int fromfd,
                        int tofd,
-                       SMB_OFF_T offset,
+                       off_t offset,
                        size_t count)
 {
        int saved_errno = 0;
        size_t total = 0;
        size_t bufsize = MIN(TRANSFER_BUF_SIZE,count);
        size_t total_written = 0;
-       char *buffer = NULL;
+       char buffer[bufsize];
 
        DEBUG(10,("default_sys_recvfile: from = %d, to = %d, "
                "offset=%.0f, count = %lu\n",
@@ -68,51 +63,69 @@ static ssize_t default_sys_recvfile(int fromfd,
                return 0;
        }
 
-       if (tofd != -1 && offset != (SMB_OFF_T)-1) {
-               if (sys_lseek(tofd, offset, SEEK_SET) == -1) {
+       if (tofd != -1 && offset != (off_t)-1) {
+               if (lseek(tofd, offset, SEEK_SET) == -1) {
                        if (errno != ESPIPE) {
                                return -1;
                        }
                }
        }
 
-       buffer = SMB_MALLOC_ARRAY(char, bufsize);
-       if (buffer == NULL) {
-               return -1;
-       }
-
        while (total < count) {
                size_t num_written = 0;
                ssize_t read_ret;
                size_t toread = MIN(bufsize,count - total);
 
-               /* Read from socket - ignore EINTR. */
-               read_ret = sys_read(fromfd, buffer, toread);
+               /*
+                * Read from socket - ignore EINTR.
+                * Can't use sys_read() as that also
+                * ignores EAGAIN and EWOULDBLOCK.
+                */
+               do {
+                       read_ret = read(fromfd, buffer, toread);
+               } while (read_ret == -1 && errno == EINTR);
+
+               if (read_ret == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
+                       /*
+                        * fromfd socket is in non-blocking mode.
+                        * If we already read some and wrote
+                        * it successfully, return that.
+                        * Only return -1 if this is the first read
+                        * attempt. Caller will handle both cases.
+                        */
+                       if (total_written != 0) {
+                               return total_written;
+                       }
+                       return -1;
+               }
+
                if (read_ret <= 0) {
                        /* EOF or socket error. */
-                       free(buffer);
                        return -1;
                }
 
                num_written = 0;
 
-               while (num_written < read_ret) {
+               /* Don't write any more after a write error. */
+               while (tofd != -1 && (num_written < read_ret)) {
                        ssize_t write_ret;
 
-                       if (tofd == -1) {
-                               write_ret = read_ret;
-                       } else {
-                               /* Write to file - ignore EINTR. */
-                               write_ret = sys_write(tofd,
-                                               buffer + num_written,
-                                               read_ret - num_written);
-
-                               if (write_ret <= 0) {
-                                       /* write error - stop writing. */
-                                       tofd = -1;
-                                       saved_errno = errno;
-                                       continue;
-                               }
+                       /* Write to file - ignore EINTR. */
+                       write_ret = sys_write(tofd,
+                                       buffer + num_written,
+                                       read_ret - num_written);
+
+                       if (write_ret <= 0) {
+                               /* write error - stop writing. */
+                               tofd = -1;
+                                if (total_written == 0) {
+                                       /* Ensure we return
+                                          -1 if the first
+                                          write failed. */
+                                        total_written = -1;
+                                }
+                               saved_errno = errno;
+                               break;
                        }
 
                        num_written += (size_t)write_ret;
@@ -122,7 +135,6 @@ static ssize_t default_sys_recvfile(int fromfd,
                total += read_ret;
        }
 
-       free(buffer);
        if (saved_errno) {
                /* Return the correct write error. */
                errno = saved_errno;
@@ -143,7 +155,7 @@ static ssize_t default_sys_recvfile(int fromfd,
 
 ssize_t sys_recvfile(int fromfd,
                        int tofd,
-                       SMB_OFF_T offset,
+                       off_t offset,
                        size_t count)
 {
        static int pipefd[2] = { -1, -1 };
@@ -194,6 +206,19 @@ ssize_t sys_recvfile(int fromfd,
                                return default_sys_recvfile(fromfd, tofd,
                                                            offset, count);
                        }
+                       if (errno == EAGAIN || errno == EWOULDBLOCK) {
+                               /*
+                                * fromfd socket is in non-blocking mode.
+                                * If we already read some and wrote
+                                * it successfully, return that.
+                                * Only return -1 if this is the first read
+                                * attempt. Caller will handle both cases.
+                                */
+                               if (total_written != 0) {
+                                       return total_written;
+                               }
+                               return -1;
+                       }
                        break;
                }
 
@@ -233,7 +258,7 @@ ssize_t sys_recvfile(int fromfd,
 
 ssize_t sys_recvfile(int fromfd,
                        int tofd,
-                       SMB_OFF_T offset,
+                       off_t offset,
                        size_t count)
 {
        return default_sys_recvfile(fromfd, tofd, offset, count);
@@ -243,20 +268,22 @@ ssize_t sys_recvfile(int fromfd,
 /*****************************************************************
  Throw away "count" bytes from the client socket.
  Returns count or -1 on error.
+ Must only operate on a blocking socket.
 *****************************************************************/
 
 ssize_t drain_socket(int sockfd, size_t count)
 {
        size_t total = 0;
        size_t bufsize = MIN(TRANSFER_BUF_SIZE,count);
-       char *buffer = NULL;
+       char buffer[bufsize];
+       int old_flags = 0;
 
        if (count == 0) {
                return 0;
        }
 
-       buffer = SMB_MALLOC_ARRAY(char, bufsize);
-       if (buffer == NULL) {
+       old_flags = fcntl(sockfd, F_GETFL, 0);
+       if (set_blocking(sockfd, true) == -1) {
                return -1;
        }
 
@@ -268,12 +295,16 @@ ssize_t drain_socket(int sockfd, size_t count)
                read_ret = sys_read(sockfd, buffer, toread);
                if (read_ret <= 0) {
                        /* EOF or socket error. */
-                       free(buffer);
-                       return -1;
+                       count = (size_t)-1;
+                       goto out;
                }
                total += read_ret;
        }
 
-       free(buffer);
+  out:
+
+       if (fcntl(sockfd, F_SETFL, old_flags) == -1) {
+               return -1;
+       }
        return count;
 }