uwrap: Add support for thread safe functions.
authorAndreas Schneider <asn@cryptomilk.org>
Tue, 30 Jul 2013 13:04:34 +0000 (15:04 +0200)
committerAndreas Schneider <asn@cryptomilk.org>
Sat, 9 Nov 2013 10:42:11 +0000 (11:42 +0100)
This is not fully working yet, but a start.

src/CMakeLists.txt
src/uid_wrapper.c

index abd47011b7a401fa27e24b688104fc6b5effcbbd..e1fb0b2c69394368db717a483a09f27b174bd088 100644 (file)
@@ -2,7 +2,7 @@ project(libuid_wrapper C)
 
 include_directories(${CMAKE_BINARY_DIR})
 add_library(uid_wrapper SHARED uid_wrapper.c)
-target_link_libraries(uid_wrapper ${UIDWRAP_REQUIRED_LIBRARIES})
+target_link_libraries(uid_wrapper ${UIDWRAP_REQUIRED_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
 
 set_target_properties(
   uid_wrapper
index d4596fa6fbc2dba2b9dd678bfdf2583d64d4a772..d42ebe0c307762f701ed0f7ec79c43c2f33cdd85 100644 (file)
@@ -35,6 +35,8 @@
 #endif
 #include <dlfcn.h>
 
+#include <pthread.h>
+
 #ifdef HAVE_GCC_THREAD_LOCAL_STORAGE
 # define UWRAP_THREAD __thread
 #else
 #define UWRAP_DEBUG(...) fprintf(stderr, __VA_ARGS__)
 #endif
 
+#define UWRAP_DLIST_ADD(list,item) do { \
+       if (!(list)) { \
+               (item)->prev    = NULL; \
+               (item)->next    = NULL; \
+               (list)          = (item); \
+       } else { \
+               (item)->prev    = NULL; \
+               (item)->next    = (list); \
+               (list)->prev    = (item); \
+               (list)          = (item); \
+       } \
+} while (0)
+
+#define UWRAP_DLIST_REMOVE(list,item) do { \
+       if ((list) == (item)) { \
+               (list)          = (item)->next; \
+               if (list) { \
+                       (list)->prev    = NULL; \
+               } \
+       } else { \
+               if ((item)->prev) { \
+                       (item)->prev->next      = (item)->next; \
+               } \
+               if ((item)->next) { \
+                       (item)->next->prev      = (item)->prev; \
+               } \
+       } \
+       (item)->prev    = NULL; \
+       (item)->next    = NULL; \
+} while (0)
+
 #define LIBC_NAME "libc.so"
 
 struct uwrap_libc_fns {
@@ -86,23 +119,38 @@ struct uwrap_libc_fns {
 /*
  * We keep the virtualised euid/egid/groups information here
  */
+struct uwrap_thread {
+       pthread_t tid;
+       bool dead;
+
+       uid_t ruid;
+       uid_t euid;
+       uid_t suid;
+
+       gid_t rgid;
+       gid_t egid;
+       gid_t sgid;
+
+       gid_t *groups;
+       int ngroups;
+
+       struct uwrap_thread *next;
+       struct uwrap_thread *prev;
+};
+
 struct uwrap {
        struct {
                void *handle;
                struct uwrap_libc_fns fns;
        } libc;
+
        bool initialised;
        bool enabled;
+
        uid_t myuid;
-       uid_t ruid;
-       uid_t euid;
-       uid_t suid;
        uid_t mygid;
-       gid_t rgid;
-       gid_t egid;
-       gid_t sgid;
-       gid_t *groups;
-       int ngroups;
+
+       struct uwrap_thread *ids;
 };
 
 static struct uwrap uwrap;
@@ -181,11 +229,56 @@ static void uwrap_libc_init(struct uwrap *u)
 #endif
 }
 
+static struct uwrap_thread *find_uwrap_id(pthread_t tid)
+{
+       struct uwrap_thread *id;
+
+       for (id = uwrap.ids; id; id = id->next) {
+               if (pthread_equal(id->tid, tid)) {
+                       return id;
+               }
+       }
+
+       return NULL;
+}
+
+static int uwrap_new_id(pthread_t tid)
+{
+       struct uwrap_thread *id;
+
+       id = malloc(sizeof(struct uwrap_thread));
+       if (id == NULL) {
+               errno = ENOMEM;
+               return -1;
+       }
+
+       id->tid = tid;
+       id->dead = false;
+
+       id->ruid = id->euid = id->suid = uwrap.myuid;
+       id->rgid = id->egid = id->sgid = uwrap.mygid;
+
+       id->ngroups = 1;
+       id->groups = malloc(sizeof(gid_t) * id->ngroups);
+       id->groups[0] = uwrap.mygid;
+
+       UWRAP_DLIST_ADD(uwrap.ids, id);
+
+       return 0;
+}
+
 static void uwrap_init(void)
 {
        const char *env = getenv("UID_WRAPPER");
 
        if (uwrap.initialised) {
+               pthread_t tid = pthread_self();
+               struct uwrap_thread *id = find_uwrap_id(tid);
+
+               if (id == NULL) {
+                       uwrap_new_id(tid);
+               }
+
                return;
        }
 
@@ -196,7 +289,7 @@ static void uwrap_init(void)
 
        if (env != NULL && env[0] == '1') {
                const char *root = getenv("UID_WRAPPER_ROOT");
-               uwrap.enabled = true;
+
                /* put us in one group */
                if (root != NULL && root[0] == '1') {
                        uwrap.myuid = 0;
@@ -206,12 +299,7 @@ static void uwrap_init(void)
                        uwrap.mygid = uwrap.libc.fns._libc_getegid();
                }
 
-               uwrap.ruid = uwrap.euid = uwrap.suid = uwrap.myuid;
-               uwrap.rgid = uwrap.egid = uwrap.sgid = uwrap.mygid;
-
-               uwrap.ngroups = 1;
-               uwrap.groups = malloc(sizeof(gid_t) * uwrap.ngroups);
-               uwrap.groups[0] = uwrap.mygid;
+               uwrap.enabled = true;
        }
 }
 
@@ -222,55 +310,70 @@ static int uwrap_enabled(void)
        return uwrap.enabled ? 1 : 0;
 }
 
-static int uwrap_setresuid(uid_t ruid, uid_t euid, uid_t suid)
+static int uwrap_setresuid_thread(uid_t ruid, uid_t euid, uid_t suid)
 {
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
+
        if (ruid == (uid_t)-1 && euid == (uid_t)-1 && suid == (uid_t)-1) {
                errno = EINVAL;
                return -1;
        }
 
        if (ruid != (uid_t)-1) {
-               uwrap.ruid = ruid;
+               id->ruid = ruid;
        }
 
        if (euid != (uid_t)-1) {
-               uwrap.euid = euid;
+               id->euid = euid;
        }
 
        if (suid != (uid_t)-1) {
-               uwrap.suid = suid;
+               id->suid = suid;
        }
 
        return 0;
 }
 
-/*
- * SETUID
- */
-int setuid(uid_t uid)
+static int uwrap_setresuid(uid_t ruid, uid_t euid, uid_t suid)
 {
-       if (!uwrap_enabled()) {
-               return uwrap.libc.fns._libc_setuid(uid);
+       struct uwrap_thread *id;
+
+       if (ruid == (uid_t)-1 && euid == (uid_t)-1 && suid == (uid_t)-1) {
+               errno = EINVAL;
+               return -1;
        }
 
-       return uwrap_setresuid(uid, -1, -1);
+       for (id = uwrap.ids; id; id = id->next) {
+               if (id->dead) {
+                       continue;
+               }
+
+               if (ruid != (uid_t)-1) {
+                       id->ruid = ruid;
+               }
+
+               if (euid != (uid_t)-1) {
+                       id->euid = euid;
+               }
+
+               if (suid != (uid_t)-1) {
+                       id->suid = suid;
+               }
+       }
+
+       return 0;
 }
 
 /*
- * GETUID
+ * SETUID
  */
-static uid_t uwrap_getuid(void)
-{
-       return uwrap.ruid;
-}
-
-uid_t getuid(void)
+int setuid(uid_t uid)
 {
        if (!uwrap_enabled()) {
-               return uwrap.libc.fns._libc_getuid();
+               return uwrap.libc.fns._libc_setuid(uid);
        }
 
-       return uwrap_getuid();
+       return uwrap_setresuid(uid, -1, -1);
 }
 
 #ifdef HAVE_SETEUID
@@ -316,83 +419,112 @@ int setresuid(uid_t ruid, uid_t euid, uid_t suid)
 }
 #endif
 
-static uid_t uwrap_geteuid(void)
-{
-       return uwrap.euid;
-}
-
-uid_t geteuid(void)
-{
-       if (!uwrap_enabled()) {
-               return uwrap.libc.fns._libc_geteuid();
-       }
-
-       return uwrap_geteuid();
-}
-
 /*
- * SETGID
+ * GETUID
  */
-static int uwrap_setgid(gid_t gid)
+static uid_t uwrap_getuid(void)
 {
-       if (gid == (gid_t)-1) {
-               errno = EINVAL;
-               return -1;
-       }
-
-       uwrap.rgid = gid;
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
 
-       return 0;
+       return id->ruid;
 }
 
-int setgid(gid_t gid)
+uid_t getuid(void)
 {
        if (!uwrap_enabled()) {
-               return uwrap.libc.fns._libc_setgid(gid);
+               return uwrap.libc.fns._libc_getuid();
        }
 
-       return uwrap_setgid(gid);
+       return uwrap_getuid();
 }
 
 /*
- * GETGID
+ * GETEUID
  */
-static gid_t uwrap_getgid(void)
+static uid_t uwrap_geteuid(void)
 {
-       return uwrap.rgid;
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
+
+       return id->euid;
 }
 
-gid_t getgid(void)
+uid_t geteuid(void)
 {
        if (!uwrap_enabled()) {
-               return uwrap.libc.fns._libc_getgid();
+               return uwrap.libc.fns._libc_geteuid();
        }
 
-       return uwrap_getgid();
+       return uwrap_geteuid();
 }
 
-static int uwrap_setresgid(gid_t rgid, gid_t egid, gid_t sgid)
+static int uwrap_setresgid_thread(gid_t rgid, gid_t egid, gid_t sgid)
 {
+       struct uwrap_thread *id;
+
        if (rgid == (gid_t)-1 && egid == (gid_t)-1 && sgid == (gid_t)-1) {
                errno = EINVAL;
                return -1;
        }
 
+       id = find_uwrap_id(pthread_self());
+
        if (rgid != (gid_t)-1) {
-               uwrap.rgid = rgid;
+               id->rgid = rgid;
        }
 
        if (egid != (gid_t)-1) {
-               uwrap.egid = egid;
+               id->egid = egid;
        }
 
        if (sgid != (gid_t)-1) {
-               uwrap.sgid = sgid;
+               id->sgid = sgid;
        }
 
        return 0;
 }
 
+static int uwrap_setresgid(gid_t rgid, gid_t egid, gid_t sgid)
+{
+       struct uwrap_thread *id;
+
+       if (rgid == (gid_t)-1 && egid == (gid_t)-1 && sgid == (gid_t)-1) {
+               errno = EINVAL;
+               return -1;
+       }
+
+       for (id = uwrap.ids; id; id = id->next) {
+               if (id->dead) {
+                       continue;
+               }
+
+               if (rgid != (gid_t)-1) {
+                       id->rgid = rgid;
+               }
+
+               if (egid != (gid_t)-1) {
+                       id->egid = egid;
+               }
+
+               if (sgid != (gid_t)-1) {
+                       id->sgid = sgid;
+               }
+       }
+
+       return 0;
+}
+
+/*
+ * SETGID
+ */
+int setgid(gid_t gid)
+{
+       if (!uwrap_enabled()) {
+               return uwrap.libc.fns._libc_setgid(gid);
+       }
+
+       return uwrap_setresgid(gid, -1, -1);
+}
+
 #ifdef HAVE_SETEGID
 int setegid(gid_t egid)
 {
@@ -426,9 +558,33 @@ int setresgid(gid_t rgid, gid_t egid, gid_t sgid)
 }
 #endif
 
+/*
+ * GETGID
+ */
+static gid_t uwrap_getgid(void)
+{
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
+
+       return id->rgid;
+}
+
+gid_t getgid(void)
+{
+       if (!uwrap_enabled()) {
+               return uwrap.libc.fns._libc_getgid();
+       }
+
+       return uwrap_getgid();
+}
+
+/*
+ * GETEGID
+ */
 static uid_t uwrap_getegid(void)
 {
-       return uwrap.egid;
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
+
+       return id->egid;
 }
 
 uid_t getegid(void)
@@ -440,20 +596,22 @@ uid_t getegid(void)
        return uwrap_getegid();
 }
 
-static int uwrap_setgroups(size_t size, const gid_t *list)
+static int uwrap_setgroups_thread(size_t size, const gid_t *list)
 {
-       free(uwrap.groups);
-       uwrap.groups = NULL;
-       uwrap.ngroups = 0;
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
+
+       free(id->groups);
+       id->groups = NULL;
+       id->ngroups = 0;
 
        if (size != 0) {
-               uwrap.groups = malloc(sizeof(gid_t) * size);
-               if (uwrap.groups == NULL) {
+               id->groups = malloc(sizeof(gid_t) * size);
+               if (id->groups == NULL) {
                        errno = ENOMEM;
                        return -1;
                }
-               uwrap.ngroups = size;
-               memcpy(uwrap.groups, list, size*sizeof(gid_t));
+               id->ngroups = size;
+               memcpy(id->groups, list, size * sizeof(gid_t));
        }
 
        return 0;
@@ -465,14 +623,15 @@ int setgroups(size_t size, const gid_t *list)
                return uwrap.libc.fns._libc_setgroups(size, list);
        }
 
-       return uwrap_setgroups(size, list);
+       return uwrap_setgroups_thread(size, list);
 }
 
 static int uwrap_getgroups(int size, gid_t *list)
 {
+       struct uwrap_thread *id = find_uwrap_id(pthread_self());
        int ngroups;
 
-       ngroups = uwrap.ngroups;
+       ngroups = id->ngroups;
 
        if (size > ngroups) {
                size = ngroups;
@@ -484,7 +643,7 @@ static int uwrap_getgroups(int size, gid_t *list)
                errno = EINVAL;
                return -1;
        }
-       memcpy(list, uwrap.groups, size*sizeof(gid_t));
+       memcpy(list, id->groups, size * sizeof(gid_t));
 
        return ngroups;
 }
@@ -529,6 +688,22 @@ static long int uwrap_syscall (long int sysno, va_list vp)
 
        switch (sysno) {
                /* gid */
+               case SYS_getgid:
+#ifdef HAVE_LINUX_32BIT_SYSCALLS
+               case SYS_getgid32:
+#endif
+                       {
+                               rc = uwrap_getgid();
+                       }
+                       break;
+               case SYS_getegid:
+#ifdef HAVE_LINUX_32BIT_SYSCALLS
+               case SYS_getegid32:
+#endif
+                       {
+                               rc = uwrap_getegid();
+                       }
+                       break;
                case SYS_setgid:
 #ifdef HAVE_LINUX_32BIT_SYSCALLS
                case SYS_setgid32:
@@ -536,7 +711,7 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                        {
                                gid_t gid = (gid_t) va_arg(vp, int);
 
-                               rc = uwrap_setresgid(gid, -1, -1);
+                               rc = uwrap_setresgid_thread(gid, -1, -1);
                        }
                        break;
                case SYS_setregid:
@@ -547,7 +722,7 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                                uid_t rgid = (uid_t) va_arg(vp, int);
                                uid_t egid = (uid_t) va_arg(vp, int);
 
-                               rc = uwrap_setresgid(rgid, egid, -1);
+                               rc = uwrap_setresgid_thread(rgid, egid, -1);
                        }
                        break;
                case SYS_setresgid:
@@ -559,11 +734,27 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                                uid_t egid = (uid_t) va_arg(vp, int);
                                uid_t sgid = (uid_t) va_arg(vp, int);
 
-                               rc = uwrap_setresgid(rgid, egid, sgid);
+                               rc = uwrap_setresgid_thread(rgid, egid, sgid);
                        }
                        break;
 
                /* uid */
+               case SYS_getuid:
+#ifdef HAVE_LINUX_32BIT_SYSCALLS
+               case SYS_getuid32:
+#endif
+                       {
+                               rc = uwrap_getuid();
+                       }
+                       break;
+               case SYS_geteuid:
+#ifdef HAVE_LINUX_32BIT_SYSCALLS
+               case SYS_geteuid32:
+#endif
+                       {
+                               rc = uwrap_geteuid();
+                       }
+                       break;
                case SYS_setuid:
 #ifdef HAVE_LINUX_32BIT_SYSCALLS
                case SYS_setuid32:
@@ -571,7 +762,7 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                        {
                                uid_t uid = (uid_t) va_arg(vp, int);
 
-                               rc = uwrap_setresuid(uid, -1, -1);
+                               rc = uwrap_setresuid_thread(uid, -1, -1);
                        }
                        break;
                case SYS_setreuid:
@@ -582,7 +773,7 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                                uid_t ruid = (uid_t) va_arg(vp, int);
                                uid_t euid = (uid_t) va_arg(vp, int);
 
-                               rc = uwrap_setresuid(ruid, euid, -1);
+                               rc = uwrap_setresuid_thread(ruid, euid, -1);
                        }
                        break;
                case SYS_setresuid:
@@ -594,7 +785,7 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                                uid_t euid = (uid_t) va_arg(vp, int);
                                uid_t suid = (uid_t) va_arg(vp, int);
 
-                               rc = uwrap_setresuid(ruid, euid, suid);
+                               rc = uwrap_setresuid_thread(ruid, euid, suid);
                        }
                        break;
 
@@ -607,7 +798,7 @@ static long int uwrap_syscall (long int sysno, va_list vp)
                                size_t size = (size_t) va_arg(vp, size_t);
                                gid_t *list = (gid_t *) va_arg(vp, int *);
 
-                               rc = uwrap_setgroups(size, list);
+                               rc = uwrap_setgroups_thread(size, list);
                        }
                        break;
                default: