cifs: use krb5_kt_default() to determine default keytab location
[jlayton/cifs-utils.git] / cifscreds.c
index d771056525f30f319c33fa2ca5790e6a16260b70..fa05dc88b0e022e4c0012476044d5ac3a5aa97ac 100644 (file)
 #include <string.h>
 #include <ctype.h>
 #include <keyutils.h>
+#include <getopt.h>
+#include <errno.h>
+#include "cifskey.h"
 #include "mount.h"
 #include "resolve_host.h"
+#include "util.h"
 
 #define THIS_PROGRAM_NAME "cifscreds"
 
 /* max length of appropriate command */
 #define MAX_COMMAND_SIZE 32
 
-/* max length of username, password and domain name */
-#define MAX_USERNAME_SIZE 32
-#define MOUNT_PASSWD_SIZE 128
-#define MAX_DOMAIN_SIZE 64
-
-/* allowed and disallowed characters for user and domain name */
-#define USER_DISALLOWED_CHARS "\\/\"[]:|<>+=;,?*@"
-#define DOMAIN_ALLOWED_CHARS "abcdefghijklmnopqrstuvwxyz" \
-                            "ABCDEFGHIJKLMNOPQRSTUVWXYZ-."
-
-/* destination keyring */
-#define DEST_KEYRING KEY_SPEC_USER_KEYRING
+struct cmdarg {
+       char            *host;
+       char            *user;
+       char            keytype;
+};
 
 struct command {
-       int (*action)(int argc, char *argv[]);
+       int (*action)(struct cmdarg *arg);
        const char      name[MAX_COMMAND_SIZE];
        const char      *format;
 };
 
-static int cifscreds_add(int argc, char *argv[]);
-static int cifscreds_clear(int argc, char *argv[]);
-static int cifscreds_clearall(int argc, char *argv[]);
-static int cifscreds_update(int argc, char *argv[]);
+static int cifscreds_add(struct cmdarg *arg);
+static int cifscreds_clear(struct cmdarg *arg);
+static int cifscreds_clearall(struct cmdarg *arg);
+static int cifscreds_update(struct cmdarg *arg);
 
-const char *thisprogram;
+static const char *thisprogram;
 
-struct command commands[] = {
-       { cifscreds_add,        "add",          "<host> <user> [domain]" },
-       { cifscreds_clear,      "clear",        "<host> <user> [domain]" },
+static struct command commands[] = {
+       { cifscreds_add,        "add",          "[-u username] [-d] <host|domain>" },
+       { cifscreds_clear,      "clear",        "[-u username] [-d] <host|domain>" },
        { cifscreds_clearall,   "clearall",     "" },
-       { cifscreds_update,     "update",       "<host> <user> [domain]" },
+       { cifscreds_update,     "update",       "[-u username] [-d] <host|domain>" },
        { NULL, "", NULL }
 };
 
+static struct option longopts[] = {
+       {"username", 1, NULL, 'u'},
+       {"domain", 0, NULL, 'd' },
+       {NULL, 0, NULL, 0}
+};
+
 /* display usage information */
-static void usage(void)
+static int
+usage(void)
 {
        struct command *cmd;
 
@@ -80,88 +84,7 @@ static void usage(void)
                        cmd->name, cmd->format);
        fprintf(stderr, "\n");
 
-       exit(EXIT_FAILURE);
-}
-
-/* create key's description string from given credentials */
-static char *
-create_description(const char *addr, const char *user,
-                  const char *domain, char *desc)
-{
-       char *str_end;
-       int str_len;
-
-       sprintf(desc, "%s:%s:%s:", THIS_PROGRAM_NAME, addr, user);
-
-       if (domain != NULL) {
-               str_end = desc + strnlen(desc, INET6_ADDRSTRLEN + \
-                                       + MAX_USERNAME_SIZE + \
-                                       + sizeof(THIS_PROGRAM_NAME) + 3);
-               str_len = strnlen(domain, MAX_DOMAIN_SIZE);
-               while (str_len--) {
-                       *str_end = tolower(*domain++);
-                       str_end++;
-               }
-               *str_end = '\0';
-       }
-
-       return desc;
-}
-
-/* search a specific key in keyring */
-static key_serial_t
-key_search(const char *addr, const char *user, const char *domain)
-{
-       char desc[INET6_ADDRSTRLEN + MAX_USERNAME_SIZE + MAX_DOMAIN_SIZE + \
-               + sizeof(THIS_PROGRAM_NAME) + 3];
-       key_serial_t key, *pk;
-       void *keylist;
-       char *buffer;
-       int count, dpos, n, ret;
-
-       create_description(addr, user, domain, desc);
-
-       /* read the key payload data */
-       count = keyctl_read_alloc(DEST_KEYRING, &keylist);
-       if (count < 0)
-               return 0;
-
-       count /= sizeof(key_serial_t);
-
-       if (count == 0) {
-               ret = 0;
-               goto key_search_out;
-       }
-
-       /* list the keys in the keyring */
-       pk = keylist;
-       do {
-               key = *pk++;
-
-               ret = keyctl_describe_alloc(key, &buffer);
-               if (ret < 0)
-                       continue;
-
-               n = sscanf(buffer, "%*[^;];%*d;%*d;%*x;%n", &dpos);
-               if (n) {
-                       free(buffer);
-                       continue;
-               }
-
-               if (!strcmp(buffer + dpos, desc)) {
-                       ret = key;
-                       free(buffer);
-                       goto key_search_out;
-               }
-               free(buffer);
-
-       } while (--count);
-
-       ret = 0;
-
-key_search_out:
-       free(keylist);
-       return ret;
+       return EXIT_FAILURE;
 }
 
 /* search all program's keys in keyring */
@@ -199,7 +122,7 @@ static key_serial_t key_search_all(void)
                        continue;
                }
 
-               if (strstr(buffer + dpos, THIS_PROGRAM_NAME ":") ==
+               if (strstr(buffer + dpos, KEY_PREFIX ":") ==
                        buffer + dpos
                ) {
                        ret = key;
@@ -217,36 +140,26 @@ key_search_all_out:
        return ret;
 }
 
-/* add or update a specific key to keyring */
-static key_serial_t
-key_add(const char *addr, const char *user,
-       const char *domain, const char *pass)
-{
-       char desc[INET6_ADDRSTRLEN + MAX_USERNAME_SIZE + MAX_DOMAIN_SIZE + \
-               + sizeof(THIS_PROGRAM_NAME) + 3];
-
-       create_description(addr, user, domain, desc);
-
-       return add_key("user", desc, pass, strnlen(pass, MOUNT_PASSWD_SIZE) + 1,
-               DEST_KEYRING);
-}
-
 /* add command handler */
-static int cifscreds_add(int argc, char *argv[])
+static int cifscreds_add(struct cmdarg *arg)
 {
        char addrstr[MAX_ADDR_LIST_LEN];
        char *currentaddress, *nextaddress;
        char *pass;
-       int ret;
+       int ret = 0;
+
+       if (arg->host == NULL || arg->user == NULL)
+               return usage();
 
-       if (argc != 4 && argc != 5)
-               usage();
+       if (arg->keytype == 'd')
+               strlcpy(addrstr, arg->host, MAX_ADDR_LIST_LEN);
+       else
+               ret = resolve_host(arg->host, addrstr);
 
-       ret = resolve_host(argv[2], addrstr);
        switch (ret) {
        case EX_USAGE:
                fprintf(stderr, "error: Could not resolve address "
-                       "for %s\n", argv[2]);
+                       "for %s\n", arg->host);
                return EXIT_FAILURE;
 
        case EX_SYSERR:
@@ -254,20 +167,11 @@ static int cifscreds_add(int argc, char *argv[])
                return EXIT_FAILURE;
        }
 
-       if (strpbrk(argv[3], USER_DISALLOWED_CHARS)) {
+       if (strpbrk(arg->user, USER_DISALLOWED_CHARS)) {
                fprintf(stderr, "error: Incorrect username\n");
                return EXIT_FAILURE;
        }
 
-       if (argc == 5) {
-               if (strspn(argv[4], DOMAIN_ALLOWED_CHARS) !=
-                       strnlen(argv[4], MAX_DOMAIN_SIZE)
-               ) {
-                       fprintf(stderr, "error: Incorrect domain name\n");
-                       return EXIT_FAILURE;
-               }
-       }
-
        /* search for same credentials stashed for current host */
        currentaddress = addrstr;
        nextaddress = strchr(currentaddress, ',');
@@ -275,11 +179,9 @@ static int cifscreds_add(int argc, char *argv[])
                *nextaddress++ = '\0';
 
        while (currentaddress) {
-               if (key_search(currentaddress, argv[3],
-                       argc == 5 ? argv[4] : NULL) > 0
-               ) {
+               if (key_search(currentaddress, arg->keytype) > 0) {
                        printf("You already have stashed credentials "
-                               "for %s (%s)\n", currentaddress, argv[2]);
+                               "for %s (%s)\n", currentaddress, arg->host);
                        printf("If you want to update them use:\n");
                        printf("\t%s update\n", thisprogram);
 
@@ -307,23 +209,19 @@ static int cifscreds_add(int argc, char *argv[])
                *nextaddress++ = '\0';
 
        while (currentaddress) {
-               key_serial_t key = key_add(currentaddress, argv[3],
-                                          argc == 5 ? argv[4] : NULL, pass);
+               key_serial_t key = key_add(currentaddress, arg->user, pass, arg->keytype);
                if (key <= 0) {
                        fprintf(stderr, "error: Add credential key for %s\n",
                                currentaddress);
                } else {
-                       if (keyctl(KEYCTL_SETPERM, key, KEY_POS_VIEW | \
-                               KEY_POS_WRITE | KEY_USR_VIEW | \
-                               KEY_USR_WRITE) < 0
-                       ) {
+                       if (keyctl(KEYCTL_SETPERM, key, CIFS_KEY_PERMS) < 0) {
                                fprintf(stderr, "error: Setting permissons "
                                        "on key, attempt to delete...\n");
 
                                if (keyctl(KEYCTL_UNLINK, key, DEST_KEYRING) < 0) {
                                        fprintf(stderr, "error: Deleting key from "
                                                "keyring for %s (%s)\n",
-                                               currentaddress, argv[2]);
+                                               currentaddress, arg->host);
                                }
                        }
                }
@@ -340,20 +238,24 @@ static int cifscreds_add(int argc, char *argv[])
 }
 
 /* clear command handler */
-static int cifscreds_clear(int argc, char *argv[])
+static int cifscreds_clear(struct cmdarg *arg)
 {
        char addrstr[MAX_ADDR_LIST_LEN];
        char *currentaddress, *nextaddress;
-       int ret, count = 0, errors = 0;
+       int ret = 0, count = 0, errors = 0;
+
+       if (arg->host == NULL || arg->user == NULL)
+               return usage();
 
-       if (argc != 4 && argc != 5)
-               usage();
+       if (arg->keytype == 'd')
+               strlcpy(addrstr, arg->host, MAX_ADDR_LIST_LEN);
+       else
+               ret = resolve_host(arg->host, addrstr);
 
-       ret = resolve_host(argv[2], addrstr);
        switch (ret) {
        case EX_USAGE:
                fprintf(stderr, "error: Could not resolve address "
-                       "for %s\n", argv[2]);
+                       "for %s\n", arg->host);
                return EXIT_FAILURE;
 
        case EX_SYSERR:
@@ -361,20 +263,11 @@ static int cifscreds_clear(int argc, char *argv[])
                return EXIT_FAILURE;
        }
 
-       if (strpbrk(argv[3], USER_DISALLOWED_CHARS)) {
+       if (strpbrk(arg->user, USER_DISALLOWED_CHARS)) {
                fprintf(stderr, "error: Incorrect username\n");
                return EXIT_FAILURE;
        }
 
-       if (argc == 5) {
-               if (strspn(argv[4], DOMAIN_ALLOWED_CHARS) !=
-                       strnlen(argv[4], MAX_DOMAIN_SIZE)
-               ) {
-                       fprintf(stderr, "error: Incorrect domain name\n");
-                       return EXIT_FAILURE;
-               }
-       }
-
        /*
         * search for same credentials stashed for current host
         * and unlink them from session keyring
@@ -385,13 +278,12 @@ static int cifscreds_clear(int argc, char *argv[])
                *nextaddress++ = '\0';
 
        while (currentaddress) {
-               key_serial_t key = key_search(currentaddress, argv[3],
-                                               argc == 5 ? argv[4] : NULL);
+               key_serial_t key = key_search(currentaddress, arg->keytype);
                if (key > 0) {
                        if (keyctl(KEYCTL_UNLINK, key, DEST_KEYRING) < 0) {
                                fprintf(stderr, "error: Removing key from "
                                        "keyring for %s (%s)\n",
-                                       currentaddress, argv[2]);
+                                       currentaddress, arg->host);
                                errors++;
                        } else {
                                count++;
@@ -408,7 +300,7 @@ static int cifscreds_clear(int argc, char *argv[])
 
        if (!count && !errors) {
                printf("You have no same stashed credentials "
-                       " for %s\n", argv[2]);
+                       " for %s\n", arg->host);
                printf("If you want to add them use:\n");
                printf("\t%s add\n", thisprogram);
 
@@ -419,14 +311,11 @@ static int cifscreds_clear(int argc, char *argv[])
 }
 
 /* clearall command handler */
-static int cifscreds_clearall(int argc, char *argv[] __attribute__ ((unused)))
+static int cifscreds_clearall(struct cmdarg *arg __attribute__ ((unused)))
 {
        key_serial_t key;
        int count = 0, errors = 0;
 
-       if (argc != 2)
-               usage();
-
        /*
         * search for all program's credentials stashed in session keyring
         * and then unlink them
@@ -445,7 +334,7 @@ static int cifscreds_clearall(int argc, char *argv[] __attribute__ ((unused)))
        } while (key > 0);
 
        if (!count && !errors) {
-               printf("You have no stashed " THIS_PROGRAM_NAME
+               printf("You have no stashed " KEY_PREFIX
                        " credentials\n");
                printf("If you want to add them use:\n");
                printf("\t%s add\n", thisprogram);
@@ -457,21 +346,25 @@ static int cifscreds_clearall(int argc, char *argv[] __attribute__ ((unused)))
 }
 
 /* update command handler */
-static int cifscreds_update(int argc, char *argv[])
+static int cifscreds_update(struct cmdarg *arg)
 {
        char addrstr[MAX_ADDR_LIST_LEN];
        char *currentaddress, *nextaddress, *pass;
        char *addrs[16];
-       int ret, id, count = 0;
+       int ret = 0, id, count = 0;
 
-       if (argc != 4 && argc != 5)
-               usage();
+       if (arg->host == NULL || arg->user == NULL)
+               return usage();
+
+       if (arg->keytype == 'd')
+               strlcpy(addrstr, arg->host, MAX_ADDR_LIST_LEN);
+       else
+               ret = resolve_host(arg->host, addrstr);
 
-       ret = resolve_host(argv[2], addrstr);
        switch (ret) {
        case EX_USAGE:
                fprintf(stderr, "error: Could not resolve address "
-                       "for %s\n", argv[2]);
+                       "for %s\n", arg->host);
                return EXIT_FAILURE;
 
        case EX_SYSERR:
@@ -479,20 +372,11 @@ static int cifscreds_update(int argc, char *argv[])
                return EXIT_FAILURE;
        }
 
-       if (strpbrk(argv[3], USER_DISALLOWED_CHARS)) {
+       if (strpbrk(arg->user, USER_DISALLOWED_CHARS)) {
                fprintf(stderr, "error: Incorrect username\n");
                return EXIT_FAILURE;
        }
 
-       if (argc == 5) {
-               if (strspn(argv[4], DOMAIN_ALLOWED_CHARS) !=
-                       strnlen(argv[4], MAX_DOMAIN_SIZE)
-               ) {
-                       fprintf(stderr, "error: Incorrect domain name\n");
-                       return EXIT_FAILURE;
-               }
-       }
-
        /* search for necessary credentials stashed in session keyring */
        currentaddress = addrstr;
        nextaddress = strchr(currentaddress, ',');
@@ -500,9 +384,7 @@ static int cifscreds_update(int argc, char *argv[])
                *nextaddress++ = '\0';
 
        while (currentaddress) {
-               if (key_search(currentaddress, argv[3],
-                       argc == 5 ? argv[4] : NULL) > 0
-               ) {
+               if (key_search(currentaddress, arg->keytype) > 0) {
                        addrs[count] = currentaddress;
                        count++;
                }
@@ -517,7 +399,7 @@ static int cifscreds_update(int argc, char *argv[])
 
        if (!count) {
                printf("You have no same stashed credentials "
-                       "for %s\n", argv[2]);
+                       "for %s\n", arg->host);
                printf("If you want to add them use:\n");
                printf("\t%s add\n", thisprogram);
 
@@ -528,8 +410,7 @@ static int cifscreds_update(int argc, char *argv[])
        pass = getpass("Password: ");
 
        for (id = 0; id < count; id++) {
-               key_serial_t key = key_add(addrs[id], argv[3],
-                                       argc == 5 ? argv[4] : NULL, pass);
+               key_serial_t key = key_add(addrs[id], arg->user, pass, arg->keytype);
                if (key <= 0)
                        fprintf(stderr, "error: Update credential key "
                                "for %s\n", addrs[id]);
@@ -538,24 +419,71 @@ static int cifscreds_update(int argc, char *argv[])
        return EXIT_SUCCESS;
 }
 
+static int
+check_session_keyring(void)
+{
+       key_serial_t    ses_key, uses_key;
+
+       ses_key = keyctl_get_keyring_ID(KEY_SPEC_SESSION_KEYRING, 0);
+       if (ses_key == -1) {
+               if (errno == ENOKEY)
+                       fprintf(stderr, "Error: you have no session keyring. "
+                                       "Consider using pam_keyinit to "
+                                       "install one.\n");
+               else
+                       fprintf(stderr, "Error: unable to query session "
+                                       "keyring: %s\n", strerror(errno));
+               return (int)ses_key;
+       }
+
+       /* A problem querying the user-session keyring isn't fatal. */
+       uses_key = keyctl_get_keyring_ID(KEY_SPEC_USER_SESSION_KEYRING, 0);
+       if (uses_key == -1)
+               return 0;
+
+       if (ses_key == uses_key)
+               fprintf(stderr, "Warning: you have no persistent session "
+                               "keyring. cifscreds keys will not persist "
+                               "after this process exits. See "
+                               "pam_keyinit(8).\n");
+       return 0;
+}
+
 int main(int argc, char **argv)
 {
        struct command *cmd, *best;
+       struct cmdarg arg;
        int n;
 
+       memset(&arg, 0, sizeof(arg));
+       arg.keytype = 'a';
+
        thisprogram = (char *)basename(argv[0]);
        if (thisprogram == NULL)
                thisprogram = THIS_PROGRAM_NAME;
 
        if (argc == 1)
-               usage();
+               return usage();
+
+       while((n = getopt_long(argc, argv, "du:", longopts, NULL)) != -1) {
+               switch (n) {
+               case 'd':
+                       arg.keytype = (char) n;
+                       break;
+               case 'u':
+                       arg.user = optarg;
+                       break;
+               default:
+                       return usage();
+               }
+       }
 
        /* find the best fit command */
        best = NULL;
-       n = strnlen(argv[1], MAX_COMMAND_SIZE);
+       n = strnlen(argv[optind], MAX_COMMAND_SIZE);
 
        for (cmd = commands; cmd->action; cmd++) {
-               if (memcmp(cmd->name, argv[1], n) != 0)
+               if (memcmp(cmd->name, argv[optind], n) != 0)
                        continue;
 
                if (cmd->name[n] == 0) {
@@ -567,7 +495,7 @@ int main(int argc, char **argv)
                /* partial match */
                if (best) {
                        fprintf(stderr, "Ambiguous command\n");
-                       exit(EXIT_FAILURE);
+                       return EXIT_FAILURE;
                }
 
                best = cmd;
@@ -575,8 +503,24 @@ int main(int argc, char **argv)
 
        if (!best) {
                fprintf(stderr, "Unknown command\n");
-               exit(EXIT_FAILURE);
+               return EXIT_FAILURE;
        }
 
-       exit(best->action(argc, argv));
+       /* second argument should be host or domain */
+       if (argc >= 3)
+               arg.host = argv[optind + 1];
+
+       if (arg.host && arg.keytype == 'd' &&
+           strpbrk(arg.host, DOMAIN_DISALLOWED_CHARS)) {
+               fprintf(stderr, "error: Domain name contains invalid characters\n");
+               return EXIT_FAILURE;
+       }
+
+       if (arg.user == NULL)
+               arg.user = getusername(getuid());
+
+       if (check_session_keyring())
+               return EXIT_FAILURE;
+
+       return best->action(&arg);
 }