r6570: Add socket_wrapper library to 3.0. Can be enabled by passing
[tprouty/samba.git] / source / lib / socket_wrapper.c
1 /* 
2    Socket wrapper library. Passes all socket communication over 
3    unix domain sockets if the environment variable SOCKET_WRAPPER_DIR 
4    is set.
5    Copyright (C) Jelmer Vernooij 2005
6    
7    This program is free software; you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation; either version 2 of the License, or
10    (at your option) any later version.
11    
12    This program is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16    
17    You should have received a copy of the GNU General Public License
18    along with this program; if not, write to the Free Software
19    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20 */
21
22 #ifdef _SAMBA_BUILD
23 #include "includes.h"
24 #include "system/network.h"
25 #else
26 #include <sys/types.h>
27 #include <sys/socket.h>
28 #include <errno.h>
29 #include <sys/un.h>
30 #include <netinet/in.h>
31 #include <netinet/tcp.h>
32 #include <stdlib.h>
33 #include <unistd.h>
34 #include <string.h>
35 #include <stdio.h>
36 #include "dlinklist.h"
37 #endif
38
39 /* LD_PRELOAD doesn't work yet, so REWRITE_CALLS is all we support
40  * for now */
41 #define REWRITE_CALLS 
42
43 #ifdef REWRITE_CALLS
44 #define real_accept accept
45 #define real_connect connect
46 #define real_bind bind
47 #define real_getpeername getpeername
48 #define real_getsockname getsockname
49 #define real_getsockopt getsockopt
50 #define real_setsockopt setsockopt
51 #define real_recvfrom recvfrom
52 #define real_sendto sendto
53 #define real_socket socket
54 #define real_close close
55 #endif
56
57 static struct sockaddr *sockaddr_dup(const void *data, socklen_t len)
58 {
59         struct sockaddr *ret = (struct sockaddr *)malloc(len);
60         memcpy(ret, data, len);
61         return ret;
62 }
63
64 struct socket_info
65 {
66         int fd;
67
68         int domain;
69         int type;
70         int protocol;
71         int bound;
72
73         char *path;
74         char *tmp_path;
75
76         struct sockaddr *myname;
77         socklen_t myname_len;
78
79         struct sockaddr *peername;
80         socklen_t peername_len;
81
82         struct socket_info *prev, *next;
83 };
84
85 static struct socket_info *sockets = NULL;
86
87 static int convert_un_in(const struct sockaddr_un *un, struct sockaddr_in *in, socklen_t *len)
88 {
89         unsigned int prt;
90         const char *p;
91         int type;
92
93         if ((*len) < sizeof(struct sockaddr_in)) {
94                 return 0;
95         }
96
97         in->sin_family = AF_INET;
98         in->sin_port = 1025; /* Default to 1025 */
99         p = strchr(un->sun_path, '/');
100         if (p) p++; else p = un->sun_path;
101
102         if (sscanf(p, "sock_ip_%d_%u", &type, &prt) == 2) {
103                 in->sin_port = htons(prt);
104         }
105         in->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
106         *len = sizeof(struct sockaddr_in);
107         return 0;
108 }
109
110 static int convert_in_un(int type, const struct sockaddr_in *in, struct sockaddr_un *un)
111 {
112         uint16_t prt = ntohs(in->sin_port);
113         snprintf(un->sun_path, sizeof(un->sun_path), "%s/sock_ip_%d_%u", 
114                  getenv("SOCKET_WRAPPER_DIR"), type, prt);
115         return 0;
116 }
117
118 static struct socket_info *find_socket_info(int fd)
119 {
120         struct socket_info *i;
121         for (i = sockets; i; i = i->next) {
122                 if (i->fd == fd) 
123                         return i;
124         }
125
126         return NULL;
127 }
128
129 static int sockaddr_convert_to_un(const struct socket_info *si, const struct sockaddr *in_addr, socklen_t in_len, 
130                                          struct sockaddr_un *out_addr)
131 {
132         if (!out_addr)
133                 return 0;
134
135         out_addr->sun_family = AF_UNIX;
136
137         switch (in_addr->sa_family) {
138         case AF_INET:
139                 return convert_in_un(si->type, (const struct sockaddr_in *)in_addr, out_addr);
140         case AF_UNIX:
141                 memcpy(out_addr, in_addr, sizeof(*out_addr));
142                 return 0;
143         default:
144                 break;
145         }
146         
147         errno = EAFNOSUPPORT;
148         return -1;
149 }
150
151 static int sockaddr_convert_from_un(const struct socket_info *si, 
152                                     const struct sockaddr_un *in_addr, 
153                                     socklen_t un_addrlen,
154                                     int family,
155                                     struct sockaddr *out_addr,
156                                     socklen_t *out_len)
157 {
158         if (out_addr == NULL || out_len == NULL) 
159                 return 0;
160
161         if (un_addrlen == 0) {
162                 *out_len = 0;
163                 return 0;
164         }
165
166         switch (family) {
167         case AF_INET:
168                 return convert_un_in(in_addr, (struct sockaddr_in *)out_addr, out_len);
169         case AF_UNIX:
170                 memcpy(out_addr, in_addr, sizeof(*in_addr));
171                 *out_len = sizeof(*in_addr);
172                 return 0;
173         default:
174                 break;
175         }
176
177         errno = EAFNOSUPPORT;
178         return -1;
179 }
180
181 int swrap_socket(int domain, int type, int protocol)
182 {
183         struct socket_info *si;
184         int fd;
185
186         if (!getenv("SOCKET_WRAPPER_DIR")) {
187                 return real_socket(domain, type, protocol);
188         }
189         
190         fd = real_socket(AF_UNIX, type, 0);
191
192         if (fd == -1) return -1;
193
194         si = calloc(1, sizeof(struct socket_info));
195
196         si->domain = domain;
197         si->type = type;
198         si->protocol = protocol;
199         si->fd = fd;
200
201         DLIST_ADD(sockets, si);
202
203         return si->fd;
204 }
205
206 int swrap_accept(int s, struct sockaddr *addr, socklen_t *addrlen)
207 {
208         struct socket_info *parent_si, *child_si;
209         int fd;
210         socklen_t un_addrlen = sizeof(struct sockaddr_un);
211         struct sockaddr_un un_addr;
212         int ret;
213
214         parent_si = find_socket_info(s);
215         if (!parent_si) {
216                 return real_accept(s, addr, addrlen);
217         }
218
219         ret = real_accept(s, (struct sockaddr *)&un_addr, &un_addrlen);
220         if (ret == -1) return ret;
221
222         fd = ret;
223
224         ret = sockaddr_convert_from_un(parent_si, &un_addr, un_addrlen,
225                                        parent_si->domain, addr, addrlen);
226         if (ret == -1) return ret;
227
228         child_si = malloc(sizeof(struct socket_info));
229         memset(child_si, 0, sizeof(*child_si));
230
231         child_si->fd = fd;
232
233         if (addr && addrlen) {
234                 child_si->myname_len = *addrlen;
235                 child_si->myname = sockaddr_dup(addr, *addrlen);
236         }
237
238         return fd;
239 }
240
241 int swrap_connect(int s, const struct sockaddr *serv_addr, socklen_t addrlen)
242 {
243         int ret;
244         struct sockaddr_un un_addr;
245         struct socket_info *si = find_socket_info(s);
246
247         if (!si) {
248                 return real_connect(s, serv_addr, addrlen);
249         }
250
251         /* only allow pseudo loopback connections */
252         if (serv_addr->sa_family == AF_INET &&
253                 ((const struct sockaddr_in *)serv_addr)->sin_addr.s_addr != 
254             htonl(INADDR_LOOPBACK)) {
255                 errno = ENETUNREACH;
256                 return -1;
257         }
258
259         ret = sockaddr_convert_to_un(si, (const struct sockaddr *)serv_addr, addrlen, &un_addr);
260         if (ret == -1) return -1;
261
262         ret = real_connect(s, (struct sockaddr *)&un_addr, 
263                            sizeof(struct sockaddr_un));
264
265         if (ret == 0) {
266                 si->peername_len = addrlen;
267                 si->peername = sockaddr_dup(serv_addr, addrlen);
268         }
269
270         return ret;
271 }
272
273 int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen)
274 {
275         int ret;
276         struct sockaddr_un un_addr;
277         struct socket_info *si = find_socket_info(s);
278
279         if (!si) {
280                 return real_bind(s, myaddr, addrlen);
281         }
282
283         ret = sockaddr_convert_to_un(si, (const struct sockaddr *)myaddr, addrlen, &un_addr);
284         if (ret == -1) return -1;
285
286         unlink(un_addr.sun_path);
287
288         ret = real_bind(s, (struct sockaddr *)&un_addr,
289                         sizeof(struct sockaddr_un));
290
291         if (ret == 0) {
292                 si->myname_len = addrlen;
293                 si->myname = sockaddr_dup(myaddr, addrlen);
294                 si->bound = 1;
295         }
296
297         return ret;
298 }
299
300 int swrap_getpeername(int s, struct sockaddr *name, socklen_t *addrlen)
301 {
302         struct socket_info *si = find_socket_info(s);
303
304         if (!si) {
305                 return real_getpeername(s, name, addrlen);
306         }
307
308         if (!si->peername) 
309         {
310                 errno = ENOTCONN;
311                 return -1;
312         }
313
314         memcpy(name, si->peername, si->peername_len);
315         *addrlen = si->peername_len;
316
317         return 0;
318 }
319
320 int swrap_getsockname(int s, struct sockaddr *name, socklen_t *addrlen)
321 {
322         struct socket_info *si = find_socket_info(s);
323
324         if (!si) {
325                 return real_getsockname(s, name, addrlen);
326         }
327
328         memcpy(name, si->myname, si->myname_len);
329         *addrlen = si->myname_len;
330
331         return 0;
332 }
333
334 int swrap_getsockopt(int s, int level, int optname, void *optval, socklen_t *optlen)
335 {
336         struct socket_info *si = find_socket_info(s);
337
338         if (!si) {
339                 return real_getsockopt(s, level, optname, optval, optlen);
340         }
341
342         if (level == SOL_SOCKET) {
343                 return real_getsockopt(s, level, optname, optval, optlen);
344         } 
345
346         switch (si->domain) {
347         case AF_UNIX:
348                 return real_getsockopt(s, level, optname, optval, optlen);
349         default:
350                 errno = ENOPROTOOPT;
351                 return -1;
352         }
353 }
354
355 int swrap_setsockopt(int s, int  level,  int  optname,  const  void  *optval, socklen_t optlen)
356 {
357         struct socket_info *si = find_socket_info(s);
358
359         if (!si) {
360                 return real_setsockopt(s, level, optname, optval, optlen);
361         }
362
363         if (level == SOL_SOCKET) {
364                 return real_setsockopt(s, level, optname, optval, optlen);
365         }
366
367         switch (si->domain) {
368         case AF_UNIX:
369                 return real_setsockopt(s, level, optname, optval, optlen);
370         case AF_INET:
371                 /* Silence some warnings */
372 #ifdef TCP_NODELAY
373                 if (optname == TCP_NODELAY) 
374                         return 0;
375 #endif
376         default:
377                 errno = ENOPROTOOPT;
378                 return -1;
379         }
380 }
381
382 ssize_t swrap_recvfrom(int s, void *buf, size_t len, int flags, struct sockaddr *from, socklen_t *fromlen)
383 {
384         struct sockaddr_un un_addr;
385         socklen_t un_addrlen = sizeof(un_addr);
386         int ret;
387         struct socket_info *si = find_socket_info(s);
388
389         if (!si) {
390                 return real_recvfrom(s, buf, len, flags, from, fromlen);
391         }
392
393         ret = real_recvfrom(s, buf, len, flags, (struct sockaddr *)&un_addr, &un_addrlen);
394         if (ret == -1) 
395                 return ret;
396
397         if (sockaddr_convert_from_un(si, &un_addr, un_addrlen,
398                                      si->domain, from, fromlen) == -1) {
399                 return -1;
400         }
401         
402         return ret;
403 }
404
405 ssize_t swrap_sendto(int  s,  const  void *buf, size_t len, int flags, const struct sockaddr *to, socklen_t tolen)
406 {
407         struct sockaddr_un un_addr;
408         int ret;
409         struct socket_info *si = find_socket_info(s);
410
411         if (!si) {
412                 return real_sendto(s, buf, len, flags, to, tolen);
413         }
414
415         /* using sendto() on an unbound DGRAM socket would give the
416            recipient no way to reply, as unlike UDP, a unix domain socket
417            can't auto-assign emphemeral port numbers, so we need to assign
418            it here */
419         if (si->bound == 0 && si->type == SOCK_DGRAM) {
420                 int i;
421
422                 un_addr.sun_family = AF_UNIX;
423
424                 for (i=0;i<1000;i++) {
425                         snprintf(un_addr.sun_path, sizeof(un_addr.sun_path), 
426                                  "%s/sock_ip_%u_%u", getenv("SOCKET_WRAPPER_DIR"), 
427                                  SOCK_DGRAM, i + 10000);
428                         if (bind(si->fd, (struct sockaddr *)&un_addr, 
429                                  sizeof(un_addr)) == 0) {
430                                 si->tmp_path = strdup(un_addr.sun_path);
431                                 si->bound = 1;
432                                 break;
433                         }
434                 }
435                 if (i == 1000) {
436                         return -1;
437                 }
438         }
439         
440
441         ret = sockaddr_convert_to_un(si, to, tolen, &un_addr);
442         if (ret == -1) return -1;
443
444         ret = real_sendto(s, buf, len, flags, (struct sockaddr *)&un_addr, sizeof(un_addr));
445
446         return ret;
447 }
448
449 int swrap_close(int fd)
450 {
451         struct socket_info *si = find_socket_info(fd);
452
453         if (si) {
454                 DLIST_REMOVE(sockets, si);
455
456                 free(si->path);
457                 free(si->myname);
458                 free(si->peername);
459                 if (si->tmp_path) {
460                         unlink(si->tmp_path);
461                         free(si->tmp_path);
462                 }
463                 free(si);
464         }
465
466         return real_close(fd);
467 }