TODO msg: start re-adding fd=passing with accrights
[obnox/samba/samba-obnox.git] / source3 / lib / msghdr.c
1 /*
2  * Unix SMB/CIFS implementation.
3  * Copyright (C) Volker Lendecke 2014
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #include "replace.h"
20 #include "lib/msghdr.h"
21 #include "lib/util/iov_buf.h"
22 #include <sys/socket.h>
23
24  #if !defined(HAVE_STRUCT_MSGHDR_MSG_CONTROL) && !defined(HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS)
25  # error Can not pass file descriptors
26  #endif
27
28 ssize_t msghdr_prep_fds(struct msghdr *msg, uint8_t *buf, size_t bufsize,
29                         const int *fds, size_t num_fds)
30 {
31 #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
32         size_t fds_size = sizeof(int) * MIN(num_fds, INT8_MAX);
33         size_t cmsg_len = CMSG_LEN(fds_size);
34         size_t cmsg_space = CMSG_SPACE(fds_size);
35         struct cmsghdr *cmsg;
36         void *fdptr;
37
38         if (num_fds == 0) {
39                 if (msg != NULL) {
40                         msg->msg_control = NULL;
41                         msg->msg_controllen = 0;
42                 }
43                 return 0;
44         }
45         if (num_fds > INT8_MAX) {
46                 return -1;
47         }
48         if ((msg == NULL) || (cmsg_space > bufsize)) {
49                 return cmsg_space;
50         }
51
52         msg->msg_control = buf;
53         msg->msg_controllen = cmsg_space;
54
55         cmsg = CMSG_FIRSTHDR(msg);
56         cmsg->cmsg_level = SOL_SOCKET;
57         cmsg->cmsg_type = SCM_RIGHTS;
58         cmsg->cmsg_len = cmsg_len;
59         fdptr = CMSG_DATA(cmsg);
60         memcpy(fdptr, fds, fds_size);
61         msg->msg_controllen = cmsg->cmsg_len;
62
63         return cmsg_space;
64 #else /* HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS */
65
66         if (num_fds == 0) {
67                 if (msg != NULL) {
68                         msg->msg_accrights = NULL;
69                         msg->msg_accrightslen = 0;
70                 }
71                 return 0;
72         }
73         if (num_fds > INT8_MAX) {
74                 return -1;
75         }
76         if ((msg == NULL) || (num_fds > bufsize)) {
77                 return num_fds;
78         }
79
80         msg->accrights = buf;
81         msg->accrights_len = num_fds;
82
83         return num_fds;
84 #endif
85 }
86
87 struct msghdr_buf {
88         struct msghdr msg;
89         struct sockaddr_storage addr;
90         struct iovec iov;
91         uint8_t buf[];
92 };
93
94 ssize_t msghdr_copy(struct msghdr_buf *msg, size_t msgsize,
95                     const void *addr, socklen_t addrlen,
96                     const struct iovec *iov, int iovcnt,
97                     const int *fds, size_t num_fds)
98 {
99         ssize_t fd_len;
100         size_t iov_len, needed, bufsize;
101
102         bufsize = (msgsize > offsetof(struct msghdr_buf, buf)) ?
103                 msgsize - offsetof(struct msghdr_buf, buf) : 0;
104
105         fd_len = msghdr_prep_fds(&msg->msg, msg->buf, bufsize, fds, num_fds);
106
107         if (fd_len == -1) {
108                 return -1;
109         }
110
111         if (bufsize >= fd_len) {
112                 bufsize -= fd_len;
113         } else {
114                 bufsize = 0;
115         }
116
117         if (msg != NULL) {
118
119                 if (addr != NULL) {
120                         if (addrlen > sizeof(struct sockaddr_storage)) {
121                                 errno = EMSGSIZE;
122                                 return -1;
123                         }
124                         memcpy(&msg->addr, addr, addrlen);
125                         msg->msg.msg_name = &msg->addr;
126                         msg->msg.msg_namelen = addrlen;
127                 } else {
128                         msg->msg.msg_name = NULL;
129                         msg->msg.msg_namelen = 0;
130                 }
131
132                 msg->iov.iov_base = msg->buf + fd_len;
133                 msg->iov.iov_len = iov_buf(
134                         iov, iovcnt, msg->iov.iov_base, bufsize);
135                 iov_len = msg->iov.iov_len;
136
137                 msg->msg.msg_iov = &msg->iov;
138                 msg->msg.msg_iovlen = 1;
139         } else {
140                 iov_len = iov_buflen(iov, iovcnt);
141         }
142
143         needed = offsetof(struct msghdr_buf, buf) + fd_len;
144         if (needed < fd_len) {
145                 return -1;
146         }
147         needed += iov_len;
148         if (needed < iov_len) {
149                 return -1;
150         }
151
152         return needed;
153 }
154
155 struct msghdr *msghdr_buf_msghdr(struct msghdr_buf *msg)
156 {
157         return &msg->msg;
158 }
159
160 size_t msghdr_prep_recv_fds(struct msghdr *msg, uint8_t *buf, size_t bufsize,
161                             size_t num_fds)
162 {
163         size_t ret = CMSG_SPACE(sizeof(int) * num_fds);
164
165         if (bufsize < ret) {
166                 return ret;
167         }
168         if (msg != NULL) {
169                 if (num_fds != 0) {
170                         msg->msg_control = buf;
171                         msg->msg_controllen = ret;
172                 } else {
173                         msg->msg_control = NULL;
174                         msg->msg_controllen = 0;
175                 }
176         }
177         return ret;
178 }
179
180 size_t msghdr_extract_fds(struct msghdr *msg, int *fds, size_t fds_size)
181 {
182         struct cmsghdr *cmsg;
183         size_t num_fds;
184
185         for(cmsg = CMSG_FIRSTHDR(msg);
186             cmsg != NULL;
187             cmsg = CMSG_NXTHDR(msg, cmsg))
188         {
189                 if ((cmsg->cmsg_type == SCM_RIGHTS) &&
190                     (cmsg->cmsg_level == SOL_SOCKET)) {
191                         break;
192                 }
193         }
194
195         if (cmsg == NULL) {
196                 return 0;
197         }
198
199         num_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
200
201         if ((num_fds != 0) && (fds != NULL) && (fds_size >= num_fds)) {
202                 memcpy(fds, CMSG_DATA(cmsg), num_fds * sizeof(int));
203         }
204
205         return num_fds;
206 }