Remove __contains__ from mock object for consistency with actual
[samba.git] / client / cifs.upcall.c
index bf6a861544c80b01c8c1ac904bfe54ff3f3d70c2..42632a0da09308ebc1f54256932b489626253c18 100644 (file)
@@ -27,6 +27,7 @@ create dns_resolver * * /usr/local/sbin/cifs.upcall %k
 
 #include "includes.h"
 #include "../libcli/auth/spnego.h"
+#include "smb_krb5.h"
 #include <keyutils.h>
 #include <getopt.h>
 
@@ -45,18 +46,6 @@ typedef enum _sectype {
        MS_KRB5
 } sectype_t;
 
-static inline int
-k5_data_equal(krb5_data d1, krb5_data d2, unsigned int length)
-{
-       if (!length)
-               length = d1.length;
-
-       return (d1.length == length &&
-               d1.length == d2.length &&
-               memcmp(d1.data, d2.data, length) == 0);
-
-}
-
 /* does the ccache have a valid TGT? */
 static time_t
 get_tgt_time(const char *ccname) {
@@ -65,9 +54,9 @@ get_tgt_time(const char *ccname) {
        krb5_cc_cursor cur;
        krb5_creds creds;
        krb5_principal principal;
-       krb5_data tgt = { .data =       "krbtgt",
-                         .length =     6 };
        time_t credtime = 0;
+       char *realm = NULL;
+       TALLOC_CTX *mem_ctx;
 
        if (krb5_init_context(&context)) {
                syslog(LOG_DEBUG, "%s: unable to init krb5 context", __func__);
@@ -94,20 +83,35 @@ get_tgt_time(const char *ccname) {
                goto err_ccstart;
        }
 
+       if ((realm = smb_krb5_principal_get_realm(context, principal)) == NULL) {
+               syslog(LOG_DEBUG, "%s: unable to get realm", __func__);
+               goto err_ccstart;
+       }
+
+       mem_ctx = talloc_init("cifs.upcall");
        while (!credtime && !krb5_cc_next_cred(context, ccache, &cur, &creds)) {
-               if (k5_data_equal(creds.server->realm, principal->realm, 0) &&
-                   k5_data_equal(creds.server->data[0], tgt, tgt.length) &&
-                   k5_data_equal(creds.server->data[1], principal->realm, 0) &&
+               char *name;
+               if (smb_krb5_unparse_name(mem_ctx, context, creds.server, &name)) {
+                       syslog(LOG_DEBUG, "%s: unable to unparse name", __func__);
+                       goto err_endseq;
+               }
+               if (krb5_realm_compare(context, creds.server, principal) &&
+                   strnequal(name, KRB5_TGS_NAME, KRB5_TGS_NAME_SIZE) &&
+                   strnequal(name+KRB5_TGS_NAME_SIZE+1, realm, strlen(realm)) &&
                    creds.times.endtime > time(NULL))
                        credtime = creds.times.endtime;
                 krb5_free_cred_contents(context, &creds);
+               TALLOC_FREE(name);
         }
+err_endseq:
+       TALLOC_FREE(mem_ctx);
         krb5_cc_end_seq_get(context, ccache, &cur);
-
 err_ccstart:
        krb5_free_principal(context, principal);
 err_princ:
+#if defined(KRB5_TC_OPENCLOSE)
        krb5_cc_set_flags(context, ccache, KRB5_TC_OPENCLOSE);
+#endif
        krb5_cc_close(context, ccache);
 err_cache:
        krb5_free_context(context);
@@ -221,7 +225,7 @@ handle_krb5_mech(const char *oid, const char *principal, DATA_BLOB *secblob,
 
        /* get a kerberos ticket for the service and extract the session key */
        retval = cli_krb5_get_ticket(principal, 0, &tkt, sess_key, 0, ccname,
-                                    NULL);
+                                    NULL, NULL);
 
        if (retval) {
                syslog(LOG_DEBUG, "%s: failed to obtain service ticket (%d)",