lib: Update socket_wrapper to version 1.1.6
[kai/samba-autobuild/.git] / lib / socket_wrapper / socket_wrapper.c
index 45282edeea0e434d81a55f4551ae2d3159fbb340..3b0499d1756c19cba4388359e50874592a812b06 100644 (file)
@@ -248,6 +248,7 @@ struct socket_info
        int connected;
        int defer_connect;
        int pktinfo;
+       int tcp_nodelay;
 
        /* The unix path so we can unlink it on close() */
        struct sockaddr_un un_addr;
@@ -397,6 +398,7 @@ struct swrap_libc_fns {
 #ifdef HAVE_TIMERFD_CREATE
        int (*libc_timerfd_create)(int clockid, int flags);
 #endif
+       ssize_t (*libc_write)(int fd, const void *buf, size_t count);
        ssize_t (*libc_writev)(int fd, const struct iovec *iov, int iovcnt);
 };
 
@@ -836,6 +838,13 @@ static int libc_timerfd_create(int clockid, int flags)
 }
 #endif
 
+static ssize_t libc_write(int fd, const void *buf, size_t count)
+{
+       swrap_load_lib_function(SWRAP_LIBC, write);
+
+       return swrap.fns.libc_write(fd, buf, count);
+}
+
 static ssize_t libc_writev(int fd, const struct iovec *iov, int iovcnt)
 {
        swrap_load_lib_function(SWRAP_LIBSOCKET, writev);
@@ -1846,11 +1855,10 @@ static uint8_t *swrap_pcap_packet_init(struct timeval *tval,
                alloc_len = SWRAP_PACKET_MIN_ALLOC;
        }
 
-       base = (uint8_t *)malloc(alloc_len);
+       base = (uint8_t *)calloc(1, alloc_len);
        if (base == NULL) {
                return NULL;
        }
-       memset(base, 0x0, alloc_len);
 
        buf = base;
 
@@ -2375,6 +2383,9 @@ static int swrap_socket(int family, int type, int protocol)
        case AF_INET6:
 #endif
                break;
+#ifdef AF_NETLINK
+       case AF_NETLINK:
+#endif /* AF_NETLINK */
        case AF_UNIX:
                return libc_socket(family, type, protocol);
        default:
@@ -2426,8 +2437,7 @@ static int swrap_socket(int family, int type, int protocol)
                swrap_remove_stale(fd);
        }
 
-       si = (struct socket_info *)malloc(sizeof(struct socket_info));
-       memset(si, 0, sizeof(struct socket_info));
+       si = (struct socket_info *)calloc(1, sizeof(struct socket_info));
        if (si == NULL) {
                errno = ENOMEM;
                return -1;
@@ -2621,8 +2631,12 @@ static int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
                return ret;
        }
 
-       child_si = (struct socket_info *)malloc(sizeof(struct socket_info));
-       memset(child_si, 0, sizeof(struct socket_info));
+       child_si = (struct socket_info *)calloc(1, sizeof(struct socket_info));
+       if (child_si == NULL) {
+               close(fd);
+               errno = ENOMEM;
+               return -1;
+       }
 
        child_fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
        if (child_fi == NULL) {
@@ -3145,6 +3159,14 @@ static int swrap_listen(int s, int backlog)
                return libc_listen(s, backlog);
        }
 
+       if (si->bound == 0) {
+               ret = swrap_auto_bind(s, si, si->family);
+               if (ret == -1) {
+                       errno = EADDRINUSE;
+                       return ret;
+               }
+       }
+
        ret = libc_listen(s, backlog);
 
        return ret;
@@ -3350,6 +3372,29 @@ static int swrap_getsockopt(int s, int level, int optname,
                                               optval,
                                               optlen);
                }
+       } else if (level == IPPROTO_TCP) {
+               switch (optname) {
+#ifdef TCP_NODELAY
+               case TCP_NODELAY:
+                       /*
+                        * This enables sending packets directly out over TCP.
+                        * As a unix socket is doing that any way, report it as
+                        * enabled.
+                        */
+                       if (optval == NULL || optlen == NULL ||
+                           *optlen < (socklen_t)sizeof(int)) {
+                               errno = EINVAL;
+                               return -1;
+                       }
+
+                       *optlen = sizeof(int);
+                       *(int *)optval = si->tcp_nodelay;
+
+                       return 0;
+#endif /* TCP_NODELAY */
+               default:
+                       break;
+               }
        }
 
        errno = ENOPROTOOPT;
@@ -3388,6 +3433,35 @@ static int swrap_setsockopt(int s, int level, int optname,
                                       optname,
                                       optval,
                                       optlen);
+       } else if (level == IPPROTO_TCP) {
+               switch (optname) {
+#ifdef TCP_NODELAY
+               case TCP_NODELAY: {
+                       int i;
+
+                       /*
+                        * This enables sending packets directly out over TCP.
+                        * A unix socket is doing that any way.
+                        */
+                       if (optval == NULL || optlen == 0 ||
+                           optlen < (socklen_t)sizeof(int)) {
+                               errno = EINVAL;
+                               return -1;
+                       }
+
+                       i = *discard_const_p(int, optval);
+                       if (i != 0 && i != 1) {
+                               errno = EINVAL;
+                               return -1;
+                       }
+                       si->tcp_nodelay = i;
+
+                       return 0;
+               }
+#endif /* TCP_NODELAY */
+               default:
+                       break;
+               }
        }
 
        switch (si->family) {
@@ -3686,9 +3760,7 @@ static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg,
        size_t cmspace;
        uint8_t *p;
 
-       cmspace =
-               (*cm_data_space) +
-               CMSG_SPACE(cmsg->cmsg_len - CMSG_ALIGN(sizeof(struct cmsghdr)));
+       cmspace = *cm_data_space + CMSG_ALIGN(cmsg->cmsg_len);
 
        p = realloc((*cm_data), cmspace);
        if (p == NULL) {
@@ -3799,7 +3871,8 @@ static ssize_t swrap_sendmsg_before(int fd,
                msg->msg_iovlen = i;
                if (msg->msg_iovlen == 0) {
                        *tmp_iov = msg->msg_iov[0];
-                       tmp_iov->iov_len = MIN(tmp_iov->iov_len, (size_t)mtu);
+                       tmp_iov->iov_len = MIN((size_t)tmp_iov->iov_len,
+                                              (size_t)mtu);
                        msg->msg_iov = tmp_iov;
                        msg->msg_iovlen = 1;
                }
@@ -4016,7 +4089,8 @@ static int swrap_recvmsg_before(int fd,
                msg->msg_iovlen = i;
                if (msg->msg_iovlen == 0) {
                        *tmp_iov = msg->msg_iov[0];
-                       tmp_iov->iov_len = MIN(tmp_iov->iov_len, (size_t)mtu);
+                       tmp_iov->iov_len = MIN((size_t)tmp_iov->iov_len,
+                                              (size_t)mtu);
                        msg->msg_iov = tmp_iov;
                        msg->msg_iovlen = 1;
                }
@@ -4490,6 +4564,58 @@ ssize_t read(int s, void *buf, size_t len)
        return swrap_read(s, buf, len);
 }
 
+/****************************************************************************
+ *   WRITE
+ ***************************************************************************/
+
+static ssize_t swrap_write(int s, const void *buf, size_t len)
+{
+       struct msghdr msg;
+       struct iovec tmp;
+       struct sockaddr_un un_addr;
+       ssize_t ret;
+       int rc;
+       struct socket_info *si;
+
+       si = find_socket_info(s);
+       if (si == NULL) {
+               return libc_write(s, buf, len);
+       }
+
+       tmp.iov_base = discard_const_p(char, buf);
+       tmp.iov_len = len;
+
+       ZERO_STRUCT(msg);
+       msg.msg_name = NULL;           /* optional address */
+       msg.msg_namelen = 0;           /* size of address */
+       msg.msg_iov = &tmp;            /* scatter/gather array */
+       msg.msg_iovlen = 1;            /* # elements in msg_iov */
+#if HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       msg.msg_control = NULL;        /* ancillary data, see below */
+       msg.msg_controllen = 0;        /* ancillary data buffer len */
+       msg.msg_flags = 0;             /* flags on received message */
+#endif
+
+       rc = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, NULL, NULL, NULL);
+       if (rc < 0) {
+               return -1;
+       }
+
+       buf = msg.msg_iov[0].iov_base;
+       len = msg.msg_iov[0].iov_len;
+
+       ret = libc_write(s, buf, len);
+
+       swrap_sendmsg_after(s, si, &msg, NULL, ret);
+
+       return ret;
+}
+
+ssize_t write(int s, const void *buf, size_t len)
+{
+       return swrap_write(s, buf, len);
+}
+
 /****************************************************************************
  *   SEND
  ***************************************************************************/
@@ -4550,6 +4676,9 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags)
        struct swrap_address from_addr = {
                .sa_socklen = sizeof(struct sockaddr_un),
        };
+       struct swrap_address convert_addr = {
+               .sa_socklen = sizeof(struct sockaddr_storage),
+       };
        struct socket_info *si;
        struct msghdr msg;
        struct iovec tmp;
@@ -4608,6 +4737,13 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags)
        }
 #endif
 
+       /*
+        * We convert the unix address to a IP address so we need a buffer
+        * which can store the address in case of SOCK_DGRAM, see below.
+        */
+       msg.msg_name = &convert_addr.sa;
+       msg.msg_namelen = convert_addr.sa_socklen;
+
        rc = swrap_recvmsg_after(s,
                                 si,
                                 &msg,