auth/credentials: Add cli_credentials_{set,get}_forced_sasl_mech()
[mat/samba.git] / auth / credentials / credentials.c
index e5978099ca4838ab5ccc7d965f1517b3cb8c470f..e98dfbdae4e49187b3f315be59672959c91b1507 100644 (file)
@@ -24,6 +24,7 @@
 #include "includes.h"
 #include "librpc/gen_ndr/samr.h" /* for struct samrPassword */
 #include "auth/credentials/credentials.h"
+#include "auth/credentials/credentials_internal.h"
 #include "libcli/auth/libcli_auth.h"
 #include "tevent.h"
 #include "param/param.h"
@@ -103,7 +104,7 @@ _PUBLIC_ struct cli_credentials *cli_credentials_init(TALLOC_CTX *mem_ctx)
 
        cred->machine_account = false;
 
-       cred->tries = 3;
+       cred->password_tries = 0;
 
        cred->callback_running = false;
 
@@ -111,9 +112,37 @@ _PUBLIC_ struct cli_credentials *cli_credentials_init(TALLOC_CTX *mem_ctx)
        cli_credentials_set_gensec_features(cred, 0);
        cli_credentials_set_krb_forwardable(cred, CRED_AUTO_KRB_FORWARDABLE);
 
+       cred->forced_sasl_mech = NULL;
+
        return cred;
 }
 
+_PUBLIC_ void cli_credentials_set_callback_data(struct cli_credentials *cred,
+                                               void *callback_data)
+{
+       cred->priv_data = callback_data;
+}
+
+_PUBLIC_ void *_cli_credentials_callback_data(struct cli_credentials *cred)
+{
+       return cred->priv_data;
+}
+
+_PUBLIC_ struct cli_credentials *cli_credentials_shallow_copy(TALLOC_CTX *mem_ctx,
+                                               struct cli_credentials *src)
+{
+       struct cli_credentials *dst;
+
+       dst = talloc(mem_ctx, struct cli_credentials);
+       if (dst == NULL) {
+               return NULL;
+       }
+
+       *dst = *src;
+
+       return dst;
+}
+
 /**
  * Create a new anonymous credential
  * @param mem_ctx TALLOC_CTX parent for credentials structure 
@@ -134,6 +163,13 @@ _PUBLIC_ void cli_credentials_set_kerberos_state(struct cli_credentials *creds,
        creds->use_kerberos = use_kerberos;
 }
 
+_PUBLIC_ void cli_credentials_set_forced_sasl_mech(struct cli_credentials *creds,
+                                                  const char *sasl_mech)
+{
+       TALLOC_FREE(creds->forced_sasl_mech);
+       creds->forced_sasl_mech = talloc_strdup(creds, sasl_mech);
+}
+
 _PUBLIC_ void cli_credentials_set_krb_forwardable(struct cli_credentials *creds,
                                                  enum credentials_krb_forwardable krb_forwardable)
 {
@@ -145,6 +181,11 @@ _PUBLIC_ enum credentials_use_kerberos cli_credentials_get_kerberos_state(struct
        return creds->use_kerberos;
 }
 
+_PUBLIC_ const char *cli_credentials_get_forced_sasl_mech(struct cli_credentials *creds)
+{
+       return creds->forced_sasl_mech;
+}
+
 _PUBLIC_ enum credentials_krb_forwardable cli_credentials_get_krb_forwardable(struct cli_credentials *creds)
 {
        return creds->krb_forwardable;
@@ -179,8 +220,10 @@ _PUBLIC_ const char *cli_credentials_get_username(struct cli_credentials *cred)
                cred->callback_running = true;
                cred->username = cred->username_cb(cred);
                cred->callback_running = false;
-               cred->username_obtained = CRED_SPECIFIED;
-               cli_credentials_invalidate_ccache(cred, cred->username_obtained);
+               if (cred->username_obtained == CRED_CALLBACK) {
+                       cred->username_obtained = CRED_CALLBACK_RESULT;
+                       cli_credentials_invalidate_ccache(cred, cred->username_obtained);
+               }
        }
 
        return cred->username;
@@ -248,8 +291,10 @@ _PUBLIC_ const char *cli_credentials_get_principal_and_obtained(struct cli_crede
                cred->callback_running = true;
                cred->principal = cred->principal_cb(cred);
                cred->callback_running = false;
-               cred->principal_obtained = CRED_SPECIFIED;
-               cli_credentials_invalidate_ccache(cred, cred->principal_obtained);
+               if (cred->principal_obtained == CRED_CALLBACK) {
+                       cred->principal_obtained = CRED_CALLBACK_RESULT;
+                       cli_credentials_invalidate_ccache(cred, cred->principal_obtained);
+               }
        }
 
        if (cred->principal_obtained < cred->username_obtained
@@ -267,7 +312,7 @@ _PUBLIC_ const char *cli_credentials_get_principal_and_obtained(struct cli_crede
                }
        }
        *obtained = cred->principal_obtained;
-       return talloc_reference(mem_ctx, cred->principal);
+       return talloc_strdup(mem_ctx, cred->principal);
 }
 
 /**
@@ -355,8 +400,10 @@ _PUBLIC_ const char *cli_credentials_get_password(struct cli_credentials *cred)
                cred->callback_running = true;
                cred->password = cred->password_cb(cred);
                cred->callback_running = false;
-               cred->password_obtained = CRED_CALLBACK_RESULT;
-               cli_credentials_invalidate_ccache(cred, cred->password_obtained);
+               if (cred->password_obtained == CRED_CALLBACK) {
+                       cred->password_obtained = CRED_CALLBACK_RESULT;
+                       cli_credentials_invalidate_ccache(cred, cred->password_obtained);
+               }
        }
 
        return cred->password;
@@ -370,6 +417,7 @@ _PUBLIC_ bool cli_credentials_set_password(struct cli_credentials *cred,
                                  enum credentials_obtained obtained)
 {
        if (obtained >= cred->password_obtained) {
+               cred->password_tries = 0;
                cred->password = talloc_strdup(cred, val);
                if (cred->password) {
                        /* Don't print the actual password in talloc memory dumps */
@@ -391,6 +439,7 @@ _PUBLIC_ bool cli_credentials_set_password_callback(struct cli_credentials *cred
                                           const char *(*password_cb) (struct cli_credentials *))
 {
        if (cred->password_obtained < CRED_CALLBACK) {
+               cred->password_tries = 3;
                cred->password_cb = password_cb;
                cred->password_obtained = CRED_CALLBACK;
                cli_credentials_invalidate_ccache(cred, cred->password_obtained);
@@ -436,8 +485,8 @@ _PUBLIC_ bool cli_credentials_set_old_password(struct cli_credentials *cred,
  * @param cred credentials context
  * @retval If set, the cleartext password, otherwise NULL
  */
-_PUBLIC_ const struct samr_Password *cli_credentials_get_nt_hash(struct cli_credentials *cred, 
-                                                       TALLOC_CTX *mem_ctx)
+_PUBLIC_ struct samr_Password *cli_credentials_get_nt_hash(struct cli_credentials *cred,
+                                                          TALLOC_CTX *mem_ctx)
 {
        const char *password = cli_credentials_get_password(cred);
 
@@ -446,13 +495,22 @@ _PUBLIC_ const struct samr_Password *cli_credentials_get_nt_hash(struct cli_cred
                if (!nt_hash) {
                        return NULL;
                }
-               
+
                E_md4hash(password, nt_hash->hash);    
 
                return nt_hash;
-       } else {
-               return cred->nt_hash;
+       } else if (cred->nt_hash != NULL) {
+               struct samr_Password *nt_hash = talloc(mem_ctx, struct samr_Password);
+               if (!nt_hash) {
+                       return NULL;
+               }
+
+               *nt_hash = *cred->nt_hash;
+
+               return nt_hash;
        }
+
+       return NULL;
 }
 
 /**
@@ -473,8 +531,10 @@ _PUBLIC_ const char *cli_credentials_get_domain(struct cli_credentials *cred)
                cred->callback_running = true;
                cred->domain = cred->domain_cb(cred);
                cred->callback_running = false;
-               cred->domain_obtained = CRED_SPECIFIED;
-               cli_credentials_invalidate_ccache(cred, cred->domain_obtained);
+               if (cred->domain_obtained == CRED_CALLBACK) {
+                       cred->domain_obtained = CRED_CALLBACK_RESULT;
+                       cli_credentials_invalidate_ccache(cred, cred->domain_obtained);
+               }
        }
 
        return cred->domain;
@@ -532,8 +592,10 @@ _PUBLIC_ const char *cli_credentials_get_realm(struct cli_credentials *cred)
                cred->callback_running = true;
                cred->realm = cred->realm_cb(cred);
                cred->callback_running = false;
-               cred->realm_obtained = CRED_SPECIFIED;
-               cli_credentials_invalidate_ccache(cred, cred->realm_obtained);
+               if (cred->realm_obtained == CRED_CALLBACK) {
+                       cred->realm_obtained = CRED_CALLBACK_RESULT;
+                       cli_credentials_invalidate_ccache(cred, cred->realm_obtained);
+               }
        }
 
        return cred->realm;
@@ -583,7 +645,9 @@ _PUBLIC_ const char *cli_credentials_get_workstation(struct cli_credentials *cre
                cred->callback_running = true;
                cred->workstation = cred->workstation_cb(cred);
                cred->callback_running = false;
-               cred->workstation_obtained = CRED_SPECIFIED;
+               if (cred->workstation_obtained == CRED_CALLBACK) {
+                       cred->workstation_obtained = CRED_CALLBACK_RESULT;
+               }
        }
 
        return cred->workstation;
@@ -870,12 +934,19 @@ _PUBLIC_ bool cli_credentials_wrong_password(struct cli_credentials *cred)
        if (cred->password_obtained != CRED_CALLBACK_RESULT) {
                return false;
        }
-       
-       cred->password_obtained = CRED_CALLBACK;
 
-       cred->tries--;
+       if (cred->password_tries == 0) {
+               return false;
+       }
+
+       cred->password_tries--;
+
+       if (cred->password_tries == 0) {
+               return false;
+       }
 
-       return (cred->tries > 0);
+       cred->password_obtained = CRED_CALLBACK;
+       return true;
 }
 
 _PUBLIC_ void cli_credentials_get_ntlm_username_domain(struct cli_credentials *cred, TALLOC_CTX *mem_ctx,