socket_wrapper: correctly handle dup()/dup2() ref counting
authorStefan Metzmacher <metze@samba.org>
Wed, 28 Sep 2011 21:09:49 +0000 (23:09 +0200)
committerStefan Metzmacher <metze@samba.org>
Thu, 29 Sep 2011 11:13:56 +0000 (13:13 +0200)
metze

Autobuild-User: Stefan Metzmacher <metze@samba.org>
Autobuild-Date: Thu Sep 29 13:13:56 CEST 2011 on sn-devel-104

lib/socket_wrapper/socket_wrapper.c

index 791da324fddc528c7951b93db550e808b0c7ce86..a54f50f2f5741c0f7c7cdb3ab3f4fc973699a1b4 100644 (file)
@@ -212,11 +212,14 @@ static size_t socket_length(int family)
        return 0;
 }
 
        return 0;
 }
 
-
+struct socket_info_fd {
+       struct socket_info_fd *prev, *next;
+       int fd;
+};
 
 struct socket_info
 {
 
 struct socket_info
 {
-       int fd;
+       struct socket_info_fd *fds;
 
        int family;
        int type;
 
        int family;
        int type;
@@ -585,8 +588,12 @@ static struct socket_info *find_socket_info(int fd)
 {
        struct socket_info *i;
        for (i = sockets; i; i = i->next) {
 {
        struct socket_info *i;
        for (i = sockets; i; i = i->next) {
-               if (i->fd == fd) 
-                       return i;
+               struct socket_info_fd *f;
+               for (f = i->fds; f; f = f->next) {
+                       if (f->fd == fd) {
+                               return i;
+                       }
+               }
        }
 
        return NULL;
        }
 
        return NULL;
@@ -1405,6 +1412,7 @@ static void swrap_dump_packet(struct socket_info *si,
 _PUBLIC_ int swrap_socket(int family, int type, int protocol)
 {
        struct socket_info *si;
 _PUBLIC_ int swrap_socket(int family, int type, int protocol)
 {
        struct socket_info *si;
+       struct socket_info_fd *fi;
        int fd;
        int real_type = type;
 #ifdef SOCK_CLOEXEC
        int fd;
        int real_type = type;
 #ifdef SOCK_CLOEXEC
@@ -1466,6 +1474,10 @@ _PUBLIC_ int swrap_socket(int family, int type, int protocol)
        if (fd == -1) return -1;
 
        si = (struct socket_info *)calloc(1, sizeof(struct socket_info));
        if (fd == -1) return -1;
 
        si = (struct socket_info *)calloc(1, sizeof(struct socket_info));
+       if (si == NULL) {
+               errno = ENOMEM;
+               return -1;
+       }
 
        si->family = family;
 
 
        si->family = family;
 
@@ -1473,16 +1485,26 @@ _PUBLIC_ int swrap_socket(int family, int type, int protocol)
         * the type, not the flags */
        si->type = real_type;
        si->protocol = protocol;
         * the type, not the flags */
        si->type = real_type;
        si->protocol = protocol;
-       si->fd = fd;
 
 
+       fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (fi == NULL) {
+               free(si);
+               errno = ENOMEM;
+               return -1;
+       }
+
+       fi->fd = fd;
+
+       SWRAP_DLIST_ADD(si->fds, fi);
        SWRAP_DLIST_ADD(sockets, si);
 
        SWRAP_DLIST_ADD(sockets, si);
 
-       return si->fd;
+       return fd;
 }
 
 _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
 {
        struct socket_info *parent_si, *child_si;
 }
 
 _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
 {
        struct socket_info *parent_si, *child_si;
+       struct socket_info_fd *child_fi;
        int fd;
        struct sockaddr_un un_addr;
        socklen_t un_addrlen = sizeof(un_addr);
        int fd;
        struct sockaddr_un un_addr;
        socklen_t un_addrlen = sizeof(un_addr);
@@ -1535,7 +1557,19 @@ _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
        child_si = (struct socket_info *)malloc(sizeof(struct socket_info));
        memset(child_si, 0, sizeof(*child_si));
 
        child_si = (struct socket_info *)malloc(sizeof(struct socket_info));
        memset(child_si, 0, sizeof(*child_si));
 
-       child_si->fd = fd;
+       child_fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (child_fi == NULL) {
+               free(child_si);
+               free(my_addr);
+               close(fd);
+               errno = ENOMEM;
+               return -1;
+       }
+
+       child_fi->fd = fd;
+
+       SWRAP_DLIST_ADD(child_si->fds, child_fi);
+
        child_si->family = parent_si->family;
        child_si->type = parent_si->type;
        child_si->protocol = parent_si->protocol;
        child_si->family = parent_si->family;
        child_si->type = parent_si->type;
        child_si->protocol = parent_si->protocol;
@@ -1557,6 +1591,7 @@ _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
        ret = real_getsockname(fd, (struct sockaddr *)(void *)&un_my_addr,
                               &un_my_addrlen);
        if (ret == -1) {
        ret = real_getsockname(fd, (struct sockaddr *)(void *)&un_my_addr,
                               &un_my_addrlen);
        if (ret == -1) {
+               free(child_fi);
                free(child_si);
                close(fd);
                return ret;
                free(child_si);
                close(fd);
                return ret;
@@ -1566,6 +1601,7 @@ _PUBLIC_ int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
        ret = sockaddr_convert_from_un(child_si, &un_my_addr, un_my_addrlen,
                                       child_si->family, my_addr, &len);
        if (ret == -1) {
        ret = sockaddr_convert_from_un(child_si, &un_my_addr, un_my_addrlen,
                                       child_si->family, my_addr, &len);
        if (ret == -1) {
+               free(child_fi);
                free(child_si);
                free(my_addr);
                close(fd);
                free(child_si);
                free(my_addr);
                close(fd);
@@ -2507,12 +2543,26 @@ int swrap_writev(int s, const struct iovec *vector, size_t count)
 _PUBLIC_ int swrap_close(int fd)
 {
        struct socket_info *si = find_socket_info(fd);
 _PUBLIC_ int swrap_close(int fd)
 {
        struct socket_info *si = find_socket_info(fd);
+       struct socket_info_fd *fi;
        int ret;
 
        if (!si) {
                return real_close(fd);
        }
 
        int ret;
 
        if (!si) {
                return real_close(fd);
        }
 
+       for (fi = si->fds; fi; fi = fi->next) {
+               if (fi->fd == fd) {
+                       SWRAP_DLIST_REMOVE(si->fds, fi);
+                       free(fi);
+                       break;
+               }
+       }
+
+       if (si->fds) {
+               /* there are still references left */
+               return real_close(fd);
+       }
+
        SWRAP_DLIST_REMOVE(sockets, si);
 
        if (si->myname && si->peername) {
        SWRAP_DLIST_REMOVE(sockets, si);
 
        if (si->myname && si->peername) {
@@ -2539,8 +2589,8 @@ _PUBLIC_ int swrap_close(int fd)
 
 _PUBLIC_ int swrap_dup(int fd)
 {
 
 _PUBLIC_ int swrap_dup(int fd)
 {
-       struct socket_info *si, *si2;
-       int fd2;
+       struct socket_info *si;
+       struct socket_info_fd *fi;
 
        si = find_socket_info(fd);
 
 
        si = find_socket_info(fd);
 
@@ -2548,55 +2598,28 @@ _PUBLIC_ int swrap_dup(int fd)
                return real_dup(fd);
        }
 
                return real_dup(fd);
        }
 
-       if (si->tmp_path) {
-               /* we would need reference counting to handle this */
-               errno = EINVAL;
-               return -1;
-       }
-
-       fd2 = real_dup(fd);
-       if (fd2 == -1) {
-               return -1;
-       }
-
-       si2 = (struct socket_info *)malloc(sizeof(struct socket_info));
-       if (si2 == NULL) {
-               real_close(fd2);
+       fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (fi == NULL) {
                errno = ENOMEM;
                return -1;
        }
 
                errno = ENOMEM;
                return -1;
        }
 
-       /* copy the whole structure, then duplicate pointer elements */
-       *si2 = *si;
-
-       si2->fd = fd2;
-
-       if (si2->myname) {
-               si2->myname = sockaddr_dup(si2->myname, si2->myname_len);
-               if (si2->myname == NULL) {
-                       real_close(fd2);
-                       errno = ENOMEM;
-                       return -1;
-               }
-       }
-
-       if (si2->peername) {
-               si2->peername = sockaddr_dup(si2->peername, si2->peername_len);
-               if (si2->peername == NULL) {
-                       real_close(fd2);
-                       errno = ENOMEM;
-                       return -1;
-               }
+       fi->fd = real_dup(fd);
+       if (fi->fd == -1) {
+               int saved_errno = errno;
+               free(fi);
+               errno = saved_errno;
+               return -1;
        }
 
        }
 
-       SWRAP_DLIST_ADD(sockets, si2);
-       return fd2;
+       SWRAP_DLIST_ADD(si->fds, fi);
+       return fi->fd;
 }
 
 _PUBLIC_ int swrap_dup2(int fd, int newfd)
 {
 }
 
 _PUBLIC_ int swrap_dup2(int fd, int newfd)
 {
-       struct socket_info *si, *si2;
-       int fd2;
+       struct socket_info *si;
+       struct socket_info_fd *fi;
 
        si = find_socket_info(fd);
 
 
        si = find_socket_info(fd);
 
@@ -2604,53 +2627,26 @@ _PUBLIC_ int swrap_dup2(int fd, int newfd)
                return real_dup2(fd, newfd);
        }
 
                return real_dup2(fd, newfd);
        }
 
-       if (si->tmp_path) {
-               /* we would need reference counting to handle this */
-               errno = EINVAL;
-               return -1;
-       }
-
        if (find_socket_info(newfd)) {
                /* dup2() does an implicit close of newfd, which we
                 * need to emulate */
                swrap_close(newfd);
        }
 
        if (find_socket_info(newfd)) {
                /* dup2() does an implicit close of newfd, which we
                 * need to emulate */
                swrap_close(newfd);
        }
 
-       fd2 = real_dup2(fd, newfd);
-       if (fd2 == -1) {
-               return -1;
-       }
-
-       si2 = (struct socket_info *)malloc(sizeof(struct socket_info));
-       if (si2 == NULL) {
-               real_close(fd2);
+       fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd));
+       if (fi == NULL) {
                errno = ENOMEM;
                return -1;
        }
 
                errno = ENOMEM;
                return -1;
        }
 
-       /* copy the whole structure, then duplicate pointer elements */
-       *si2 = *si;
-
-       si2->fd = fd2;
-
-       if (si2->myname) {
-               si2->myname = sockaddr_dup(si2->myname, si2->myname_len);
-               if (si2->myname == NULL) {
-                       real_close(fd2);
-                       errno = ENOMEM;
-                       return -1;
-               }
-       }
-
-       if (si2->peername) {
-               si2->peername = sockaddr_dup(si2->peername, si2->peername_len);
-               if (si2->peername == NULL) {
-                       real_close(fd2);
-                       errno = ENOMEM;
-                       return -1;
-               }
+       fi->fd = real_dup2(fd, newfd);
+       if (fi->fd == -1) {
+               int saved_errno = errno;
+               free(fi);
+               errno = saved_errno;
+               return -1;
        }
 
        }
 
-       SWRAP_DLIST_ADD(sockets, si2);
-       return fd2;
+       SWRAP_DLIST_ADD(si->fds, fi);
+       return fi->fd;
 }
 }