Add support for sendmsg() in socket_wrapper
[ira/wip.git] / lib / socket_wrapper / socket_wrapper.c
index 44082e78a1e97ca86ac9b1f6cdefe061c66cd8f1..bd848f920b9327bde4d708b9f76fe6c670ade60a 100644 (file)
 #define real_setsockopt setsockopt
 #define real_recvfrom recvfrom
 #define real_sendto sendto
+#define real_sendmsg sendmsg
 #define real_ioctl ioctl
 #define real_recv recv
 #define real_send send
@@ -218,6 +219,7 @@ struct socket_info
        int bcast;
        int is_server;
        int connected;
+       int defer_connect;
 
        char *path;
        char *tmp_path;
@@ -1101,8 +1103,10 @@ static uint8_t *swrap_marshall_packet(struct socket_info *si,
        switch (si->family) {
        case AF_INET:
                break;
+#ifdef HAVE_IPV6
        case AF_INET6:
                break;
+#endif
        default:
                return NULL;
        }
@@ -1686,10 +1690,15 @@ _PUBLIC_ int swrap_connect(int s, const struct sockaddr *serv_addr, socklen_t ad
        ret = sockaddr_convert_to_un(si, (const struct sockaddr *)serv_addr, addrlen, &un_addr, 0, NULL);
        if (ret == -1) return -1;
 
-       swrap_dump_packet(si, serv_addr, SWRAP_CONNECT_SEND, NULL, 0);
+       if (si->type == SOCK_DGRAM) {
+               si->defer_connect = 1;
+               ret = 0;
+       } else {
+               swrap_dump_packet(si, serv_addr, SWRAP_CONNECT_SEND, NULL, 0);
 
-       ret = real_connect(s, (struct sockaddr *)&un_addr, 
-                          sizeof(struct sockaddr_un));
+               ret = real_connect(s, (struct sockaddr *)&un_addr,
+                                  sizeof(struct sockaddr_un));
+       }
 
        /* to give better errors */
        if (ret == -1 && errno == ENOENT) {
@@ -1917,7 +1926,22 @@ _PUBLIC_ ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, con
                        
                        return len;
                }
-               
+
+               if (si->defer_connect) {
+                       ret = real_connect(s, (struct sockaddr *)&un_addr,
+                                          sizeof(un_addr));
+
+                       /* to give better errors */
+                       if (ret == -1 && errno == ENOENT) {
+                               errno = EHOSTUNREACH;
+                       }
+
+                       if (ret == -1) {
+                               return ret;
+                       }
+                       si->defer_connect = 0;
+               }
+
                ret = real_sendto(s, buf, len, flags, (struct sockaddr *)&un_addr, sizeof(un_addr));
                break;
        default:
@@ -2002,6 +2026,33 @@ _PUBLIC_ ssize_t swrap_send(int s, const void *buf, size_t len, int flags)
 
        len = MIN(len, 1500);
 
+       if (si->defer_connect) {
+               struct sockaddr_un un_addr;
+               int bcast = 0;
+
+               if (si->bound == 0) {
+                       ret = swrap_auto_bind(si, si->family);
+                       if (ret == -1) return -1;
+               }
+
+               ret = sockaddr_convert_to_un(si, si->peername, si->peername_len,
+                                            &un_addr, 0, &bcast);
+               if (ret == -1) return -1;
+
+               ret = real_connect(s, (struct sockaddr *)&un_addr,
+                                  sizeof(un_addr));
+
+               /* to give better errors */
+               if (ret == -1 && errno == ENOENT) {
+                       errno = EHOSTUNREACH;
+               }
+
+               if (ret == -1) {
+                       return ret;
+               }
+               si->defer_connect = 0;
+       }
+
        ret = real_send(s, buf, len, flags);
 
        if (ret == -1) {
@@ -2014,6 +2065,76 @@ _PUBLIC_ ssize_t swrap_send(int s, const void *buf, size_t len, int flags)
        return ret;
 }
 
+_PUBLIC_ ssize_t swrap_sendmsg(int s, const struct msghdr *msg, int flags)
+{
+       int ret;
+       uint8_t *buf;
+       off_t ofs = 0;
+       size_t i;
+       size_t remain;
+       
+       struct socket_info *si = find_socket_info(s);
+
+       if (!si) {
+               return real_sendmsg(s, msg, flags);
+       }
+
+       if (si->defer_connect) {
+               struct sockaddr_un un_addr;
+               int bcast = 0;
+
+               if (si->bound == 0) {
+                       ret = swrap_auto_bind(si, si->family);
+                       if (ret == -1) return -1;
+               }
+
+               ret = sockaddr_convert_to_un(si, si->peername, si->peername_len,
+                                            &un_addr, 0, &bcast);
+               if (ret == -1) return -1;
+
+               ret = real_connect(s, (struct sockaddr *)&un_addr,
+                                  sizeof(un_addr));
+
+               /* to give better errors */
+               if (ret == -1 && errno == ENOENT) {
+                       errno = EHOSTUNREACH;
+               }
+
+               if (ret == -1) {
+                       return ret;
+               }
+               si->defer_connect = 0;
+       }
+
+       ret = real_sendmsg(s, msg, flags);
+       remain = ret;
+               
+       /* we capture it as one single packet */
+       buf = (uint8_t *)malloc(ret);
+       if (!buf) {
+               /* we just not capture the packet */
+               errno = 0;
+               return ret;
+       }
+       
+       for (i=0; i < msg->msg_iovlen; i++) {
+               size_t this_time = MIN(remain, msg->msg_iov[i].iov_len);
+               memcpy(buf + ofs,
+                      msg->msg_iov[i].iov_base,
+                      this_time);
+               ofs += this_time;
+               remain -= this_time;
+       }
+       
+       swrap_dump_packet(si, NULL, SWRAP_SEND, buf, ret);
+       free(buf);
+       if (ret == -1) {
+               swrap_dump_packet(si, NULL, SWRAP_SEND_RST, NULL, 0);
+       }
+
+       return ret;
+}
+
 int swrap_readv(int s, const struct iovec *vector, size_t count)
 {
        int ret;
@@ -2053,6 +2174,7 @@ int swrap_readv(int s, const struct iovec *vector, size_t count)
                uint8_t *buf;
                off_t ofs = 0;
                size_t i;
+               size_t remain = ret;
 
                /* we capture it as one single packet */
                buf = (uint8_t *)malloc(ret);
@@ -2063,10 +2185,12 @@ int swrap_readv(int s, const struct iovec *vector, size_t count)
                }
 
                for (i=0; i < count; i++) {
+                       size_t this_time = MIN(remain, vector[i].iov_len);
                        memcpy(buf + ofs,
                               vector[i].iov_base,
-                              vector[i].iov_len);
-                       ofs += vector[i].iov_len;
+                              this_time);
+                       ofs += this_time;
+                       remain -= this_time;
                }
 
                swrap_dump_packet(si, NULL, SWRAP_RECV, buf, ret);
@@ -2113,6 +2237,7 @@ int swrap_writev(int s, const struct iovec *vector, size_t count)
                uint8_t *buf;
                off_t ofs = 0;
                size_t i;
+               size_t remain = ret;
 
                /* we capture it as one single packet */
                buf = (uint8_t *)malloc(ret);
@@ -2123,10 +2248,12 @@ int swrap_writev(int s, const struct iovec *vector, size_t count)
                }
 
                for (i=0; i < count; i++) {
+                       size_t this_time = MIN(remain, vector[i].iov_len);
                        memcpy(buf + ofs,
                               vector[i].iov_base,
-                              vector[i].iov_len);
-                       ofs += vector[i].iov_len;
+                              this_time);
+                       ofs += this_time;
+                       remain -= this_time;
                }
 
                swrap_dump_packet(si, NULL, SWRAP_SEND, buf, ret);