smbd: Use msghdr.[ch] in vfs_aio_fork
[samba.git] / source3 / modules / vfs_aio_fork.c
index d8a99b00359dd7ff44c6312ff3fac98ae7aaa2c9..7d2aff9064462a29d3559a3351523f3771617145 100644 (file)
 #include "smbd/globals.h"
 #include "lib/async_req/async_sock.h"
 #include "lib/util/tevent_unix.h"
+#include "lib/sys_rw.h"
+#include "lib/sys_rw_data.h"
+#include "lib/msghdr.h"
 
-#if !defined(HAVE_STRUCT_MSGHDR_MSG_CONTROL) && !defined(HAVE_MSGHDR_MSG_ACCTRIGHTS)
+#if !defined(HAVE_STRUCT_MSGHDR_MSG_CONTROL) && !defined(HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS)
 # error Can not pass file descriptors
 #endif
 
@@ -154,26 +157,13 @@ static ssize_t read_fd(int fd, void *ptr, size_t nbytes, int *recvfd)
        struct msghdr msg;
        struct iovec iov[1];
        ssize_t n;
-#ifndef HAVE_STRUCT_MSGHDR_MSG_CONTROL
-       int newfd;
+       size_t bufsize = msghdr_prep_recv_fds(NULL, NULL, 0, 1);
+       uint8_t buf[bufsize];
 
-       msg.msg_accrights = (caddr_t) &newfd;
-       msg.msg_accrightslen = sizeof(int);
-#else
-
-       union {
-         struct cmsghdr        cm;
-         char                          control[CMSG_SPACE(sizeof(int))];
-       } control_un;
-       struct cmsghdr  *cmptr;
-
-       msg.msg_control = control_un.control;
-       msg.msg_controllen = sizeof(control_un.control);
-#endif
+       msghdr_prep_recv_fds(&msg, buf, bufsize, 1);
 
        msg.msg_name = NULL;
        msg.msg_namelen = 0;
-       msg.msg_flags = 0;
 
        iov[0].iov_base = (void *)ptr;
        iov[0].iov_len = nbytes;
@@ -184,71 +174,43 @@ static ssize_t read_fd(int fd, void *ptr, size_t nbytes, int *recvfd)
                return(n);
        }
 
-#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
-       if ((cmptr = CMSG_FIRSTHDR(&msg)) != NULL
-           && cmptr->cmsg_len == CMSG_LEN(sizeof(int))) {
-               if (cmptr->cmsg_level != SOL_SOCKET) {
-                       DEBUG(10, ("control level != SOL_SOCKET"));
-                       errno = EINVAL;
-                       return -1;
-               }
-               if (cmptr->cmsg_type != SCM_RIGHTS) {
-                       DEBUG(10, ("control type != SCM_RIGHTS"));
-                       errno = EINVAL;
-                       return -1;
+       {
+               size_t num_fds = msghdr_extract_fds(&msg, NULL, 0);
+               int fds[num_fds];
+
+               msghdr_extract_fds(&msg, fds, num_fds);
+
+               if (num_fds != 1) {
+                       size_t i;
+
+                       for (i=0; i<num_fds; i++) {
+                               close(fds[i]);
+                       }
+
+                       *recvfd = -1;
+                       return n;
                }
-               memcpy(recvfd, CMSG_DATA(cmptr), sizeof(*recvfd));
-       } else {
-               *recvfd = -1;           /* descriptor was not passed */
-       }
-#else
-       if (msg.msg_accrightslen == sizeof(int)) {
-               *recvfd = newfd;
-       }
-       else {
-               *recvfd = -1;           /* descriptor was not passed */
+
+               *recvfd = fds[0];
        }
-#endif
 
        return(n);
 }
 
 static ssize_t write_fd(int fd, void *ptr, size_t nbytes, int sendfd)
 {
-       struct msghdr   msg;
-       struct iovec    iov[1];
-
-#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
-       union {
-               struct cmsghdr  cm;
-               char control[CMSG_SPACE(sizeof(int))];
-       } control_un;
-       struct cmsghdr  *cmptr;
-
-       ZERO_STRUCT(msg);
-       ZERO_STRUCT(control_un);
-
-       msg.msg_control = control_un.control;
-       msg.msg_controllen = sizeof(control_un.control);
-
-       cmptr = CMSG_FIRSTHDR(&msg);
-       cmptr->cmsg_len = CMSG_LEN(sizeof(int));
-       cmptr->cmsg_level = SOL_SOCKET;
-       cmptr->cmsg_type = SCM_RIGHTS;
-       memcpy(CMSG_DATA(cmptr), &sendfd, sizeof(sendfd));
-#else
-       ZERO_STRUCT(msg);
-       msg.msg_accrights = (caddr_t) &sendfd;
-       msg.msg_accrightslen = sizeof(int);
-#endif
+       struct msghdr msg;
+       size_t bufsize = msghdr_prep_fds(NULL, NULL, 0, &sendfd, 1);
+       uint8_t buf[bufsize];
+       struct iovec iov;
 
+       msghdr_prep_fds(&msg, buf, bufsize, &sendfd, 1);
        msg.msg_name = NULL;
        msg.msg_namelen = 0;
 
-       ZERO_STRUCT(iov);
-       iov[0].iov_base = (void *)ptr;
-       iov[0].iov_len = nbytes;
-       msg.msg_iov = iov;
+       iov.iov_base = (void *)ptr;
+       iov.iov_len = nbytes;
+       msg.msg_iov = &iov;
        msg.msg_iovlen = 1;
 
        return (sendmsg(fd, &msg, 0));