s3-talloc Change TALLOC_REALLOC_ARRAY() to talloc_realloc()
[samba.git] / source3 / smbd / session.c
index 3cc93c1a7f42d46eba012ffb87bc13523439a860..12d4818e90ca8c548c5e2b5572911ea459e1fe82 100644 (file)
 */
 
 #include "includes.h"
-
-/********************************************************************
-********************************************************************/
-
-static struct db_context *session_db_ctx(void)
-{
-       static struct db_context *ctx;
-
-       if (ctx)
-               return ctx;
-
-       ctx = db_open(NULL, lock_path("sessionid.tdb"), 0,
-                     TDB_CLEAR_IF_FIRST|TDB_DEFAULT, 
-                     O_RDWR | O_CREAT, 0644);
-       return ctx;
-}
-
-BOOL session_init(void)
-{
-       if (session_db_ctx() == NULL) {
-               DEBUG(1,("session_init: failed to open sessionid tdb\n"));
-               return False;
-       }
-
-       return True;
-}
+#include "smbd/smbd.h"
+#include "smbd/globals.h"
+#include "dbwrap.h"
+#include "session.h"
+#include "auth.h"
 
 /********************************************************************
  called when a session is created
 ********************************************************************/
 
-BOOL session_claim(user_struct *vuser)
+bool session_claim(struct smbd_server_connection *sconn, user_struct *vuser)
 {
-       TDB_DATA key, data;
+       struct server_id pid = sconn_server_id(sconn);
+       TDB_DATA data;
        int i = 0;
-       struct sockaddr sa;
-       struct in_addr *client_ip;
        struct sessionid sessionid;
-       struct server_id pid = procid_self();
        fstring keystr;
-       char * hostname;
-       struct db_context *ctx;
        struct db_record *rec;
        NTSTATUS status;
 
@@ -76,11 +51,11 @@ BOOL session_claim(user_struct *vuser)
 
        /* don't register sessions for the guest user - its just too
           expensive to go through pam session code for browsing etc */
-       if (vuser->guest) {
+       if (vuser->session_info->guest) {
                return True;
        }
 
-       if (!(ctx = session_db_ctx())) {
+       if (!sessionid_init()) {
                return False;
        }
 
@@ -100,10 +75,8 @@ BOOL session_claim(user_struct *vuser)
                        struct server_id sess_pid;
 
                        snprintf(keystr, sizeof(keystr), "ID/%d", i);
-                       key = string_term_tdb_data(keystr);
-
-                       rec = ctx->fetch_locked(ctx, NULL, key);
 
+                       rec = sessionid_fetch_record(NULL, keystr);
                        if (rec == NULL) {
                                DEBUG(1, ("Could not lock \"%s\"\n", keystr));
                                return False;
@@ -114,7 +87,10 @@ BOOL session_claim(user_struct *vuser)
                                break;
                        }
 
-                       sess_pid = ((struct sessionid *)rec->value.dptr)->pid;
+                       memcpy(&sess_pid,
+                              ((char *)rec->value.dptr)
+                              + offsetof(struct sessionid, pid),
+                              sizeof(sess_pid));
 
                        if (!process_exists(sess_pid)) {
                                DEBUG(5, ("%s has died -- re-using session\n",
@@ -124,7 +100,7 @@ BOOL session_claim(user_struct *vuser)
 
                        TALLOC_FREE(rec);
                }
-               
+
                if (i == MAX_SESSION_ID) {
                        SMB_ASSERT(rec == NULL);
                        DEBUG(1,("session_claim: out of session IDs "
@@ -138,45 +114,36 @@ BOOL session_claim(user_struct *vuser)
        {
                snprintf(keystr, sizeof(keystr), "ID/%s/%u",
                         procid_str_static(&pid), vuser->vuid);
-               key = string_term_tdb_data(keystr);
-
-               rec = ctx->fetch_locked(ctx, NULL, key);
 
+               rec = sessionid_fetch_record(NULL, keystr);
                if (rec == NULL) {
                        DEBUG(1, ("Could not lock \"%s\"\n", keystr));
                        return False;
                }
 
-               snprintf(sessionid.id_str, sizeof(sessionid.id_str), 
-                        SESSION_TEMPLATE, (long unsigned int)sys_getpid(), 
+               snprintf(sessionid.id_str, sizeof(sessionid.id_str),
+                        SESSION_TEMPLATE, (long unsigned int)sys_getpid(),
                         vuser->vuid);
        }
 
        SMB_ASSERT(rec != NULL);
 
        /* If 'hostname lookup' == yes, then do the DNS lookup.  This is
-           needed because utmp and PAM both expect DNS names 
-          
+           needed because utmp and PAM both expect DNS names
+
           client_name() handles this case internally.
        */
 
-       hostname = client_name();
-       if (strcmp(hostname, "UNKNOWN") == 0) {
-               hostname = client_addr();
-       }
-
-       fstrcpy(sessionid.username, vuser->user.unix_name);
-       fstrcpy(sessionid.hostname, hostname);
+       fstrcpy(sessionid.username, vuser->session_info->unix_name);
+       fstrcpy(sessionid.hostname, sconn->client_id.name);
        sessionid.id_num = i;  /* Only valid for utmp sessions */
        sessionid.pid = pid;
-       sessionid.uid = vuser->uid;
-       sessionid.gid = vuser->gid;
+       sessionid.uid = vuser->session_info->utok.uid;
+       sessionid.gid = vuser->session_info->utok.gid;
        fstrcpy(sessionid.remote_machine, get_remote_machine_name());
-       fstrcpy(sessionid.ip_addr, client_addr());
+       fstrcpy(sessionid.ip_addr_str, sconn->client_id.addr);
        sessionid.connect_start = time(NULL);
 
-       client_ip = client_inaddr(&sa);
-
        if (!smb_pam_claim_session(sessionid.username, sessionid.id_str,
                                   sessionid.hostname)) {
                DEBUG(1,("pam_session rejected the session for %s [%s]\n",
@@ -200,17 +167,14 @@ BOOL session_claim(user_struct *vuser)
        }
 
        if (lp_utmp()) {
-               sys_utmp_claim(sessionid.username, sessionid.hostname, 
-                              client_ip,
+               sys_utmp_claim(sessionid.username, sessionid.hostname,
+                              sessionid.ip_addr_str,
                               sessionid.id_str, sessionid.id_num);
        }
 
-       TALLOC_FREE(rec);
-
        vuser->session_keystr = talloc_strdup(vuser, keystr);
        if (!vuser->session_keystr) {
-               DEBUG(0, ("session_claim:  talloc_strdup() failed for "
-                         "session_keystr\n"));
+               DEBUG(0, ("session_claim:  talloc_strdup() failed for session_keystr\n"));
                return False;
        }
        return True;
@@ -222,21 +186,15 @@ BOOL session_claim(user_struct *vuser)
 
 void session_yield(user_struct *vuser)
 {
-       TDB_DATA key;
        struct sessionid sessionid;
-       struct in_addr *client_ip;
-       struct db_context *ctx;
        struct db_record *rec;
 
-       if (!(ctx = session_db_ctx())) return;
-
        if (!vuser->session_keystr) {
                return;
        }
 
-       key = string_term_tdb_data(vuser->session_keystr);
-
-       if (!(rec = ctx->fetch_locked(ctx, NULL, key))) {
+       rec = sessionid_fetch_record(NULL, vuser->session_keystr);
+       if (rec == NULL) {
                return;
        }
 
@@ -245,11 +203,9 @@ void session_yield(user_struct *vuser)
 
        memcpy(&sessionid, rec->value.dptr, sizeof(sessionid));
 
-       client_ip = interpret_addr2(sessionid.ip_addr);
-
        if (lp_utmp()) {
                sys_utmp_yield(sessionid.username, sessionid.hostname, 
-                              client_ip,
+                              sessionid.ip_addr_str,
                               sessionid.id_str, sessionid.id_num);
        }
 
@@ -264,37 +220,18 @@ void session_yield(user_struct *vuser)
 /********************************************************************
 ********************************************************************/
 
-static BOOL session_traverse(int (*fn)(struct db_record *db,
-                                      void *private_data),
-                            void *private_data)
-{
-       struct db_context *ctx;
-
-       if (!(ctx = session_db_ctx())) {
-               DEBUG(3, ("No tdb opened\n"));
-               return False;
-       }
-
-       ctx->traverse_read(ctx, fn, private_data);
-       return True;
-}
-
-/********************************************************************
-********************************************************************/
-
 struct session_list {
        TALLOC_CTX *mem_ctx;
        int count;
        struct sessionid *sessions;
 };
 
-static int gather_sessioninfo(struct db_record *rec, void *state)
+static int gather_sessioninfo(const char *key, struct sessionid *session,
+                             void *private_data)
 {
-       struct session_list *sesslist = (struct session_list *) state;
-       const struct sessionid *current =
-               (const struct sessionid *) rec->value.dptr;
+       struct session_list *sesslist = (struct session_list *)private_data;
 
-       sesslist->sessions = TALLOC_REALLOC_ARRAY(
+       sesslist->sessions = talloc_realloc(
                sesslist->mem_ctx, sesslist->sessions, struct sessionid,
                sesslist->count+1);
 
@@ -303,13 +240,13 @@ static int gather_sessioninfo(struct db_record *rec, void *state)
                return -1;
        }
 
-       memcpy(&sesslist->sessions[sesslist->count], current,
+       memcpy(&sesslist->sessions[sesslist->count], session,
               sizeof(struct sessionid));
 
        sesslist->count++;
 
-       DEBUG(7,("gather_sessioninfo session from %s@%s\n", 
-                current->username, current->remote_machine));
+       DEBUG(7, ("gather_sessioninfo session from %s@%s\n",
+                 session->username, session->remote_machine));
 
        return 0;
 }
@@ -320,12 +257,14 @@ static int gather_sessioninfo(struct db_record *rec, void *state)
 int list_sessions(TALLOC_CTX *mem_ctx, struct sessionid **session_list)
 {
        struct session_list sesslist;
+       int ret;
 
        sesslist.mem_ctx = mem_ctx;
        sesslist.count = 0;
        sesslist.sessions = NULL;
-       
-       if (!session_traverse(gather_sessioninfo, (void *) &sesslist)) {
+
+       ret = sessionid_traverse_read(gather_sessioninfo, (void *) &sesslist);
+       if (ret == -1) {
                DEBUG(3, ("Session traverse failed\n"));
                SAFE_FREE(sesslist.sessions);
                *session_list = NULL;