rwrap: Add initial support for res_[n]init().
[obnox/cwrap/resolv_wrapper.git] / src / resolv_wrapper.c
1 /*
2  * Copyright (c) 2014      Andreas Schneider <asn@samba.org>
3  *
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * 3. Neither the name of the author nor the names of its contributors
18  *    may be used to endorse or promote products derived from this software
19  *    without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31  * SUCH DAMAGE.
32  */
33
34 #include "config.h"
35
36 #include <arpa/inet.h>
37 #include <sys/types.h>
38 #include <stdarg.h>
39 #include <stdlib.h>
40 #include <stdio.h>
41 #include <stdbool.h>
42 #include <string.h>
43 #include <unistd.h>
44
45 #include <resolv.h>
46
47 /* GCC has printf type attribute check. */
48 #ifdef HAVE_ATTRIBUTE_PRINTF_FORMAT
49 #define PRINTF_ATTRIBUTE(a,b) __attribute__ ((__format__ (__printf__, a, b)))
50 #else
51 #define PRINTF_ATTRIBUTE(a,b)
52 #endif /* HAVE_ATTRIBUTE_PRINTF_FORMAT */
53
54 #ifdef HAVE_DESTRUCTOR_ATTRIBUTE
55 #define DESTRUCTOR_ATTRIBUTE __attribute__ ((destructor))
56 #else
57 #define DESTRUCTOR_ATTRIBUTE
58 #endif /* HAVE_DESTRUCTOR_ATTRIBUTE */
59
60 enum rwrap_dbglvl_e {
61         RWRAP_LOG_ERROR = 0,
62         RWRAP_LOG_WARN,
63         RWRAP_LOG_DEBUG,
64         RWRAP_LOG_TRACE
65 };
66
67 #ifdef NDEBUG
68 # define RWRAP_LOG(...)
69 #else
70
71 static void rwrap_log(enum rwrap_dbglvl_e dbglvl, const char *func, const char *format, ...) PRINTF_ATTRIBUTE(3, 4);
72 # define RWRAP_LOG(dbglvl, ...) rwrap_log((dbglvl), __func__, __VA_ARGS__)
73
74 static void rwrap_log(enum rwrap_dbglvl_e dbglvl,
75                       const char *func,
76                       const char *format, ...)
77 {
78         char buffer[1024];
79         va_list va;
80         const char *d;
81         unsigned int lvl = 0;
82         int pid = getpid();
83
84         d = getenv("RESOLV_WRAPPER_DEBUGLEVEL");
85         if (d != NULL) {
86                 lvl = atoi(d);
87         }
88
89         va_start(va, format);
90         vsnprintf(buffer, sizeof(buffer), format, va);
91         va_end(va);
92
93         if (lvl >= dbglvl) {
94                 switch (dbglvl) {
95                         case RWRAP_LOG_ERROR:
96                                 fprintf(stderr,
97                                         "RWRAP_ERROR(%d) - %s: %s\n",
98                                         pid, func, buffer);
99                                 break;
100                         case RWRAP_LOG_WARN:
101                                 fprintf(stderr,
102                                         "RWRAP_WARN(%d) - %s: %s\n",
103                                         pid, func, buffer);
104                                 break;
105                         case RWRAP_LOG_DEBUG:
106                                 fprintf(stderr,
107                                         "RWRAP_DEBUG(%d) - %s: %s\n",
108                                         pid, func, buffer);
109                                 break;
110                         case RWRAP_LOG_TRACE:
111                                 fprintf(stderr,
112                                         "RWRAP_TRACE(%d) - %s: %s\n",
113                                         pid, func, buffer);
114                                 break;
115                 }
116         }
117 }
118 #endif /* NDEBUG RWRAP_LOG */
119
120 /*********************************************************
121  * RWRAP LOADING LIBC FUNCTIONS
122  *********************************************************/
123
124 #include <dlfcn.h>
125
126 struct rwrap_libc_fns {
127         int (*libc_res_init)(void);
128         int (*libc___res_init)(void);
129         int (*libc_res_ninit)(struct __res_state *state);
130         int (*libc___res_ninit)(struct __res_state *state);
131 };
132
133 struct rwrap {
134         void *libc_handle;
135         void *libresolv_handle;
136
137         bool initialised;
138         bool enabled;
139
140         char *socket_dir;
141
142         struct rwrap_libc_fns fns;
143 };
144
145 static struct rwrap rwrap;
146
147 enum rwrap_lib {
148     RWRAP_LIBC,
149     RWRAP_LIBRESOLV
150 };
151
152 #ifndef NDEBUG
153 static const char *rwrap_str_lib(enum rwrap_lib lib)
154 {
155         switch (lib) {
156         case RWRAP_LIBC:
157                 return "libc";
158         case RWRAP_LIBRESOLV:
159                 return "libresolv";
160         }
161
162         /* Compiler would warn us about unhandled enum value if we get here */
163         return "unknown";
164 }
165 #endif
166
167 static void *rwrap_load_lib_handle(enum rwrap_lib lib)
168 {
169         int flags = RTLD_LAZY;
170         void *handle = NULL;
171         int i;
172
173 #ifdef RTLD_DEEPBIND
174         flags |= RTLD_DEEPBIND;
175 #endif
176
177         switch (lib) {
178         case RWRAP_LIBRESOLV:
179 #ifdef HAVE_LIBRESOLV
180                 handle = rwrap.libresolv_handle;
181                 if (handle == NULL) {
182                         for (i = 10; i >= 0; i--) {
183                                 char soname[256] = {0};
184
185                                 snprintf(soname, sizeof(soname), "libresolv.so.%d", i);
186                                 handle = dlopen(soname, flags);
187                                 if (handle != NULL) {
188                                         break;
189                                 }
190                         }
191
192                         rwrap.libresolv_handle = handle;
193                 }
194                 break;
195 #endif
196                 /* FALL TROUGH */
197         case RWRAP_LIBC:
198                 handle = rwrap.libc_handle;
199 #ifdef LIBC_SO
200                 if (handle == NULL) {
201                         handle = dlopen(LIBC_SO, flags);
202
203                         rwrap.libc_handle = handle;
204                 }
205 #endif
206                 if (handle == NULL) {
207                         for (i = 10; i >= 0; i--) {
208                                 char soname[256] = {0};
209
210                                 snprintf(soname, sizeof(soname), "libc.so.%d", i);
211                                 handle = dlopen(soname, flags);
212                                 if (handle != NULL) {
213                                         break;
214                                 }
215                         }
216
217                         rwrap.libc_handle = handle;
218                 }
219                 break;
220         }
221
222         if (handle == NULL) {
223 #ifdef RTLD_NEXT
224                 handle = rwrap.libc_handle = rwrap.libresolv_handle = RTLD_NEXT;
225 #else
226                 RWRAP_LOG(RWRAP_LOG_ERROR,
227                           "Failed to dlopen library: %s\n",
228                           dlerror());
229                 exit(-1);
230 #endif
231         }
232
233         return handle;
234 }
235
236 static void *_rwrap_load_lib_function(enum rwrap_lib lib, const char *fn_name)
237 {
238         void *handle;
239         void *func;
240
241         handle = rwrap_load_lib_handle(lib);
242
243         func = dlsym(handle, fn_name);
244         if (func == NULL) {
245                 RWRAP_LOG(RWRAP_LOG_ERROR,
246                                 "Failed to find %s: %s\n",
247                                 fn_name, dlerror());
248                 exit(-1);
249         }
250
251         RWRAP_LOG(RWRAP_LOG_TRACE,
252                         "Loaded %s from %s",
253                         fn_name, rwrap_str_lib(lib));
254         return func;
255 }
256
257 #define rwrap_load_lib_function(lib, fn_name) \
258         if (rwrap.fns.libc_##fn_name == NULL) { \
259                 *(void **) (&rwrap.fns.libc_##fn_name) = \
260                         _rwrap_load_lib_function(lib, #fn_name); \
261         }
262
263 /*
264  * IMPORTANT
265  *
266  * Functions especially from libc need to be loaded individually, you can't load
267  * all at once or gdb will segfault at startup. The same applies to valgrind and
268  * has probably something todo with with the linker.
269  * So we need load each function at the point it is called the first time.
270  */
271 #if 0
272 static int libc_res_init(void)
273 {
274 #if defined(HAVE_RES_INIT)
275         rwrap_load_lib_function(RWRAP_LIBRESOLV, res_init);
276
277         return rwrap.fns.libc_res_init();
278 #elif defined(HAVE___RES_INIT)
279         rwrap_load_lib_function(RWRAP_LIBRESOLV, __res_init);
280
281         return rwrap.fns.libc___res_init();
282 #endif
283 }
284 #endif
285
286 static int libc_res_ninit(struct __res_state *state)
287 {
288 #if defined(HAVE_RES_NINIT)
289         rwrap_load_lib_function(RWRAP_LIBC, res_ninit);
290
291         return rwrap.fns.libc_res_ninit(state);
292 #elif defined(HAVE___RES_NINIT)
293         rwrap_load_lib_function(RWRAP_LIBC, __res_ninit);
294
295         return rwrap.fns.libc___res_ninit(state);
296 #else
297 #error "No res_ninit function"
298 #endif
299 }
300
301 /****************************************************************************
302  *   RES_NINIT
303  ***************************************************************************/
304
305 static int rwrap_res_ninit(struct __res_state *state)
306 {
307         int rc;
308
309         rc = libc_res_ninit(state);
310         if (rc == 0) {
311                 const char *rwrap_ns_env = getenv("RESOLV_WRAPPER_NAMESERVER");
312
313                 if (rwrap_ns_env != NULL) {
314                         int ok;
315
316                         /* Delete name servers */
317                         state->nscount = 1;
318                         memset(state->nsaddr_list, 0, sizeof(state->nsaddr_list));
319
320                         /* Simply zero the the padding array in the union */
321                         memset(state->_u.pad, 0, sizeof(state->_u.pad));
322
323                         state->nsaddr_list[0] = (struct sockaddr_in) {
324                                 .sin_family = AF_INET,
325                                 .sin_port = htons(53),
326                         };
327
328                         ok = inet_pton(AF_INET, rwrap_ns_env, &state->nsaddr_list[0].sin_addr);
329                         if (!ok) {
330                                 return -1;
331                         }
332
333                         RWRAP_LOG(RWRAP_LOG_DEBUG,
334                                   "Using [%s] as new nameserver",
335                                   rwrap_ns_env);
336                 }
337         }
338
339         return rc;
340 }
341
342 #if defined(HAVE_RES_NINIT)
343 int res_ninit(struct __res_state *state)
344 #elif defined(HAVE___RES_NINIT)
345 int __res_ninit(struct __res_state *state)
346 #endif
347 {
348         return rwrap_res_ninit(state);
349 }
350
351 /****************************************************************************
352  *   RES_INIT
353  ***************************************************************************/
354
355 static struct __res_state rwrap_res_state;
356
357 static int rwrap_res_init(void)
358 {
359         int rc;
360
361         rc = rwrap_res_ninit(&rwrap_res_state);
362
363         return rc;
364 }
365
366 #if defined(HAVE_RES_INIT)
367 int res_init(void)
368 #elif defined(HAVE___RES_INIT)
369 int __res_init(void)
370 #endif
371 {
372         return rwrap_res_init();
373 }