lib: unix_dgram_msg does not need "num_fds"
[samba.git] / source3 / lib / unix_msg / unix_msg.c
index ad415ddc01061bedc7bf362c36dffa3c5e42082c..d85cde9fb3ffe7664b1f4cbd9dae4e7a6c490577 100644 (file)
@@ -23,6 +23,7 @@
 #include "system/network.h"
 #include "dlinklist.h"
 #include "pthreadpool/pthreadpool.h"
+#include "lib/iov_buf.h"
 #include <fcntl.h>
 
 /*
@@ -42,10 +43,8 @@ struct unix_dgram_msg {
        int sock;
        ssize_t sent;
        int sys_errno;
-       size_t num_fds;
-       int *fds;
-       size_t buflen;
-       uint8_t buf[];
+       struct msghdr msg;
+       struct iovec iov;
 };
 
 struct unix_dgram_send_queue {
@@ -78,7 +77,6 @@ struct unix_dgram_ctx {
        char path[];
 };
 
-static ssize_t iov_buflen(const struct iovec *iov, int iovlen);
 static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
                                    void *private_data);
 
@@ -138,6 +136,32 @@ static int prepare_socket(int sock)
        return prepare_socket_cloexec(sock);
 }
 
+static void extract_fd_array_from_msghdr(struct msghdr *msg, int **fds,
+                                        size_t *num_fds)
+{
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       struct cmsghdr *cmsg;
+
+       for(cmsg = CMSG_FIRSTHDR(msg);
+           cmsg != NULL;
+           cmsg = CMSG_NXTHDR(msg, cmsg))
+       {
+               void *data = CMSG_DATA(cmsg);
+
+               if (cmsg->cmsg_type != SCM_RIGHTS) {
+                       continue;
+               }
+               if (cmsg->cmsg_level != SOL_SOCKET) {
+                       continue;
+               }
+
+               *fds = (int *)data;
+               *num_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof (int);
+               break;
+       }
+#endif
+}
+
 static void close_fd_array(int *fds, size_t num_fds)
 {
        size_t i;
@@ -152,6 +176,19 @@ static void close_fd_array(int *fds, size_t num_fds)
        }
 }
 
+static void close_fd_array_cmsg(struct msghdr *msg)
+{
+       int *fds = NULL;
+       size_t num_fds = 0;
+
+       extract_fd_array_from_msghdr(msg, &fds, &num_fds);
+
+       /*
+        * TODO: caveat - side-effect - changing msg ???
+        */
+       close_fd_array(fds, num_fds);
+}
+
 static int unix_dgram_init(const struct sockaddr_un *addr, size_t max_msg,
                           const struct poll_funcs *ev_funcs,
                           void (*recv_callback)(struct unix_dgram_ctx *ctx,
@@ -249,7 +286,6 @@ static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
        struct iovec iov;
 #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
        char buf[CMSG_SPACE(sizeof(int)*INT8_MAX)] = { 0, };
-       struct cmsghdr *cmsg;
 #endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
        int *fds = NULL;
        size_t i, num_fds = 0;
@@ -289,24 +325,7 @@ static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
                return;
        }
 
-#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
-       for(cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
-           cmsg = CMSG_NXTHDR(&msg, cmsg))
-       {
-               void *data = CMSG_DATA(cmsg);
-
-               if (cmsg->cmsg_type != SCM_RIGHTS) {
-                       continue;
-               }
-               if (cmsg->cmsg_level != SOL_SOCKET) {
-                       continue;
-               }
-
-               fds = (int *)data;
-               num_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof (int);
-               break;
-       }
-#endif
+       extract_fd_array_from_msghdr(&msg, &fds, &num_fds);
 
        for (i = 0; i < num_fds; i++) {
                int err;
@@ -319,6 +338,12 @@ static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
 
        ctx->recv_callback(ctx, ctx->recv_buf, received,
                           fds, num_fds, ctx->private_data);
+
+       /*
+        * Close those fds that the callback has not set to -1.
+        */
+       close_fd_array(fds, num_fds);
+
        return;
 
 cleanup_fds:
@@ -423,7 +448,7 @@ static void unix_dgram_send_queue_free(struct unix_dgram_send_queue *q)
                struct unix_dgram_msg *msg;
                msg = q->msgs;
                DLIST_REMOVE(q->msgs, msg);
-               close_fd_array(msg->fds, msg->num_fds);
+               close_fd_array_cmsg(&msg->msg);
                free(msg);
        }
        close(q->sock);
@@ -449,138 +474,136 @@ static int queue_msg(struct unix_dgram_send_queue *q,
                     const int *fds, size_t num_fds)
 {
        struct unix_dgram_msg *msg;
-       ssize_t buflen;
-       size_t msglen;
-       size_t fds_size = sizeof(int) * num_fds;
-       int fds_copy[MIN(num_fds, INT8_MAX)];
-       size_t fds_padding = 0;
+       ssize_t data_len;
+       uint8_t *data_buf;
+       size_t msglen = sizeof(struct unix_dgram_msg);
        int i;
        size_t tmp;
        int ret = -1;
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       size_t fds_size = sizeof(int) * MIN(num_fds, INT8_MAX);
+       int fds_copy[MIN(num_fds, INT8_MAX)];
+       size_t cmsg_len = CMSG_LEN(fds_size);
+       size_t cmsg_space = CMSG_SPACE(fds_size);
+       char *cmsg_buf;
+
+       /*
+        * Note: No need to check for overflow here,
+        * since cmsg will store <= INT8_MAX fds.
+        */
+       msglen += cmsg_space;
+
+#endif /*  HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 
        if (num_fds > INT8_MAX) {
                return EINVAL;
        }
 
-       for (i = 0; i < num_fds; i++) {
-               fds_copy[i] = -1;
-       }
-
-       for (i = 0; i < num_fds; i++) {
-               fds_copy[i] = dup(fds[i]);
-               if (fds_copy[i] == -1) {
-                       ret = errno;
-                       goto fail;
-               }
+#ifndef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               return ENOSYS;
        }
+#endif
 
-       buflen = iov_buflen(iov, iovlen);
-       if (buflen == -1) {
-               goto invalid;
+       data_len = iov_buflen(iov, iovlen);
+       if (data_len == -1) {
+               return EINVAL;
        }
 
-       msglen = offsetof(struct unix_dgram_msg, buf);
-       tmp = msglen + buflen;
-       if ((tmp < msglen) || (tmp < buflen)) {
+       tmp = msglen + data_len;
+       if ((tmp < msglen) || (tmp < data_len)) {
                /* overflow */
-               goto invalid;
+               return EINVAL;
        }
        msglen = tmp;
 
-       if (num_fds > 0) {
-               const size_t fds_align = sizeof(int) - 1;
-
-               tmp = msglen + fds_align;
-               if ((tmp < msglen) || (tmp < fds_align)) {
-                       /* overflow */
-                       goto invalid;
-               }
-               tmp &= ~fds_align;
-
-               fds_padding = tmp - msglen;
-               msglen = tmp;
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       for (i = 0; i < num_fds; i++) {
+               fds_copy[i] = -1;
+       }
 
-               tmp = msglen + fds_size;
-               if ((tmp < msglen) || (tmp < fds_size)) {
-                       /* overflow */
-                       goto invalid;
+       for (i = 0; i < num_fds; i++) {
+               fds_copy[i] = dup(fds[i]);
+               if (fds_copy[i] == -1) {
+                       ret = errno;
+                       goto fail;
                }
-               msglen = tmp;
        }
+#endif
 
        msg = malloc(msglen);
        if (msg == NULL) {
                ret = ENOMEM;
                goto fail;
        }
-       msg->buflen = buflen;
+
        msg->sock = q->sock;
 
-       buflen = 0;
-       for (i=0; i<iovlen; i++) {
-               memcpy(&msg->buf[buflen], iov[i].iov_base, iov[i].iov_len);
-               buflen += iov[i].iov_len;
-       }
+       data_buf = (uint8_t *)(msg + 1);
 
-       msg->num_fds = num_fds;
-       if (msg->num_fds > 0) {
-               void *fds_ptr = (void *)&msg->buf[buflen+fds_padding];
-               memcpy(fds_ptr, fds_copy, fds_size);
-               msg->fds = (int *)fds_ptr;
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               cmsg_buf = (char *)data_buf;
+               memset(cmsg_buf, 0, cmsg_space);
+               data_buf += cmsg_space;
        } else {
-               msg->fds = NULL;
+               cmsg_buf = NULL;
+               cmsg_space = 0;
        }
+#endif
 
-       DLIST_ADD_END(q->msgs, msg, struct unix_dgram_msg);
-       return 0;
-
-invalid:
-       ret = EINVAL;
-fail:
-       close_fd_array(fds_copy, num_fds);
-       return ret;
-}
-
-static void unix_dgram_send_job(void *private_data)
-{
-       struct unix_dgram_msg *dmsg = private_data;
-       struct iovec iov = {
-               .iov_base = (void *)dmsg->buf,
-               .iov_len = dmsg->buflen,
+       msg->iov = (struct iovec) {
+               .iov_base = (void *)data_buf,
+               .iov_len = data_len,
        };
-       struct msghdr msg = {
-               .msg_iov = &iov,
+
+       msg->msg = (struct msghdr) {
+               .msg_iov = &msg->iov,
                .msg_iovlen = 1,
-       };
 #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
-       struct cmsghdr *cmsg;
-       size_t fds_size = sizeof(int) * dmsg->num_fds;
-       size_t cmsg_len = CMSG_LEN(fds_size);
-       size_t cmsg_space = CMSG_SPACE(fds_size);
-       char cmsg_buf[cmsg_space];
+               .msg_control = cmsg_buf,
+               .msg_controllen = cmsg_space,
+#endif
+       };
 
-       if (dmsg->num_fds > 0) {
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               struct cmsghdr *cmsg;
                void *fdptr;
 
-               memset(cmsg_buf, 0, cmsg_space);
-
-               msg.msg_control = cmsg_buf;
-               msg.msg_controllen = cmsg_space;
-               cmsg = CMSG_FIRSTHDR(&msg);
+               cmsg = CMSG_FIRSTHDR(&msg->msg);
                cmsg->cmsg_level = SOL_SOCKET;
                cmsg->cmsg_type = SCM_RIGHTS;
                cmsg->cmsg_len = cmsg_len;
                fdptr = CMSG_DATA(cmsg);
-               memcpy(fdptr, dmsg->fds, fds_size);
-               msg.msg_controllen = cmsg->cmsg_len;
+               memcpy(fdptr, fds_copy, fds_size);
+               msg->msg.msg_controllen = cmsg->cmsg_len;
        }
 #endif /*  HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 
+       iov_buf(iov, iovlen, data_buf, data_len);
+
+       DLIST_ADD_END(q->msgs, msg, struct unix_dgram_msg);
+       return 0;
+
+fail:
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       close_fd_array(fds_copy, num_fds);
+#endif
+       return ret;
+}
+
+static void unix_dgram_send_job(void *private_data)
+{
+       struct unix_dgram_msg *dmsg = private_data;
+
        do {
-               dmsg->sent = sendmsg(dmsg->sock, &msg, 0);
+               dmsg->sent = sendmsg(dmsg->sock, &dmsg->msg, 0);
        } while ((dmsg->sent == -1) && (errno == EINTR));
 
-       close_fd_array(dmsg->fds, dmsg->num_fds);
+       if (dmsg->sent == -1) {
+               dmsg->sys_errno = errno;
+       }
 }
 
 static void unix_dgram_job_finished(struct poll_watch *w, int fd, short events,
@@ -609,7 +632,7 @@ static void unix_dgram_job_finished(struct poll_watch *w, int fd, short events,
 
        msg = q->msgs;
        DLIST_REMOVE(q->msgs, msg);
-       close_fd_array(msg->fds, msg->num_fds);
+       close_fd_array_cmsg(&msg->msg);
        free(msg);
 
        if (q->msgs != NULL) {
@@ -713,7 +736,13 @@ static int unix_dgram_send(struct unix_dgram_ctx *ctx,
        if (ret >= 0) {
                return 0;
        }
-       if ((errno != EWOULDBLOCK) && (errno != EAGAIN) && (errno != EINTR)) {
+       if ((errno != EWOULDBLOCK) &&
+           (errno != EAGAIN) &&
+#ifdef ENOBUFS
+           /* FreeBSD can give this for large messages */
+           (errno != ENOBUFS) &&
+#endif
+           (errno != EINTR)) {
                return errno;
        }
 
@@ -1072,21 +1101,3 @@ int unix_msg_free(struct unix_msg_ctx *ctx)
        free(ctx);
        return 0;
 }
-
-static ssize_t iov_buflen(const struct iovec *iov, int iovlen)
-{
-       size_t buflen = 0;
-       int i;
-
-       for (i=0; i<iovlen; i++) {
-               size_t thislen = iov[i].iov_len;
-               size_t tmp = buflen + thislen;
-
-               if ((tmp < buflen) || (tmp < thislen)) {
-                       /* overflow */
-                       return -1;
-               }
-               buflen = tmp;
-       }
-       return buflen;
-}