lib: unix_dgram_msg does not need "num_fds"
[samba.git] / source3 / lib / unix_msg / unix_msg.c
index 00438cebfd6385ee065c075cb75da3418d5e5bd6..d85cde9fb3ffe7664b1f4cbd9dae4e7a6c490577 100644 (file)
@@ -23,6 +23,7 @@
 #include "system/network.h"
 #include "dlinklist.h"
 #include "pthreadpool/pthreadpool.h"
+#include "lib/iov_buf.h"
 #include <fcntl.h>
 
 /*
@@ -42,8 +43,8 @@ struct unix_dgram_msg {
        int sock;
        ssize_t sent;
        int sys_errno;
-       size_t buflen;
-       uint8_t buf[];
+       struct msghdr msg;
+       struct iovec iov;
 };
 
 struct unix_dgram_send_queue {
@@ -62,6 +63,7 @@ struct unix_dgram_ctx {
 
        void (*recv_callback)(struct unix_dgram_ctx *ctx,
                              uint8_t *msg, size_t msg_len,
+                             int *fds, size_t num_fds,
                              void *private_data);
        void *private_data;
 
@@ -75,7 +77,6 @@ struct unix_dgram_ctx {
        char path[];
 };
 
-static ssize_t iov_buflen(const struct iovec *iov, int iovlen);
 static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
                                    void *private_data);
 
@@ -135,10 +136,64 @@ static int prepare_socket(int sock)
        return prepare_socket_cloexec(sock);
 }
 
+static void extract_fd_array_from_msghdr(struct msghdr *msg, int **fds,
+                                        size_t *num_fds)
+{
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       struct cmsghdr *cmsg;
+
+       for(cmsg = CMSG_FIRSTHDR(msg);
+           cmsg != NULL;
+           cmsg = CMSG_NXTHDR(msg, cmsg))
+       {
+               void *data = CMSG_DATA(cmsg);
+
+               if (cmsg->cmsg_type != SCM_RIGHTS) {
+                       continue;
+               }
+               if (cmsg->cmsg_level != SOL_SOCKET) {
+                       continue;
+               }
+
+               *fds = (int *)data;
+               *num_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof (int);
+               break;
+       }
+#endif
+}
+
+static void close_fd_array(int *fds, size_t num_fds)
+{
+       size_t i;
+
+       for (i = 0; i < num_fds; i++) {
+               if (fds[i] == -1) {
+                       continue;
+               }
+
+               close(fds[i]);
+               fds[i] = -1;
+       }
+}
+
+static void close_fd_array_cmsg(struct msghdr *msg)
+{
+       int *fds = NULL;
+       size_t num_fds = 0;
+
+       extract_fd_array_from_msghdr(msg, &fds, &num_fds);
+
+       /*
+        * TODO: caveat - side-effect - changing msg ???
+        */
+       close_fd_array(fds, num_fds);
+}
+
 static int unix_dgram_init(const struct sockaddr_un *addr, size_t max_msg,
                           const struct poll_funcs *ev_funcs,
                           void (*recv_callback)(struct unix_dgram_ctx *ctx,
                                                 uint8_t *msg, size_t msg_len,
+                                                int *fds, size_t num_fds,
                                                 void *private_data),
                           void *private_data,
                           struct unix_dgram_ctx **result)
@@ -226,8 +281,14 @@ static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
 {
        struct unix_dgram_ctx *ctx = (struct unix_dgram_ctx *)private_data;
        ssize_t received;
+       int flags = 0;
        struct msghdr msg;
        struct iovec iov;
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       char buf[CMSG_SPACE(sizeof(int)*INT8_MAX)] = { 0, };
+#endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+       int *fds = NULL;
+       size_t i, num_fds = 0;
 
        iov = (struct iovec) {
                .iov_base = (void *)ctx->recv_buf,
@@ -238,17 +299,19 @@ static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
                .msg_iov = &iov,
                .msg_iovlen = 1,
 #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
-               .msg_control = NULL,
-               .msg_controllen = 0,
+               .msg_control = buf,
+               .msg_controllen = sizeof(buf),
 #endif
        };
 
-       received = recvmsg(fd, &msg, 0);
+#ifdef MSG_CMSG_CLOEXEC
+       flags |= MSG_CMSG_CLOEXEC;
+#endif
+
+       received = recvmsg(fd, &msg, flags);
        if (received == -1) {
                if ((errno == EAGAIN) ||
-#ifdef EWOULDBLOCK
                    (errno == EWOULDBLOCK) ||
-#endif
                    (errno == EINTR) || (errno == ENOMEM)) {
                        /* Not really an error - just try again. */
                        return;
@@ -261,7 +324,33 @@ static void unix_dgram_recv_handler(struct poll_watch *w, int fd, short events,
                /* More than we expected, not for us */
                return;
        }
-       ctx->recv_callback(ctx, ctx->recv_buf, received, ctx->private_data);
+
+       extract_fd_array_from_msghdr(&msg, &fds, &num_fds);
+
+       for (i = 0; i < num_fds; i++) {
+               int err;
+
+               err = prepare_socket_cloexec(fds[i]);
+               if (err != 0) {
+                       goto cleanup_fds;
+               }
+       }
+
+       ctx->recv_callback(ctx, ctx->recv_buf, received,
+                          fds, num_fds, ctx->private_data);
+
+       /*
+        * Close those fds that the callback has not set to -1.
+        */
+       close_fd_array(fds, num_fds);
+
+       return;
+
+cleanup_fds:
+       close_fd_array(fds, num_fds);
+
+       ctx->recv_callback(ctx, ctx->recv_buf, received,
+                          NULL, 0, ctx->private_data);
 }
 
 static void unix_dgram_job_finished(struct poll_watch *w, int fd, short events,
@@ -359,6 +448,7 @@ static void unix_dgram_send_queue_free(struct unix_dgram_send_queue *q)
                struct unix_dgram_msg *msg;
                msg = q->msgs;
                DLIST_REMOVE(q->msgs, msg);
+               close_fd_array_cmsg(&msg->msg);
                free(msg);
        }
        close(q->sock);
@@ -380,49 +470,140 @@ static struct unix_dgram_send_queue *find_send_queue(
 }
 
 static int queue_msg(struct unix_dgram_send_queue *q,
-                    const struct iovec *iov, int iovlen)
+                    const struct iovec *iov, int iovlen,
+                    const int *fds, size_t num_fds)
 {
        struct unix_dgram_msg *msg;
-       ssize_t buflen;
-       size_t msglen;
+       ssize_t data_len;
+       uint8_t *data_buf;
+       size_t msglen = sizeof(struct unix_dgram_msg);
        int i;
+       size_t tmp;
+       int ret = -1;
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       size_t fds_size = sizeof(int) * MIN(num_fds, INT8_MAX);
+       int fds_copy[MIN(num_fds, INT8_MAX)];
+       size_t cmsg_len = CMSG_LEN(fds_size);
+       size_t cmsg_space = CMSG_SPACE(fds_size);
+       char *cmsg_buf;
 
-       buflen = iov_buflen(iov, iovlen);
-       if (buflen == -1) {
+       /*
+        * Note: No need to check for overflow here,
+        * since cmsg will store <= INT8_MAX fds.
+        */
+       msglen += cmsg_space;
+
+#endif /*  HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+
+       if (num_fds > INT8_MAX) {
+               return EINVAL;
+       }
+
+#ifndef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               return ENOSYS;
+       }
+#endif
+
+       data_len = iov_buflen(iov, iovlen);
+       if (data_len == -1) {
                return EINVAL;
        }
 
-       msglen = offsetof(struct unix_dgram_msg, buf) + buflen;
-       if ((msglen < buflen) ||
-           (msglen < offsetof(struct unix_dgram_msg, buf))) {
+       tmp = msglen + data_len;
+       if ((tmp < msglen) || (tmp < data_len)) {
                /* overflow */
                return EINVAL;
        }
+       msglen = tmp;
+
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       for (i = 0; i < num_fds; i++) {
+               fds_copy[i] = -1;
+       }
+
+       for (i = 0; i < num_fds; i++) {
+               fds_copy[i] = dup(fds[i]);
+               if (fds_copy[i] == -1) {
+                       ret = errno;
+                       goto fail;
+               }
+       }
+#endif
 
        msg = malloc(msglen);
        if (msg == NULL) {
-               return ENOMEM;
+               ret = ENOMEM;
+               goto fail;
        }
-       msg->buflen = buflen;
+
        msg->sock = q->sock;
 
-       buflen = 0;
-       for (i=0; i<iovlen; i++) {
-               memcpy(&msg->buf[buflen], iov[i].iov_base, iov[i].iov_len);
-               buflen += iov[i].iov_len;
+       data_buf = (uint8_t *)(msg + 1);
+
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               cmsg_buf = (char *)data_buf;
+               memset(cmsg_buf, 0, cmsg_space);
+               data_buf += cmsg_space;
+       } else {
+               cmsg_buf = NULL;
+               cmsg_space = 0;
        }
+#endif
+
+       msg->iov = (struct iovec) {
+               .iov_base = (void *)data_buf,
+               .iov_len = data_len,
+       };
+
+       msg->msg = (struct msghdr) {
+               .msg_iov = &msg->iov,
+               .msg_iovlen = 1,
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+               .msg_control = cmsg_buf,
+               .msg_controllen = cmsg_space,
+#endif
+       };
+
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               struct cmsghdr *cmsg;
+               void *fdptr;
+
+               cmsg = CMSG_FIRSTHDR(&msg->msg);
+               cmsg->cmsg_level = SOL_SOCKET;
+               cmsg->cmsg_type = SCM_RIGHTS;
+               cmsg->cmsg_len = cmsg_len;
+               fdptr = CMSG_DATA(cmsg);
+               memcpy(fdptr, fds_copy, fds_size);
+               msg->msg.msg_controllen = cmsg->cmsg_len;
+       }
+#endif /*  HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+
+       iov_buf(iov, iovlen, data_buf, data_len);
 
        DLIST_ADD_END(q->msgs, msg, struct unix_dgram_msg);
        return 0;
+
+fail:
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       close_fd_array(fds_copy, num_fds);
+#endif
+       return ret;
 }
 
 static void unix_dgram_send_job(void *private_data)
 {
-       struct unix_dgram_msg *msg = private_data;
+       struct unix_dgram_msg *dmsg = private_data;
 
        do {
-               msg->sent = send(msg->sock, msg->buf, msg->buflen, 0);
-       } while ((msg->sent == -1) && (errno == EINTR));
+               dmsg->sent = sendmsg(dmsg->sock, &dmsg->msg, 0);
+       } while ((dmsg->sent == -1) && (errno == EINTR));
+
+       if (dmsg->sent == -1) {
+               dmsg->sys_errno = errno;
+       }
 }
 
 static void unix_dgram_job_finished(struct poll_watch *w, int fd, short events,
@@ -451,6 +632,7 @@ static void unix_dgram_job_finished(struct poll_watch *w, int fd, short events,
 
        msg = q->msgs;
        DLIST_REMOVE(q->msgs, msg);
+       close_fd_array_cmsg(&msg->msg);
        free(msg);
 
        if (q->msgs != NULL) {
@@ -466,11 +648,52 @@ static void unix_dgram_job_finished(struct poll_watch *w, int fd, short events,
 
 static int unix_dgram_send(struct unix_dgram_ctx *ctx,
                           const struct sockaddr_un *dst,
-                          const struct iovec *iov, int iovlen)
+                          const struct iovec *iov, int iovlen,
+                          const int *fds, size_t num_fds)
 {
        struct unix_dgram_send_queue *q;
        struct msghdr msg;
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       struct cmsghdr *cmsg;
+       size_t fds_size = sizeof(int) * num_fds;
+       size_t cmsg_len = CMSG_LEN(fds_size);
+       size_t cmsg_space = CMSG_SPACE(fds_size);
+       char cmsg_buf[cmsg_space];
+#endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
        int ret;
+       int i;
+
+       if (num_fds > INT8_MAX) {
+               return EINVAL;
+       }
+
+#ifndef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               return ENOSYS;
+       }
+#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+
+       for (i = 0; i < num_fds; i++) {
+               /*
+                * Make sure we only allow fd passing
+                * for communication channels,
+                * e.g. sockets, pipes, fifos, ...
+                */
+               ret = lseek(fds[i], 0, SEEK_CUR);
+               if (ret == -1 && errno == ESPIPE) {
+                       /* ok */
+                       continue;
+               }
+
+               /*
+                * Reject the message as we may need to call dup(),
+                * if we queue the message.
+                *
+                * That might result in unexpected behavior for the caller
+                * for files and broken posix locking.
+                */
+               return EINVAL;
+       }
 
        /*
         * To preserve message ordering, we have to queue a message when
@@ -478,7 +701,7 @@ static int unix_dgram_send(struct unix_dgram_ctx *ctx,
         */
        q = find_send_queue(ctx, dst->sun_path);
        if (q != NULL) {
-               return queue_msg(q, iov, iovlen);
+               return queue_msg(q, iov, iovlen, fds, num_fds);
        }
 
        /*
@@ -491,16 +714,35 @@ static int unix_dgram_send(struct unix_dgram_ctx *ctx,
                .msg_iov = discard_const_p(struct iovec, iov),
                .msg_iovlen = iovlen
        };
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               void *fdptr;
+
+               memset(cmsg_buf, 0, cmsg_space);
+
+               msg.msg_control = cmsg_buf;
+               msg.msg_controllen = cmsg_space;
+               cmsg = CMSG_FIRSTHDR(&msg);
+               cmsg->cmsg_level = SOL_SOCKET;
+               cmsg->cmsg_type = SCM_RIGHTS;
+               cmsg->cmsg_len = cmsg_len;
+               fdptr = CMSG_DATA(cmsg);
+               memcpy(fdptr, fds, fds_size);
+               msg.msg_controllen = cmsg->cmsg_len;
+       }
+#endif /*  HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 
        ret = sendmsg(ctx->sock, &msg, 0);
        if (ret >= 0) {
                return 0;
        }
-#ifdef EWOULDBLOCK
-       if ((errno != EWOULDBLOCK) && (errno != EAGAIN) && (errno != EINTR)) {
-#else
-       if ((errno != EAGAIN) && (errno != EINTR)) {
+       if ((errno != EWOULDBLOCK) &&
+           (errno != EAGAIN) &&
+#ifdef ENOBUFS
+           /* FreeBSD can give this for large messages */
+           (errno != ENOBUFS) &&
 #endif
+           (errno != EINTR)) {
                return errno;
        }
 
@@ -508,7 +750,7 @@ static int unix_dgram_send(struct unix_dgram_ctx *ctx,
        if (ret != 0) {
                return ret;
        }
-       ret = queue_msg(q, iov, iovlen);
+       ret = queue_msg(q, iov, iovlen, fds, num_fds);
        if (ret != 0) {
                unix_dgram_send_queue_free(q);
                return ret;
@@ -589,14 +831,16 @@ struct unix_msg_ctx {
 
        void (*recv_callback)(struct unix_msg_ctx *ctx,
                              uint8_t *msg, size_t msg_len,
+                             int *fds, size_t num_fds,
                              void *private_data);
        void *private_data;
 
        struct unix_msg *msgs;
 };
 
-static void unix_msg_recv(struct unix_dgram_ctx *ctx,
-                         uint8_t *msg, size_t msg_len,
+static void unix_msg_recv(struct unix_dgram_ctx *dgram_ctx,
+                         uint8_t *buf, size_t buflen,
+                         int *fds, size_t num_fds,
                          void *private_data);
 
 int unix_msg_init(const struct sockaddr_un *addr,
@@ -604,6 +848,7 @@ int unix_msg_init(const struct sockaddr_un *addr,
                  size_t fragment_len, uint64_t cookie,
                  void (*recv_callback)(struct unix_msg_ctx *ctx,
                                        uint8_t *msg, size_t msg_len,
+                                       int *fds, size_t num_fds,
                                        void *private_data),
                  void *private_data,
                  struct unix_msg_ctx **result)
@@ -635,7 +880,8 @@ int unix_msg_init(const struct sockaddr_un *addr,
 }
 
 int unix_msg_send(struct unix_msg_ctx *ctx, const struct sockaddr_un *dst,
-                 const struct iovec *iov, int iovlen)
+                 const struct iovec *iov, int iovlen,
+                 const int *fds, size_t num_fds)
 {
        ssize_t msglen;
        size_t sent;
@@ -653,6 +899,16 @@ int unix_msg_send(struct unix_msg_ctx *ctx, const struct sockaddr_un *dst,
                return EINVAL;
        }
 
+#ifndef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       if (num_fds > 0) {
+               return ENOSYS;
+       }
+#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+
+       if (num_fds > INT8_MAX) {
+               return EINVAL;
+       }
+
        if (msglen <= (ctx->fragment_len - sizeof(uint64_t))) {
                uint64_t cookie = 0;
 
@@ -663,7 +919,8 @@ int unix_msg_send(struct unix_msg_ctx *ctx, const struct sockaddr_un *dst,
                               sizeof(struct iovec) * iovlen);
                }
 
-               return unix_dgram_send(ctx->dgram, dst, iov_copy, iovlen+1);
+               return unix_dgram_send(ctx->dgram, dst, iov_copy, iovlen+1,
+                                      fds, num_fds);
        }
 
        hdr = (struct unix_msg_hdr) {
@@ -719,7 +976,19 @@ int unix_msg_send(struct unix_msg_ctx *ctx, const struct sockaddr_un *dst,
                }
                sent += (fragment_len - sizeof(ctx->cookie) - sizeof(hdr));
 
-               ret = unix_dgram_send(ctx->dgram, dst, iov_copy, iov_index);
+               /*
+                * only the last fragment should pass the fd array.
+                * That simplifies the receiver a lot.
+                */
+               if (sent < msglen) {
+                       ret = unix_dgram_send(ctx->dgram, dst,
+                                             iov_copy, iov_index,
+                                             NULL, 0);
+               } else {
+                       ret = unix_dgram_send(ctx->dgram, dst,
+                                             iov_copy, iov_index,
+                                             fds, num_fds);
+               }
                if (ret != 0) {
                        break;
                }
@@ -735,6 +1004,7 @@ int unix_msg_send(struct unix_msg_ctx *ctx, const struct sockaddr_un *dst,
 
 static void unix_msg_recv(struct unix_dgram_ctx *dgram_ctx,
                          uint8_t *buf, size_t buflen,
+                         int *fds, size_t num_fds,
                          void *private_data)
 {
        struct unix_msg_ctx *ctx = (struct unix_msg_ctx *)private_data;
@@ -744,20 +1014,21 @@ static void unix_msg_recv(struct unix_dgram_ctx *dgram_ctx,
        uint64_t cookie;
 
        if (buflen < sizeof(cookie)) {
-               return;
+               goto close_fds;
        }
+
        memcpy(&cookie, buf, sizeof(cookie));
 
        buf += sizeof(cookie);
        buflen -= sizeof(cookie);
 
        if (cookie == 0) {
-               ctx->recv_callback(ctx, buf, buflen, ctx->private_data);
+               ctx->recv_callback(ctx, buf, buflen, fds, num_fds, ctx->private_data);
                return;
        }
 
        if (buflen < sizeof(hdr)) {
-               return;
+               goto close_fds;
        }
        memcpy(&hdr, buf, sizeof(hdr));
 
@@ -780,7 +1051,7 @@ static void unix_msg_recv(struct unix_dgram_ctx *dgram_ctx,
        if (msg == NULL) {
                msg = malloc(offsetof(struct unix_msg, buf) + hdr.msglen);
                if (msg == NULL) {
-                       return;
+                       goto close_fds;
                }
                *msg = (struct unix_msg) {
                        .msglen = hdr.msglen,
@@ -793,19 +1064,23 @@ static void unix_msg_recv(struct unix_dgram_ctx *dgram_ctx,
 
        space = msg->msglen - msg->received;
        if (buflen > space) {
-               return;
+               goto close_fds;
        }
 
        memcpy(msg->buf + msg->received, buf, buflen);
        msg->received += buflen;
 
        if (msg->received < msg->msglen) {
-               return;
+               goto close_fds;
        }
 
        DLIST_REMOVE(ctx->msgs, msg);
-       ctx->recv_callback(ctx, msg->buf, msg->msglen, ctx->private_data);
+       ctx->recv_callback(ctx, msg->buf, msg->msglen, fds, num_fds, ctx->private_data);
        free(msg);
+       return;
+
+close_fds:
+       close_fd_array(fds, num_fds);
 }
 
 int unix_msg_free(struct unix_msg_ctx *ctx)
@@ -826,21 +1101,3 @@ int unix_msg_free(struct unix_msg_ctx *ctx)
        free(ctx);
        return 0;
 }
-
-static ssize_t iov_buflen(const struct iovec *iov, int iovlen)
-{
-       size_t buflen = 0;
-       int i;
-
-       for (i=0; i<iovlen; i++) {
-               size_t thislen = iov[i].iov_len;
-               size_t tmp = buflen + thislen;
-
-               if ((tmp < buflen) || (tmp < thislen)) {
-                       /* overflow */
-                       return -1;
-               }
-               buflen = tmp;
-       }
-       return buflen;
-}