r17516: Change helper function names to make more clear what they are meant to do
[kamenim/samba.git] / source4 / auth / gensec / schannel_state.c
index 7ef64ca00b33bbb1b5fa650b8770fe0d8ef3e702..e8d0f8388b8b40605fba8d301e86a40ef66dc092 100644 (file)
 #include "lib/ldb/include/ldb.h"
 #include "lib/ldb/include/ldb_errors.h"
 #include "dsdb/samdb/samdb.h"
+#include "db_wrap.h"
 
-/*
+/**
   connect to the schannel ldb
 */
-static struct ldb_context *schannel_db_connect(TALLOC_CTX *mem_ctx)
+struct ldb_context *schannel_db_connect(TALLOC_CTX *mem_ctx)
 {
        char *path;
        struct ldb_context *ldb;
@@ -44,7 +45,7 @@ static struct ldb_context *schannel_db_connect(TALLOC_CTX *mem_ctx)
                return NULL;
        }
 
-       existed = file_exists(path);
+       existed = file_exist(path);
        
        ldb = ldb_wrap_connect(mem_ctx, path, system_session(mem_ctx), 
                               NULL, LDB_FLG_NOSYNC, NULL);
@@ -64,44 +65,35 @@ static struct ldb_context *schannel_db_connect(TALLOC_CTX *mem_ctx)
   remember an established session key for a netr server authentication
   use a simple ldb structure
 */
-NTSTATUS schannel_store_session_key(TALLOC_CTX *mem_ctx,
-                                   struct creds_CredentialState *creds)
+NTSTATUS schannel_store_session_key_ldb(TALLOC_CTX *mem_ctx,
+                                       struct ldb_context *ldb,
+                                       struct creds_CredentialState *creds)
 {
-       struct ldb_context *ldb;
        struct ldb_message *msg;
-       struct ldb_val val, seed;
+       struct ldb_val val, seed, client_state, server_state;
        char *f;
        char *sct;
        int ret;
 
-       ldb = schannel_db_connect(mem_ctx);
-       if (ldb == NULL) {
-               return NT_STATUS_NO_MEMORY;
-       }
-
        f = talloc_asprintf(mem_ctx, "%u", (unsigned int)creds->negotiate_flags);
 
        if (f == NULL) {
-               talloc_free(ldb);
                return NT_STATUS_NO_MEMORY;
        }
 
        sct = talloc_asprintf(mem_ctx, "%u", (unsigned int)creds->secure_channel_type);
 
        if (sct == NULL) {
-               talloc_free(ldb);
                return NT_STATUS_NO_MEMORY;
        }
 
        msg = ldb_msg_new(ldb);
        if (msg == NULL) {
-               talloc_free(ldb);
                return NT_STATUS_NO_MEMORY;
        }
 
        msg->dn = ldb_dn_build_child(msg, "computerName", creds->computer_name, NULL);
        if (msg->dn == NULL) {
-               talloc_free(ldb);
                return NT_STATUS_NO_MEMORY;
        }
 
@@ -111,9 +103,16 @@ NTSTATUS schannel_store_session_key(TALLOC_CTX *mem_ctx,
        seed.data = creds->seed.data;
        seed.length = sizeof(creds->seed.data);
 
+       client_state.data = creds->client.data;
+       client_state.length = sizeof(creds->client.data);
+       server_state.data = creds->server.data;
+       server_state.length = sizeof(creds->server.data);
+
        ldb_msg_add_string(msg, "objectClass", "schannelState");
        ldb_msg_add_value(msg, "sessionKey", &val);
        ldb_msg_add_value(msg, "seed", &seed);
+       ldb_msg_add_value(msg, "clientState", &client_state);
+       ldb_msg_add_value(msg, "serverState", &server_state);
        ldb_msg_add_string(msg, "negotiateFlags", f);
        ldb_msg_add_string(msg, "secureChannelType", sct);
        ldb_msg_add_string(msg, "accountName", creds->account_name);
@@ -121,49 +120,65 @@ NTSTATUS schannel_store_session_key(TALLOC_CTX *mem_ctx,
        ldb_msg_add_string(msg, "flatname", creds->domain);
        samdb_msg_add_dom_sid(ldb, mem_ctx, msg, "objectSid", creds->sid);
 
-       ret = ldb_transaction_start(ldb);
+       ldb_delete(ldb, msg->dn);
+
+       ret = ldb_add(ldb, msg);
+
        if (ret != 0) {
-               DEBUG(0,("Unable to start transaction to add %s to session key db - %s\n", 
+               DEBUG(0,("Unable to add %s to session key db - %s\n", 
                         ldb_dn_linearize(msg, msg->dn), ldb_errstring(ldb)));
-               talloc_free(ldb);
                return NT_STATUS_INTERNAL_DB_CORRUPTION;
        }
 
-       ldb_delete(ldb, msg->dn);
+       return NT_STATUS_OK;
+}
 
-       ret = ldb_add(ldb, msg);
+NTSTATUS schannel_store_session_key(TALLOC_CTX *mem_ctx,
+                                   struct creds_CredentialState *creds)
+{
+       struct ldb_context *ldb;
+       NTSTATUS nt_status;
+       int ret;
+               
+       ldb = schannel_db_connect(mem_ctx);
+       if (!ldb) {
+               return NT_STATUS_ACCESS_DENIED;
+       }
 
+       ret = ldb_transaction_start(ldb);
        if (ret != 0) {
-               DEBUG(0,("Unable to add %s to session key db - %s\n", 
-                        ldb_dn_linearize(msg, msg->dn), ldb_errstring(ldb)));
                talloc_free(ldb);
                return NT_STATUS_INTERNAL_DB_CORRUPTION;
        }
 
-       ret = ldb_transaction_commit(ldb);
+       nt_status = schannel_store_session_key_ldb(mem_ctx, ldb, creds);
+
+       if (NT_STATUS_IS_OK(nt_status)) {
+               ret = ldb_transaction_commit(ldb);
+       } else {
+               ret = ldb_transaction_cancel(ldb);
+       }
 
        if (ret != 0) {
-               DEBUG(0,("Unable to commit adding %s to session key db - %s\n", 
-                        ldb_dn_linearize(msg, msg->dn), ldb_errstring(ldb)));
+               DEBUG(0,("Unable to commit adding credentials for %s to schannel key db - %s\n", 
+                        creds->computer_name, ldb_errstring(ldb)));
                talloc_free(ldb);
                return NT_STATUS_INTERNAL_DB_CORRUPTION;
        }
 
        talloc_free(ldb);
-
-       return NT_STATUS_OK;
+       return nt_status;
 }
 
-
 /*
   read back a credentials back for a computer
 */
-NTSTATUS schannel_fetch_session_key(TALLOC_CTX *mem_ctx,
-                                   const char *computer_name, 
-                                   const char *domain,
-                                   struct creds_CredentialState **creds)
+NTSTATUS schannel_fetch_session_key_ldb(TALLOC_CTX *mem_ctx,
+                                       struct ldb_context *ldb,
+                                       const char *computer_name, 
+                                       const char *domain,
+                                       struct creds_CredentialState **creds)
 {
-       struct ldb_context *ldb;
        struct ldb_result *res;
        int ret;
        const struct ldb_val *val;
@@ -174,27 +189,21 @@ NTSTATUS schannel_fetch_session_key(TALLOC_CTX *mem_ctx,
                return NT_STATUS_NO_MEMORY;
        }
 
-       ldb = schannel_db_connect(mem_ctx);
-       if (ldb == NULL) {
-               return NT_STATUS_NO_MEMORY;
-       }
-
-       expr = talloc_asprintf(mem_ctx, "(&(computerName=%s)(flatname=%s))", computer_name, domain);
+       expr = talloc_asprintf(mem_ctx, "(&(computerName=%s)(flatname=%s))", 
+                              computer_name, domain);
        if (expr == NULL) {
-               talloc_free(ldb);
                return NT_STATUS_NO_MEMORY;
        }
 
        ret = ldb_search(ldb, NULL, LDB_SCOPE_SUBTREE, expr, NULL, &res);
        if (ret != LDB_SUCCESS || res->count != 1) {
-               talloc_free(ldb);
+               DEBUG(3,("schannel: Failed to find a record for client: %s\n", computer_name));
                return NT_STATUS_INVALID_HANDLE;
        }
 
        val = ldb_msg_find_ldb_val(res->msgs[0], "sessionKey");
        if (val == NULL || val->length != 16) {
                DEBUG(1,("schannel: record in schannel DB must contain a sessionKey of length 16, when searching for client: %s\n", computer_name));
-               talloc_free(ldb);
                return NT_STATUS_INTERNAL_ERROR;
        }
 
@@ -203,25 +212,56 @@ NTSTATUS schannel_fetch_session_key(TALLOC_CTX *mem_ctx,
        val = ldb_msg_find_ldb_val(res->msgs[0], "seed");
        if (val == NULL || val->length != 8) {
                DEBUG(1,("schannel: record in schannel DB must contain a vaid seed of length 8, when searching for client: %s\n", computer_name));
-               talloc_free(ldb);
                return NT_STATUS_INTERNAL_ERROR;
        }
 
        memcpy((*creds)->seed.data, val->data, 8);
 
-       (*creds)->negotiate_flags = ldb_msg_find_int(res->msgs[0], "negotiateFlags", 0);
+       val = ldb_msg_find_ldb_val(res->msgs[0], "clientState");
+       if (val == NULL || val->length != 8) {
+               DEBUG(1,("schannel: record in schannel DB must contain a vaid clientState of length 8, when searching for client: %s\n", computer_name));
+               return NT_STATUS_INTERNAL_ERROR;
+       }
+       memcpy((*creds)->client.data, val->data, 8);
+
+       val = ldb_msg_find_ldb_val(res->msgs[0], "serverState");
+       if (val == NULL || val->length != 8) {
+               DEBUG(1,("schannel: record in schannel DB must contain a vaid serverState of length 8, when searching for client: %s\n", computer_name));
+               return NT_STATUS_INTERNAL_ERROR;
+       }
+       memcpy((*creds)->server.data, val->data, 8);
 
-       (*creds)->secure_channel_type = ldb_msg_find_int(res->msgs[0], "secureChannelType", 0);
+       (*creds)->negotiate_flags = ldb_msg_find_attr_as_int(res->msgs[0], "negotiateFlags", 0);
 
-       (*creds)->account_name = talloc_reference(*creds, ldb_msg_find_string(res->msgs[0], "accountName", NULL));
+       (*creds)->secure_channel_type = ldb_msg_find_attr_as_int(res->msgs[0], "secureChannelType", 0);
 
-       (*creds)->computer_name = talloc_reference(*creds, ldb_msg_find_string(res->msgs[0], "computerName", NULL));
+       (*creds)->account_name = talloc_reference(*creds, ldb_msg_find_attr_as_string(res->msgs[0], "accountName", NULL));
 
-       (*creds)->domain = talloc_reference(*creds, ldb_msg_find_string(res->msgs[0], "flatname", NULL));
+       (*creds)->computer_name = talloc_reference(*creds, ldb_msg_find_attr_as_string(res->msgs[0], "computerName", NULL));
 
-       (*creds)->sid = samdb_result_dom_sid(*creds, res->msgs[0], "objectSid");
+       (*creds)->domain = talloc_reference(*creds, ldb_msg_find_attr_as_string(res->msgs[0], "flatname", NULL));
 
-       talloc_free(ldb);
+       (*creds)->sid = samdb_result_dom_sid(*creds, res->msgs[0], "objectSid");
 
        return NT_STATUS_OK;
 }
+
+NTSTATUS schannel_fetch_session_key(TALLOC_CTX *mem_ctx,
+                                       const char *computer_name, 
+                                       const char *domain, 
+                                       struct creds_CredentialState **creds)
+{
+       NTSTATUS nt_status;
+       struct ldb_context *ldb;
+
+       ldb = schannel_db_connect(mem_ctx);
+       if (!ldb) {
+               return NT_STATUS_ACCESS_DENIED;
+       }
+
+       nt_status = schannel_fetch_session_key_ldb(mem_ctx, ldb,
+                                                  computer_name, domain, 
+                                                  creds);
+       talloc_free(ldb);
+       return nt_status;
+}