From 1dd5fc7ac12671578f03ae2fb3a69902cb183061 Mon Sep 17 00:00:00 2001 From: Stefan Metzmacher Date: Mon, 29 Jun 2020 18:28:56 +0200 Subject: [PATCH] third_party/socket_wrapper/socket_wrapper.c ... fd-passing simple --- third_party/socket_wrapper/socket_wrapper.c | 1004 +++++++++++++++++-- 1 file changed, 948 insertions(+), 56 deletions(-) diff --git a/third_party/socket_wrapper/socket_wrapper.c b/third_party/socket_wrapper/socket_wrapper.c index ffdd31a51bfd..fdabc111b67e 100644 --- a/third_party/socket_wrapper/socket_wrapper.c +++ b/third_party/socket_wrapper/socket_wrapper.c @@ -268,6 +268,7 @@ struct socket_info int pktinfo; int tcp_nodelay; int listening; + int fd_passed; /* The unix path so we can unlink it on close() */ struct sockaddr_un un_addr; @@ -1779,7 +1780,7 @@ static int find_socket_info_index(int fd) return socket_fds_idx[fd]; } -static int swrap_add_socket_info(struct socket_info *si_input) +static int swrap_add_socket_info(const struct socket_info *si_input) { struct socket_info *si = NULL; int si_index = -1; @@ -1822,6 +1823,7 @@ static int swrap_create_socket(struct socket_info *si, int fd) "trying to add %d", socket_fds_max, fd); + errno = EMFILE; return -1; } @@ -2932,7 +2934,7 @@ static int swrap_pcap_get_fd(const char *fname) file_hdr.link_type = 0x0065; /* 101 RAW IP */ if (write(fd, &file_hdr, sizeof(file_hdr)) != sizeof(file_hdr)) { - close(fd); + libc_close(fd); fd = -1; } return fd; @@ -3437,6 +3439,9 @@ static int swrap_socket(int family, int type, int protocol) ret = swrap_create_socket(si, fd); if (ret == -1) { + int saved_errno = errno; + libc_close(fd); + errno = saved_errno; return -1; } @@ -3607,7 +3612,7 @@ static int swrap_accept(int s, &in_addr.sa_socklen); if (ret == -1) { SWRAP_UNLOCK_SI(parent_si); - close(fd); + libc_close(fd); return ret; } @@ -3639,7 +3644,7 @@ static int swrap_accept(int s, &un_my_addr.sa.s, &un_my_addr.sa_socklen); if (ret == -1) { - close(fd); + libc_close(fd); return ret; } @@ -3650,7 +3655,7 @@ static int swrap_accept(int s, &in_my_addr.sa.s, &in_my_addr.sa_socklen); if (ret == -1) { - close(fd); + libc_close(fd); return ret; } @@ -3665,7 +3670,9 @@ static int swrap_accept(int s, idx = swrap_create_socket(&new_si, fd); if (idx == -1) { - close (fd); + int saved_errno = errno; + libc_close(fd); + errno = saved_errno; return -1; } @@ -4959,16 +4966,21 @@ static int swrap_msghdr_add_socket_info(struct socket_info *si, return rc; } -static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg, +static int swrap_sendmsg_copy_cmsg(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space); -static int swrap_sendmsg_filter_cmsg_socket(struct cmsghdr *cmsg, - uint8_t **cm_data, - size_t *cm_data_space); - -static int swrap_sendmsg_filter_cmsghdr(struct msghdr *msg, +static int swrap_sendmsg_filter_cmsg_ipproto_ip(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space); +static int swrap_sendmsg_filter_cmsg_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space); + +static int swrap_sendmsg_filter_cmsghdr(const struct msghdr *_msg, uint8_t **cm_data, - size_t *cm_data_space) { + size_t *cm_data_space) +{ + struct msghdr *msg = discard_const_p(struct msghdr, _msg); struct cmsghdr *cmsg; int rc = -1; @@ -4982,9 +4994,14 @@ static int swrap_sendmsg_filter_cmsghdr(struct msghdr *msg, cmsg = CMSG_NXTHDR(msg, cmsg)) { switch (cmsg->cmsg_level) { case IPPROTO_IP: - rc = swrap_sendmsg_filter_cmsg_socket(cmsg, - cm_data, - cm_data_space); + rc = swrap_sendmsg_filter_cmsg_ipproto_ip(cmsg, + cm_data, + cm_data_space); + break; + case SOL_SOCKET: + rc = swrap_sendmsg_filter_cmsg_sol_socket(cmsg, + cm_data, + cm_data_space); break; default: rc = swrap_sendmsg_copy_cmsg(cmsg, @@ -4992,12 +5009,19 @@ static int swrap_sendmsg_filter_cmsghdr(struct msghdr *msg, cm_data_space); break; } + if (rc < 0) { + int saved_errno = errno; + SAFE_FREE(*cm_data); + *cm_data_space = 0; + errno = saved_errno; + return rc; + } } return rc; } -static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg, +static int swrap_sendmsg_copy_cmsg(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space) { @@ -5020,14 +5044,14 @@ static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg, return 0; } -static int swrap_sendmsg_filter_cmsg_pktinfo(struct cmsghdr *cmsg, +static int swrap_sendmsg_filter_cmsg_pktinfo(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space); -static int swrap_sendmsg_filter_cmsg_socket(struct cmsghdr *cmsg, - uint8_t **cm_data, - size_t *cm_data_space) +static int swrap_sendmsg_filter_cmsg_ipproto_ip(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) { int rc = -1; @@ -5053,7 +5077,7 @@ static int swrap_sendmsg_filter_cmsg_socket(struct cmsghdr *cmsg, return rc; } -static int swrap_sendmsg_filter_cmsg_pktinfo(struct cmsghdr *cmsg, +static int swrap_sendmsg_filter_cmsg_pktinfo(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space) { @@ -5067,7 +5091,816 @@ static int swrap_sendmsg_filter_cmsg_pktinfo(struct cmsghdr *cmsg, */ return 0; } + +static int swrap_sendmsg_filter_cmsg_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) +{ + int rc = -1; + + switch (cmsg->cmsg_type) { + case SCM_RIGHTS: + SWRAP_LOG(SWRAP_LOG_TRACE, + "Ignoring SCM_RIGHTS on inet socket!"); + rc = 0; + break; +#ifdef SCM_CREDENTIALS + case SCM_CREDENTIALS: + SWRAP_LOG(SWRAP_LOG_TRACE, + "Ignoring SCM_CREDENTIALS on inet socket!"); + rc = 0; + break; +#endif /* SCM_CREDENTIALS */ + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + cm_data, + cm_data_space); + break; + } + + return rc; +} + +#if 0 +#define __STRUCT_WITH_MAGIC(__magic_string_var, __magic_size_var, __payload_definition, __main_definition) \ +static const char *__magic_string_var = ## __payload_definition; \ +static const size_t __magic_size_var = sizeof(## __payload_definition); \ +__payload_definition; \ +__main_definition; +#define __STRUCT_WITH_MAGIC(__magic_string_var, __magic_size_var, __payload_definition, __main_definition) \ +extern const char *__magic_string_var; \ +const char *__magic_string_var = # __payload_definition; \ +extern const size_t __magic_size_var; \ +const size_t __magic_size_var = sizeof(# __payload_definition); \ +__payload_definition; \ +__main_definition; + +__STRUCT_WITH_MAGIC( +swrap_unix_scm_rights_payload_magic_string, \ +swrap_unix_scm_rights_payload_magic_size, \ + +#endif + +#define SWRAP_MAX_PASSED_FDS ((size_t)8) +#define SWRAP_MAX_PASSED_SOCKET_INFO ((size_t)8) +struct swrap_unix_scm_rights_payload { \ + uint8_t num_idxs; \ + int8_t idxs[SWRAP_MAX_PASSED_FDS]; \ + struct socket_info infos[SWRAP_MAX_PASSED_SOCKET_INFO]; \ +}; +struct swrap_unix_scm_rights { \ + uint32_t full_size; + uint32_t payload_size; +// char payload_definition[sizeof(# __payload_definition)]; + struct swrap_unix_scm_rights_payload payload; +}; + +static void swrap_dec_fd_passed_array(size_t num, struct socket_info **array) +{ + int saved_errno = errno; + size_t i; + + for (i = 0; i < num; i++) { + struct socket_info *si = array[i]; + if (si == NULL) { + continue; + } + + SWRAP_LOCK_SI(si); + swrap_dec_refcount(si); + if (si->fd_passed > 0) { + si->fd_passed -= 1; + } + SWRAP_UNLOCK_SI(si); + array[i] = NULL; + } + + errno = saved_errno; +} + +static void swrap_undo_si_idx_array(size_t num, int *array) +{ + int saved_errno = errno; + size_t i; + + swrap_mutex_lock(&first_free_mutex); + + for (i = 0; i < num; i++) { + struct socket_info *si = NULL; + + if (array[i] == -1) { + continue; + } + + si = swrap_get_socket_info(array[i]); + if (si == NULL) { + continue; + } + + SWRAP_LOCK_SI(si); + swrap_dec_refcount(si); + SWRAP_UNLOCK_SI(si); + + swrap_set_next_free(si, first_free); + first_free = array[i]; + array[i] = -1; + } + + swrap_mutex_unlock(&first_free_mutex); + errno = saved_errno; +} + +static void swrap_close_fd_array(size_t num, const int *array) +{ + int saved_errno = errno; + size_t i; + + for (i = 0; i < num; i++) { + if (array[i] == -1) { + continue; + } + libc_close(array[i]); + } + + errno = saved_errno; +} + +union __swrap_fds { + const uint8_t *p; + int *fds; +}; + +union __swrap_cmsghdr { + const uint8_t *p; + struct cmsghdr *cmsg; +}; + +static int swrap_sendmsg_unix_scm_rights(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space, + int *scm_rights_pipe_fd) +{ + struct swrap_unix_scm_rights info; + struct swrap_unix_scm_rights_payload *payload = NULL; + int si_idx_array[SWRAP_MAX_PASSED_FDS]; + struct socket_info *si_array[SWRAP_MAX_PASSED_FDS] = { NULL, }; + size_t info_idx = 0; + size_t size_fds_in; + size_t num_fds_in; + union __swrap_fds __fds_in = { .p = NULL, }; + const int *fds_in = NULL; + size_t num_fds_out; + size_t size_fds_out; + union __swrap_fds __fds_out = { .p = NULL, }; + int *fds_out = NULL; + size_t cmsg_len; + size_t cmsg_space; + size_t new_cm_data_space; + union __swrap_cmsghdr __new_cmsg = { .p = NULL, }; + struct cmsghdr *new_cmsg = NULL; + uint8_t *p = NULL; + size_t i; + int pipefd[2] = { -1, -1 }; + int rc; + ssize_t sret; + + /* + * We pass this a bufer to the kernel make sure any padding + * is also cleared. + */ + ZERO_STRUCT(info); + info.full_size = sizeof(info); + info.payload_size = sizeof(info.payload); + payload = &info.payload; + + if (*scm_rights_pipe_fd != -1) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "Two SCM_RIGHTS headers are not supported by socket_wrapper"); + errno = EINVAL; + return -1; + } + + if (cmsg->cmsg_len < CMSG_LEN(0)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu < CMSG_LEN(0)=%zu", + (size_t)cmsg->cmsg_len, + CMSG_LEN(0)); + errno = EINVAL; + return -1; + } + size_fds_in = cmsg->cmsg_len - CMSG_LEN(0); + if ((size_fds_in % sizeof(int)) != 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu => (size_fds_in=%zu %% sizeof(int)=%zu) != 0", + (size_t)cmsg->cmsg_len, + size_fds_in, + sizeof(int)); + errno = EINVAL; + return -1; + } + num_fds_in = size_fds_in / sizeof(int); + if (num_fds_in > SWRAP_MAX_PASSED_FDS) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu > " + "SWRAP_MAX_PASSED_FDS(%zu)", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in, + SWRAP_MAX_PASSED_FDS); + errno = EINVAL; + return -1; + } + if (num_fds_in == 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in); + errno = EINVAL; + return -1; + } + __fds_in.p = CMSG_DATA(cmsg); + fds_in = __fds_in.fds; + num_fds_out = num_fds_in + 1; + + SWRAP_LOG(SWRAP_LOG_TRACE, + "num_fds_in=%zu num_fds_out=%zu", + num_fds_in, num_fds_out); + + size_fds_out = sizeof(int) * num_fds_out; + cmsg_len = CMSG_LEN(size_fds_out); + cmsg_space = CMSG_SPACE(size_fds_out); + + new_cm_data_space = *cm_data_space + cmsg_space; + + p = realloc((*cm_data), new_cm_data_space); + if (p == NULL) { + return -1; + } + (*cm_data) = p; + p = (*cm_data) + (*cm_data_space); + memset(p, 0, cmsg_space); + __new_cmsg.p = p; + new_cmsg = __new_cmsg.cmsg; + *new_cmsg = *cmsg; + __fds_out.p = CMSG_DATA(new_cmsg); + fds_out = __fds_out.fds; + memcpy(fds_out, fds_in, size_fds_out); + new_cmsg->cmsg_len = cmsg->cmsg_len; + + for (i = 0; i < num_fds_in; i++) { + size_t j; + + payload->idxs[i] = -1; + payload->num_idxs++; + + si_idx_array[i] = find_socket_info_index(fds_in[i]); + if (si_idx_array[i] == -1) { + continue; + } + + si_array[i] = swrap_get_socket_info(si_idx_array[i]); + if (si_array[i] == NULL) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_in[%zu]=%d si_idx_array[%zu]=%d missing!", + i, fds_in[i], i, si_idx_array[i]); + errno = EINVAL; + return -1; + } + + for (j = 0; j < i; j++) { + if (si_array[j] == si_array[i]) { + payload->idxs[i] = payload->idxs[j]; + break; + } + } + if (payload->idxs[i] == -1) { + if (info_idx >= SWRAP_MAX_PASSED_SOCKET_INFO) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_in[%zu]=%d,si_idx_array[%zu]=%d: " + "info_idx=%zu >= SWRAP_MAX_PASSED_FDS(%zu)!", + i, fds_in[i], i, si_idx_array[i], + info_idx, + SWRAP_MAX_PASSED_SOCKET_INFO); + errno = EINVAL; + return -1; + } + payload->idxs[i] = info_idx; + info_idx += 1; + continue; + } + } + + for (i = 0; i < num_fds_in; i++) { + struct socket_info *si = si_array[i]; + + if (si == NULL) { + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_in[%zu]=%d not an inet socket", + i, fds_in[i]); + continue; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_in[%zu]=%d si_idx_array[%zu]=%d " + "passing as info.idxs[%zu]=%d!", + i, fds_in[i], + i, si_idx_array[i], + i, payload->idxs[i]); + + SWRAP_LOCK_SI(si); + si->fd_passed += 1; + payload->infos[payload->idxs[i]] = *si; + payload->infos[payload->idxs[i]].fd_passed = 0; + SWRAP_UNLOCK_SI(si); + } + + rc = pipe(pipefd); + if (rc == -1) { + int saved_errno = errno; + SWRAP_LOG(SWRAP_LOG_ERROR, + "pipe() failed - %d %s", + saved_errno, + strerror(saved_errno)); + swrap_dec_fd_passed_array(num_fds_in, si_array); + errno = saved_errno; + return -1; + } + + sret = write(pipefd[1], &info, sizeof(info)); + if (sret != sizeof(info)) { + int saved_errno = errno; + if (sret != -1) { + saved_errno = EINVAL; + } + SWRAP_LOG(SWRAP_LOG_ERROR, + "write() failed - sret=%zd - %d %s", + sret, saved_errno, + strerror(saved_errno)); + swrap_dec_fd_passed_array(num_fds_in, si_array); + libc_close(pipefd[1]); + libc_close(pipefd[0]); + errno = saved_errno; + return -1; + } + libc_close(pipefd[1]); + + /* + * Add the pipe read end to the end of the passed fd array + */ + fds_out[num_fds_in] = pipefd[0]; + new_cmsg->cmsg_len = cmsg_len; + + /* we're done ... */ + *scm_rights_pipe_fd = pipefd[0]; + *cm_data_space = new_cm_data_space; + + return 0; +} + +static int swrap_sendmsg_unix_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space, + int *scm_rights_pipe_fd) +{ + int rc = -1; + + switch (cmsg->cmsg_type) { + case SCM_RIGHTS: + rc = swrap_sendmsg_unix_scm_rights(cmsg, + cm_data, + cm_data_space, + scm_rights_pipe_fd); + break; + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + cm_data, + cm_data_space); + break; + } + + return rc; +} + +static int swrap_recvmsg_unix_scm_rights(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) +{ + int scm_rights_pipe_fd = -1; + struct swrap_unix_scm_rights info; + struct swrap_unix_scm_rights_payload *payload = NULL; + int si_idx_array[SWRAP_MAX_PASSED_FDS]; + size_t size_fds_in; + size_t num_fds_in; + union __swrap_fds __fds_in = { .p = NULL, }; + const int *fds_in = NULL; + size_t num_fds_out; + size_t size_fds_out; + union __swrap_fds __fds_out = { .p = NULL, }; + int *fds_out = NULL; + size_t cmsg_len; + size_t cmsg_space; + size_t new_cm_data_space; + union __swrap_cmsghdr __new_cmsg = { .p = NULL, }; + struct cmsghdr *new_cmsg = NULL; + uint8_t *p = NULL; + size_t i; + ssize_t sret; + + if (cmsg->cmsg_len < CMSG_LEN(0)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu < CMSG_LEN(0)=%zu", + (size_t)cmsg->cmsg_len, + CMSG_LEN(0)); + errno = EINVAL; + return -1; + } + size_fds_in = cmsg->cmsg_len - CMSG_LEN(0); + if ((size_fds_in % sizeof(int)) != 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu => (size_fds_in=%zu %% sizeof(int)=%zu) != 0", + (size_t)cmsg->cmsg_len, + size_fds_in, + sizeof(int)); + errno = EINVAL; + return -1; + } + num_fds_in = size_fds_in / sizeof(int); + if (num_fds_in > (SWRAP_MAX_PASSED_FDS + 1)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu > SWRAP_MAX_PASSED_FDS+1(%zu)", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in, + SWRAP_MAX_PASSED_FDS+1); + errno = EINVAL; + return -1; + } + if (num_fds_in <= 1) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in); + errno = EINVAL; + return -1; + } + __fds_in.p = CMSG_DATA(cmsg); + fds_in = __fds_in.fds; + num_fds_out = num_fds_in - 1; + + SWRAP_LOG(SWRAP_LOG_TRACE, + "num_fds_in=%zu num_fds_out=%zu", + num_fds_in, num_fds_out); + + for (i = 0; i < num_fds_in; i++) { + /* Check if we have a stale fd and remove it */ + swrap_remove_stale(fds_in[i]); + } + + scm_rights_pipe_fd = fds_in[num_fds_out]; + size_fds_out = sizeof(int) * num_fds_out; + cmsg_len = CMSG_LEN(size_fds_out); + cmsg_space = CMSG_SPACE(size_fds_out); + + new_cm_data_space = *cm_data_space + cmsg_space; + + p = realloc((*cm_data), new_cm_data_space); + if (p == NULL) { + swrap_close_fd_array(num_fds_in, fds_in); + return -1; + } + (*cm_data) = p; + p = (*cm_data) + (*cm_data_space); + memset(p, 0, cmsg_space); + __new_cmsg.p = p; + new_cmsg = __new_cmsg.cmsg; + *new_cmsg = *cmsg; + __fds_out.p = CMSG_DATA(new_cmsg); + fds_out = __fds_out.fds; + memcpy(fds_out, fds_in, size_fds_out); + new_cmsg->cmsg_len = cmsg_len; + + sret = read(scm_rights_pipe_fd, &info, sizeof(info)); + if (sret != sizeof(info)) { + int saved_errno = errno; + if (sret != -1) { + saved_errno = EINVAL; + } + SWRAP_LOG(SWRAP_LOG_ERROR, + "read() failed - sret=%zd - %d %s", + sret, saved_errno, + strerror(saved_errno)); + swrap_close_fd_array(num_fds_in, fds_in); + errno = saved_errno; + return -1; + } + libc_close(scm_rights_pipe_fd); + payload = &info.payload; + + if (payload->num_idxs != num_fds_out) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.num_idxs=%u != num_fds_out=%zu", + payload->num_idxs, num_fds_out); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + for (i = 0; i < num_fds_out; i++) { + size_t j; + + si_idx_array[i] = -1; + + if (payload->idxs[i] == -1) { + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_out[%zu]=%d not an inet socket", + i, fds_out[i]); + continue; + } + + if (payload->idxs[i] < 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_out[%zu]=%d info.idxs[%zu]=%d < 0!", + i, fds_out[i], i, payload->idxs[i]); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if (payload->idxs[i] >= payload->num_idxs) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_out[%zu]=%d info.idxs[%zu]=%d >= %u!", + i, fds_out[i], i, payload->idxs[i], + payload->num_idxs); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if ((size_t)fds_out[i] >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + fds_out[i]); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EMFILE; + return -1; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_in[%zu]=%d " + "received as info.idxs[%zu]=%d!", + i, fds_out[i], + i, payload->idxs[i]); + + for (j = 0; j < i; j++) { + if (payload->idxs[j] == -1) { + continue; + } + if (payload->idxs[j] == payload->idxs[i]) { + si_idx_array[i] = si_idx_array[j]; + } + } + if (si_idx_array[i] == -1) { + const struct socket_info *si = &payload->infos[payload->idxs[i]]; + + si_idx_array[i] = swrap_add_socket_info(si); + if (si_idx_array[i] == -1) { + int saved_errno = errno; + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + fds_out[i]); + swrap_undo_si_idx_array(i, si_idx_array); + swrap_close_fd_array(num_fds_out, fds_out); + errno = saved_errno; + return -1; + } + SWRAP_LOG(SWRAP_LOG_TRACE, + "Imported %s socket for protocol %s, fd=%d", + si->family == AF_INET ? "IPv4" : "IPv6", + si->type == SOCK_DGRAM ? "UDP" : "TCP", + fds_out[i]); + } + } + + for (i = 0; i < num_fds_out; i++) { + if (si_idx_array[i] == -1) { + continue; + } + set_socket_info_index(fds_out[i], si_idx_array[i]); + } + +#if 0 + si->family = family; + + /* however, the rest of the socket_wrapper code expects just + * the type, not the flags */ + si->type = real_type; + si->protocol = protocol; + + ret = swrap_create_socket(si, fd); + if (ret == -1) { + return -1; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "Created %s socket for protocol %s, fd=%d", + family == AF_INET ? "IPv4" : "IPv6", + real_type == SOCK_DGRAM ? "UDP" : "TCP", + fd); + + si_array[i] = swrap_get_socket_info(info.idxs[i]); + if (si_array[i] == NULL) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_out[%zu]=%d info.idxs[%zu]=%d missing!", + i, fds_out[i], i, info.idxs[i]); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + } + + for (i = 0; i < num_fds_in; i++) { + struct socket_info *si = si_array[i]; + + if (si == NULL) { + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_out[%zu]=%d not an inet socket", + i, fds_in[i]); + continue; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_out[%zu]=%d si_idx_array[%zu]=%d passed!", + i, fds_in[i], i, info.idxs[i]); + set_socket_info_index(fds_out[i], info.idxs[i]); + } +#endif + + /* we're done ... */ + *cm_data_space = new_cm_data_space; + + return 0; +} + +static int swrap_recvmsg_unix_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) +{ + int rc = -1; + + switch (cmsg->cmsg_type) { + case SCM_RIGHTS: + rc = swrap_recvmsg_unix_scm_rights(cmsg, + cm_data, + cm_data_space); + break; + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + cm_data, + cm_data_space); + break; + } + + return rc; +} + +#endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + +static int swrap_sendmsg_before_unix(const struct msghdr *_msg_in, + struct msghdr *msg_tmp, + int *scm_rights_pipe_fd) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + struct msghdr *msg_in = discard_const_p(struct msghdr, _msg_in); + struct cmsghdr *cmsg = NULL; + uint8_t *cm_data = NULL; + size_t cm_data_space = 0; + int rc = -1; + + *msg_tmp = *msg_in; + *scm_rights_pipe_fd = -1; + + /* Nothing to do */ + if (msg_in->msg_controllen == 0 || msg_in->msg_control == NULL) { + return 0; + } + + for (cmsg = CMSG_FIRSTHDR(msg_in); + cmsg != NULL; + cmsg = CMSG_NXTHDR(msg_in, cmsg)) { + switch (cmsg->cmsg_level) { + case SOL_SOCKET: + rc = swrap_sendmsg_unix_sol_socket(cmsg, + &cm_data, + &cm_data_space, + scm_rights_pipe_fd); + break; + + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + &cm_data, + &cm_data_space); + break; + } + if (rc < 0) { + int saved_errno = errno; + SAFE_FREE(cm_data); + errno = saved_errno; + return rc; + } + } + + msg_tmp->msg_controllen = cm_data_space; + msg_tmp->msg_control = cm_data; + + return 0; +#else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + *msg_tmp = *_msg_in; + return 0; +#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */ +} + +static ssize_t swrap_sendmsg_after_unix(struct msghdr *msg_tmp, + ssize_t ret, + int scm_rights_pipe_fd) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + int saved_errno = errno; + SAFE_FREE(msg_tmp->msg_control); + if (scm_rights_pipe_fd != -1) { + libc_close(scm_rights_pipe_fd); + } + errno = saved_errno; #endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + return ret; +} + +static int swrap_recvmsg_before_unix(struct msghdr *msg_in, + struct msghdr *msg_tmp) +{ + *msg_tmp = *msg_in; + return 0; +} + +static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp, + struct msghdr *msg_out, + ssize_t ret) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + struct cmsghdr *cmsg = NULL; + uint8_t *cm_data = NULL; + size_t cm_data_space = 0; + int rc = -1; + + /* Nothing to do */ + if (msg_tmp->msg_controllen == 0 || msg_tmp->msg_control == NULL) { + goto done; + } + + for (cmsg = CMSG_FIRSTHDR(msg_tmp); + cmsg != NULL; + cmsg = CMSG_NXTHDR(msg_tmp, cmsg)) { + switch (cmsg->cmsg_level) { + case SOL_SOCKET: + rc = swrap_recvmsg_unix_sol_socket(cmsg, + &cm_data, + &cm_data_space); + break; + + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + &cm_data, + &cm_data_space); + break; + } + if (rc < 0) { + int saved_errno = errno; + SAFE_FREE(cm_data); + errno = saved_errno; + return rc; + } + } + + /* + * msg_tmp->msg_control is still the buffer of the caller. + */ + memcpy(msg_tmp->msg_control, cm_data, cm_data_space); + msg_tmp->msg_controllen = cm_data_space; + SAFE_FREE(cm_data); +done: +#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + *msg_out = *msg_tmp; + return ret; +} static ssize_t swrap_sendmsg_before(int fd, struct socket_info *si, @@ -5215,28 +6048,6 @@ static ssize_t swrap_sendmsg_before(int fd, goto out; } -#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL - if (msg->msg_controllen > 0 && msg->msg_control != NULL) { - uint8_t *cmbuf = NULL; - size_t cmlen = 0; - - ret = swrap_sendmsg_filter_cmsghdr(msg, &cmbuf, &cmlen); - if (ret < 0) { - free(cmbuf); - goto out; - } - - if (cmlen == 0) { - msg->msg_controllen = 0; - msg->msg_control = NULL; - } else if (cmlen < msg->msg_controllen && cmbuf != NULL) { - memcpy(msg->msg_control, cmbuf, cmlen); - msg->msg_controllen = cmlen; - } - free(cmbuf); - } -#endif - ret = 0; out: SWRAP_UNLOCK_SI(si); @@ -6003,7 +6814,12 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags) si = find_socket_info(s); if (si == NULL) { - return libc_recvmsg(s, omsg, flags); + rc = swrap_recvmsg_before_unix(omsg, &msg); + if (rc < 0) { + return rc; + } + ret = libc_recvmsg(s, &msg, flags); + return swrap_recvmsg_after_unix(&msg, omsg, ret); } tmp.iov_base = NULL; @@ -6126,7 +6942,15 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) int bcast = 0; if (!si) { - return libc_sendmsg(s, omsg, flags); + int scm_rights_pipe_fd = -1; + + rc = swrap_sendmsg_before_unix(omsg, &msg, + &scm_rights_pipe_fd); + if (rc < 0) { + return rc; + } + ret = libc_sendmsg(s, &msg, flags); + return swrap_sendmsg_after_unix(&msg, ret, scm_rights_pipe_fd); } ZERO_STRUCT(un_addr); @@ -6148,20 +6972,32 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) SWRAP_UNLOCK_SI(si); #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL - if (msg.msg_controllen > 0 && msg.msg_control != NULL) { - /* omsg is a const so use a local buffer for modifications */ - uint8_t cmbuf[omsg->msg_controllen]; + if (omsg != NULL && omsg->msg_controllen > 0 && omsg->msg_control != NULL) { + uint8_t *cmbuf = NULL; + size_t cmlen = 0; - memcpy(cmbuf, omsg->msg_control, omsg->msg_controllen); + rc = swrap_sendmsg_filter_cmsghdr(omsg, &cmbuf, &cmlen); + if (rc < 0) { + return rc; + } - msg.msg_control = cmbuf; /* ancillary data, see below */ - msg.msg_controllen = omsg->msg_controllen; /* ancillary data buffer len */ + if (cmlen == 0) { + msg.msg_controllen = 0; + msg.msg_control = NULL; + } else { + msg.msg_control = cmbuf; + msg.msg_controllen = cmlen; + } } msg.msg_flags = omsg->msg_flags; /* flags on received message */ #endif - rc = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, &to_un, &to, &bcast); if (rc < 0) { + int saved_errno = errno; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg.msg_control); +#endif + errno = saved_errno; return -1; } @@ -6187,6 +7023,11 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) /* we capture it as one single packet */ buf = (uint8_t *)malloc(remain); if (!buf) { + int saved_errno = errno; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg.msg_control); +#endif + errno = saved_errno; return -1; } @@ -6203,7 +7044,12 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) swrap_dir = socket_wrapper_dir(); if (swrap_dir == NULL) { - free(buf); + int saved_errno = errno; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg.msg_control); +#endif + SAFE_FREE(buf); + errno = saved_errno; return -1; } @@ -6234,6 +7080,14 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) swrap_sendmsg_after(s, si, &msg, to, ret); +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + { + int saved_errno = errno; + SAFE_FREE(msg.msg_control); + errno = saved_errno; + } +#endif + return ret; } @@ -6385,6 +7239,10 @@ static int swrap_close(int fd) goto out; } + if (si->fd_passed) { + goto set_next_free; + } + if (si->myname.sa_socklen > 0 && si->peername.sa_socklen > 0) { swrap_pcap_dump_packet(si, NULL, SWRAP_CLOSE_SEND, NULL, 0); } @@ -6398,6 +7256,7 @@ static int swrap_close(int fd) unlink(si->un_addr.sun_path); } +set_next_free: swrap_set_next_free(si, first_free); first_free = si_index; @@ -6437,6 +7296,17 @@ static int swrap_dup(int fd) return -1; } + if ((size_t)dup_fd >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + dup_fd); + libc_close(dup_fd); + errno = EMFILE; + return -1; + } + SWRAP_LOCK_SI(si); swrap_inc_refcount(si); @@ -6482,6 +7352,16 @@ static int swrap_dup2(int fd, int newfd) return newfd; } + if ((size_t)newfd >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + newfd); + errno = EMFILE; + return -1; + } + if (find_socket_info(newfd)) { /* dup2() does an implicit close of newfd, which we * need to emulate */ @@ -6539,14 +7419,26 @@ static int swrap_vfcntl(int fd, int cmd, va_list va) return -1; } + /* Make sure we don't have an entry for the fd */ + swrap_remove_stale(dup_fd); + + if ((size_t)dup_fd >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + dup_fd); + libc_close(dup_fd); + errno = EMFILE; + return -1; + } + SWRAP_LOCK_SI(si); swrap_inc_refcount(si); SWRAP_UNLOCK_SI(si); - /* Make sure we don't have an entry for the fd */ - swrap_remove_stale(dup_fd); set_socket_info_index(dup_fd, idx); -- 2.34.1