socket_wrapper: correctly handle dup()/dup2() ref counting
[amitay/samba.git] / lib / socket_wrapper / socket_wrapper.c
index cd913275f63f8edf0d1aa13b8e104cc410e05bcd..a54f50f2f5741c0f7c7cdb3ab3f4fc973699a1b4 100644 (file)
 #define real_writev writev
 #define real_socket socket
 #define real_close close
+#define real_dup dup
+#define real_dup2 dup2
 #endif
 
 #ifdef HAVE_GETTIMEOFDAY_TZ
 #define SOCKET_TYPE_CHAR_TCP_V6                'X'
 #define SOCKET_TYPE_CHAR_UDP_V6                'Y'
 
-#define MAX_WRAPPED_INTERFACES 16
+/* This limit is to avoid broadcast sendto() needing to stat too many
+ * files.  It may be raised (with a performance cost) to up to 254
+ * without changing the format above */
+#define MAX_WRAPPED_INTERFACES 32
 
 #ifdef HAVE_IPV6
 /*
@@ -207,11 +212,14 @@ static size_t socket_length(int family)
        return 0;
 }
 
-
+struct socket_info_fd {
+       struct socket_info_fd *prev, *next;
+       int fd;
+};
 
 struct socket_info
 {
-       int fd;
+       struct socket_info_fd *fds;
 
        int family;
        int type;
@@ -222,7 +230,6 @@ struct socket_info
        int connected;
        int defer_connect;
 
-       char *path;
        char *tmp_path;
 
        struct sockaddr *myname;
@@ -551,6 +558,11 @@ static int convert_in_un_alloc(struct socket_info *si, const struct sockaddr *in
 
        if (bcast) *bcast = is_bcast;
 
+       if (iface == 0 || iface > MAX_WRAPPED_INTERFACES) {
+               errno = EINVAL;
+               return -1;
+       }
+
        if (prt == 0) {
                /* handle auto-allocation of ephemeral ports */
                for (prt = 5001; prt < 10000; prt++) {
@@ -576,8 +588,12 @@ static struct socket_info *find_socket_info(int fd)
 {
        struct socket_info *i;
        for (i = sockets; i; i = i->next) {
-               if (i->fd == fd) 
-                       return i;
+               struct socket_info_fd *f;
+               for (f = i->fds; f; f = f->next) {
+                       if (f->fd == fd) {
+                               return i;
+                       }
+               }
        }
 
        return NULL;
@@ -1396,6 +1412,7 @@ static void swrap_dump_packet(struct socket_info *si,
 _PUBLIC_ int swrap_socket(int family, int type, int protocol)
 {
        struct socket_info *si;
+       struct socket_info_fd *fi;
        int fd;
        int real_type = type;
 #ifdef SOCK_CLOEXEC
@@ -1457,6 +1474,10 @@ _PUBLIC_ int swrap_socket(int family, int type, int protocol)
        if (fd == -1) return -1;
 
        si = (struct socket_info *)calloc(1, sizeof(struct socket_info));
+       if (si == NULL) {
+               errno = ENOMEM;
+               return -1;
+       }
 
        si->family = family;
 
@@ -1464,16 +1485,26 @@ _PUBLIC_ int swrap_socket(int family, int type, int protocol)
         * the type, not the flags */
        si->type = real_type;
        si->protocol = protocol;
-       si->fd = fd;
 
+       fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (fi == NULL) {
+               free(si);
+               errno = ENOMEM;
+               return -1;
+       }
+
+       fi->fd = fd;
+
+       SWRAP_DLIST_ADD(si->fds, fi);
        SWRAP_DLIST_ADD(sockets, si);
 
-       return si->fd;
+       return fd;
 }
 
 _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
 {
        struct socket_info *parent_si, *child_si;
+       struct socket_info_fd *child_fi;
        int fd;
        struct sockaddr_un un_addr;
        socklen_t un_addrlen = sizeof(un_addr);
@@ -1526,7 +1557,19 @@ _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
        child_si = (struct socket_info *)malloc(sizeof(struct socket_info));
        memset(child_si, 0, sizeof(*child_si));
 
-       child_si->fd = fd;
+       child_fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (child_fi == NULL) {
+               free(child_si);
+               free(my_addr);
+               close(fd);
+               errno = ENOMEM;
+               return -1;
+       }
+
+       child_fi->fd = fd;
+
+       SWRAP_DLIST_ADD(child_si->fds, child_fi);
+
        child_si->family = parent_si->family;
        child_si->type = parent_si->type;
        child_si->protocol = parent_si->protocol;
@@ -1538,15 +1581,17 @@ _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
        child_si->peername = sockaddr_dup(my_addr, len);
 
        if (addr != NULL && addrlen != NULL) {
-           *addrlen = len;
-           if (*addrlen >= len)
-               memcpy(addr, my_addr, len);
-           *addrlen = 0;
+               size_t copy_len = MIN(*addrlen, len);
+               if (copy_len > 0) {
+                       memcpy(addr, my_addr, copy_len);
+               }
+               *addrlen = len;
        }
 
        ret = real_getsockname(fd, (struct sockaddr *)(void *)&un_my_addr,
                               &un_my_addrlen);
        if (ret == -1) {
+               free(child_fi);
                free(child_si);
                close(fd);
                return ret;
@@ -1556,6 +1601,7 @@ _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
        ret = sockaddr_convert_from_un(child_si, &un_my_addr, un_my_addrlen,
                                       child_si->family, my_addr, &len);
        if (ret == -1) {
+               free(child_fi);
                free(child_si);
                free(my_addr);
                close(fd);
@@ -1584,7 +1630,7 @@ static int autobind_start;
    assign it here.
    Note: this might change the family from ipv6 to ipv4
 */
-static int swrap_auto_bind(struct socket_info *si, int family)
+static int swrap_auto_bind(int fd, struct socket_info *si, int family)
 {
        struct sockaddr_un un_addr;
        int i;
@@ -1673,7 +1719,7 @@ static int swrap_auto_bind(struct socket_info *si, int family)
                         type, socket_wrapper_default_iface(), port);
                if (stat(un_addr.sun_path, &st) == 0) continue;
 
-               ret = real_bind(si->fd, (struct sockaddr *)(void *)&un_addr,
+               ret = real_bind(fd, (struct sockaddr *)(void *)&un_addr,
                                sizeof(un_addr));
                if (ret == -1) return ret;
 
@@ -1706,7 +1752,7 @@ _PUBLIC_ int swrap_connect(int s, const struct sockaddr *serv_addr, socklen_t ad
        }
 
        if (si->bound == 0) {
-               ret = swrap_auto_bind(si, serv_addr->sa_family);
+               ret = swrap_auto_bind(s, si, serv_addr->sa_family);
                if (ret == -1) return -1;
        }
 
@@ -1896,7 +1942,8 @@ _PUBLIC_ int swrap_ioctl(int s, int r, void *p)
        return ret;
 }
 
-static ssize_t swrap_sendmsg_before(struct socket_info *si,
+static ssize_t swrap_sendmsg_before(int fd,
+                                   struct socket_info *si,
                                    struct msghdr *msg,
                                    struct iovec *tmp_iov,
                                    struct sockaddr_un *tmp_un,
@@ -1981,7 +2028,7 @@ static ssize_t swrap_sendmsg_before(struct socket_info *si,
                }
 
                if (si->bound == 0) {
-                       ret = swrap_auto_bind(si, si->family);
+                       ret = swrap_auto_bind(fd, si, si->family);
                        if (ret == -1) return -1;
                }
 
@@ -1993,7 +2040,7 @@ static ssize_t swrap_sendmsg_before(struct socket_info *si,
                                             tmp_un, 0, NULL);
                if (ret == -1) return -1;
 
-               ret = real_connect(si->fd, (struct sockaddr *)(void *)tmp_un,
+               ret = real_connect(fd, (struct sockaddr *)(void *)tmp_un,
                                   sizeof(*tmp_un));
 
                /* to give better errors */
@@ -2158,7 +2205,7 @@ _PUBLIC_ ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, con
        msg.msg_flags = 0;             /* flags on received message */
 #endif
 
-       ret = swrap_sendmsg_before(si, &msg, &tmp, &un_addr, &to_un, &to, &bcast);
+       ret = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, &to_un, &to, &bcast);
        if (ret == -1) return -1;
 
        buf = msg.msg_iov[0].iov_base;
@@ -2188,7 +2235,8 @@ _PUBLIC_ ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, con
                return len;
        }
 
-       ret = real_sendto(s, buf, len, flags, msg.msg_name, msg.msg_namelen);
+       ret = real_sendto(s, buf, len, flags, (struct sockaddr *)msg.msg_name,
+                         msg.msg_namelen);
 
        swrap_sendmsg_after(si, &msg, to, ret);
 
@@ -2278,7 +2326,7 @@ _PUBLIC_ ssize_t swrap_send(int s, const void *buf, size_t len, int flags)
        msg.msg_flags = 0;             /* flags on received message */
 #endif
 
-       ret = swrap_sendmsg_before(si, &msg, &tmp, &un_addr, NULL, NULL, NULL);
+       ret = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, NULL, NULL, NULL);
        if (ret == -1) return -1;
 
        buf = msg.msg_iov[0].iov_base;
@@ -2321,7 +2369,7 @@ _PUBLIC_ ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags)
        msg.msg_flags = omsg->msg_flags;           /* flags on received message */
 #endif
 
-       ret = swrap_sendmsg_before(si, &msg, &tmp, &un_addr, &to_un, &to, &bcast);
+       ret = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, &to_un, &to, &bcast);
        if (ret == -1) return -1;
 
        if (bcast) {
@@ -2482,7 +2530,7 @@ int swrap_writev(int s, const struct iovec *vector, size_t count)
        msg.msg_flags = 0;             /* flags on received message */
 #endif
 
-       ret = swrap_sendmsg_before(si, &msg, &tmp, &un_addr, NULL, NULL, NULL);
+       ret = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, NULL, NULL, NULL);
        if (ret == -1) return -1;
 
        ret = real_writev(s, msg.msg_iov, msg.msg_iovlen);
@@ -2495,12 +2543,26 @@ int swrap_writev(int s, const struct iovec *vector, size_t count)
 _PUBLIC_ int swrap_close(int fd)
 {
        struct socket_info *si = find_socket_info(fd);
+       struct socket_info_fd *fi;
        int ret;
 
        if (!si) {
                return real_close(fd);
        }
 
+       for (fi = si->fds; fi; fi = fi->next) {
+               if (fi->fd == fd) {
+                       SWRAP_DLIST_REMOVE(si->fds, fi);
+                       free(fi);
+                       break;
+               }
+       }
+
+       if (si->fds) {
+               /* there are still references left */
+               return real_close(fd);
+       }
+
        SWRAP_DLIST_REMOVE(sockets, si);
 
        if (si->myname && si->peername) {
@@ -2514,7 +2576,6 @@ _PUBLIC_ int swrap_close(int fd)
                swrap_dump_packet(si, NULL, SWRAP_CLOSE_ACK, NULL, 0);
        }
 
-       if (si->path) free(si->path);
        if (si->myname) free(si->myname);
        if (si->peername) free(si->peername);
        if (si->tmp_path) {
@@ -2525,3 +2586,67 @@ _PUBLIC_ int swrap_close(int fd)
 
        return ret;
 }
+
+_PUBLIC_ int swrap_dup(int fd)
+{
+       struct socket_info *si;
+       struct socket_info_fd *fi;
+
+       si = find_socket_info(fd);
+
+       if (!si) {
+               return real_dup(fd);
+       }
+
+       fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (fi == NULL) {
+               errno = ENOMEM;
+               return -1;
+       }
+
+       fi->fd = real_dup(fd);
+       if (fi->fd == -1) {
+               int saved_errno = errno;
+               free(fi);
+               errno = saved_errno;
+               return -1;
+       }
+
+       SWRAP_DLIST_ADD(si->fds, fi);
+       return fi->fd;
+}
+
+_PUBLIC_ int swrap_dup2(int fd, int newfd)
+{
+       struct socket_info *si;
+       struct socket_info_fd *fi;
+
+       si = find_socket_info(fd);
+
+       if (!si) {
+               return real_dup2(fd, newfd);
+       }
+
+       if (find_socket_info(newfd)) {
+               /* dup2() does an implicit close of newfd, which we
+                * need to emulate */
+               swrap_close(newfd);
+       }
+
+       fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (fi == NULL) {
+               errno = ENOMEM;
+               return -1;
+       }
+
+       fi->fd = real_dup2(fd, newfd);
+       if (fi->fd == -1) {
+               int saved_errno = errno;
+               free(fi);
+               errno = saved_errno;
+               return -1;
+       }
+
+       SWRAP_DLIST_ADD(si->fds, fi);
+       return fi->fd;
+}