d771056525f30f319c33fa2ca5790e6a16260b70
[jlayton/cifs-utils.git] / cifscreds.c
1 /*
2  * Credentials stashing utility for Linux CIFS VFS (virtual filesystem) client
3  * Copyright (C) 2010 Jeff Layton (jlayton@samba.org)
4  * Copyright (C) 2010 Igor Druzhinin (jaxbrigs@gmail.com)
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #ifdef HAVE_CONFIG_H
21 #include "config.h"
22 #endif /* HAVE_CONFIG_H */
23
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <unistd.h>
27 #include <string.h>
28 #include <ctype.h>
29 #include <keyutils.h>
30 #include "mount.h"
31 #include "resolve_host.h"
32
33 #define THIS_PROGRAM_NAME "cifscreds"
34
35 /* max length of appropriate command */
36 #define MAX_COMMAND_SIZE 32
37
38 /* max length of username, password and domain name */
39 #define MAX_USERNAME_SIZE 32
40 #define MOUNT_PASSWD_SIZE 128
41 #define MAX_DOMAIN_SIZE 64
42
43 /* allowed and disallowed characters for user and domain name */
44 #define USER_DISALLOWED_CHARS "\\/\"[]:|<>+=;,?*@"
45 #define DOMAIN_ALLOWED_CHARS "abcdefghijklmnopqrstuvwxyz" \
46                              "ABCDEFGHIJKLMNOPQRSTUVWXYZ-."
47
48 /* destination keyring */
49 #define DEST_KEYRING KEY_SPEC_USER_KEYRING
50
51 struct command {
52         int (*action)(int argc, char *argv[]);
53         const char      name[MAX_COMMAND_SIZE];
54         const char      *format;
55 };
56
57 static int cifscreds_add(int argc, char *argv[]);
58 static int cifscreds_clear(int argc, char *argv[]);
59 static int cifscreds_clearall(int argc, char *argv[]);
60 static int cifscreds_update(int argc, char *argv[]);
61
62 const char *thisprogram;
63
64 struct command commands[] = {
65         { cifscreds_add,        "add",          "<host> <user> [domain]" },
66         { cifscreds_clear,      "clear",        "<host> <user> [domain]" },
67         { cifscreds_clearall,   "clearall",     "" },
68         { cifscreds_update,     "update",       "<host> <user> [domain]" },
69         { NULL, "", NULL }
70 };
71
72 /* display usage information */
73 static void usage(void)
74 {
75         struct command *cmd;
76
77         fprintf(stderr, "Usage:\n");
78         for (cmd = commands; cmd->action; cmd++)
79                 fprintf(stderr, "\t%s %s %s\n", thisprogram,
80                         cmd->name, cmd->format);
81         fprintf(stderr, "\n");
82
83         exit(EXIT_FAILURE);
84 }
85
86 /* create key's description string from given credentials */
87 static char *
88 create_description(const char *addr, const char *user,
89                    const char *domain, char *desc)
90 {
91         char *str_end;
92         int str_len;
93
94         sprintf(desc, "%s:%s:%s:", THIS_PROGRAM_NAME, addr, user);
95
96         if (domain != NULL) {
97                 str_end = desc + strnlen(desc, INET6_ADDRSTRLEN + \
98                                         + MAX_USERNAME_SIZE + \
99                                         + sizeof(THIS_PROGRAM_NAME) + 3);
100                 str_len = strnlen(domain, MAX_DOMAIN_SIZE);
101                 while (str_len--) {
102                         *str_end = tolower(*domain++);
103                         str_end++;
104                 }
105                 *str_end = '\0';
106         }
107
108         return desc;
109 }
110
111 /* search a specific key in keyring */
112 static key_serial_t
113 key_search(const char *addr, const char *user, const char *domain)
114 {
115         char desc[INET6_ADDRSTRLEN + MAX_USERNAME_SIZE + MAX_DOMAIN_SIZE + \
116                 + sizeof(THIS_PROGRAM_NAME) + 3];
117         key_serial_t key, *pk;
118         void *keylist;
119         char *buffer;
120         int count, dpos, n, ret;
121
122         create_description(addr, user, domain, desc);
123
124         /* read the key payload data */
125         count = keyctl_read_alloc(DEST_KEYRING, &keylist);
126         if (count < 0)
127                 return 0;
128
129         count /= sizeof(key_serial_t);
130
131         if (count == 0) {
132                 ret = 0;
133                 goto key_search_out;
134         }
135
136         /* list the keys in the keyring */
137         pk = keylist;
138         do {
139                 key = *pk++;
140
141                 ret = keyctl_describe_alloc(key, &buffer);
142                 if (ret < 0)
143                         continue;
144
145                 n = sscanf(buffer, "%*[^;];%*d;%*d;%*x;%n", &dpos);
146                 if (n) {
147                         free(buffer);
148                         continue;
149                 }
150
151                 if (!strcmp(buffer + dpos, desc)) {
152                         ret = key;
153                         free(buffer);
154                         goto key_search_out;
155                 }
156                 free(buffer);
157
158         } while (--count);
159
160         ret = 0;
161
162 key_search_out:
163         free(keylist);
164         return ret;
165 }
166
167 /* search all program's keys in keyring */
168 static key_serial_t key_search_all(void)
169 {
170         key_serial_t key, *pk;
171         void *keylist;
172         char *buffer;
173         int count, dpos, n, ret;
174
175         /* read the key payload data */
176         count = keyctl_read_alloc(DEST_KEYRING, &keylist);
177         if (count < 0)
178                 return 0;
179
180         count /= sizeof(key_serial_t);
181
182         if (count == 0) {
183                 ret = 0;
184                 goto key_search_all_out;
185         }
186
187         /* list the keys in the keyring */
188         pk = keylist;
189         do {
190                 key = *pk++;
191
192                 ret = keyctl_describe_alloc(key, &buffer);
193                 if (ret < 0)
194                         continue;
195
196                 n = sscanf(buffer, "%*[^;];%*d;%*d;%*x;%n", &dpos);
197                 if (n) {
198                         free(buffer);
199                         continue;
200                 }
201
202                 if (strstr(buffer + dpos, THIS_PROGRAM_NAME ":") ==
203                         buffer + dpos
204                 ) {
205                         ret = key;
206                         free(buffer);
207                         goto key_search_all_out;
208                 }
209                 free(buffer);
210
211         } while (--count);
212
213         ret = 0;
214
215 key_search_all_out:
216         free(keylist);
217         return ret;
218 }
219
220 /* add or update a specific key to keyring */
221 static key_serial_t
222 key_add(const char *addr, const char *user,
223         const char *domain, const char *pass)
224 {
225         char desc[INET6_ADDRSTRLEN + MAX_USERNAME_SIZE + MAX_DOMAIN_SIZE + \
226                 + sizeof(THIS_PROGRAM_NAME) + 3];
227
228         create_description(addr, user, domain, desc);
229
230         return add_key("user", desc, pass, strnlen(pass, MOUNT_PASSWD_SIZE) + 1,
231                 DEST_KEYRING);
232 }
233
234 /* add command handler */
235 static int cifscreds_add(int argc, char *argv[])
236 {
237         char addrstr[MAX_ADDR_LIST_LEN];
238         char *currentaddress, *nextaddress;
239         char *pass;
240         int ret;
241
242         if (argc != 4 && argc != 5)
243                 usage();
244
245         ret = resolve_host(argv[2], addrstr);
246         switch (ret) {
247         case EX_USAGE:
248                 fprintf(stderr, "error: Could not resolve address "
249                         "for %s\n", argv[2]);
250                 return EXIT_FAILURE;
251
252         case EX_SYSERR:
253                 fprintf(stderr, "error: Problem parsing address list\n");
254                 return EXIT_FAILURE;
255         }
256
257         if (strpbrk(argv[3], USER_DISALLOWED_CHARS)) {
258                 fprintf(stderr, "error: Incorrect username\n");
259                 return EXIT_FAILURE;
260         }
261
262         if (argc == 5) {
263                 if (strspn(argv[4], DOMAIN_ALLOWED_CHARS) !=
264                         strnlen(argv[4], MAX_DOMAIN_SIZE)
265                 ) {
266                         fprintf(stderr, "error: Incorrect domain name\n");
267                         return EXIT_FAILURE;
268                 }
269         }
270
271         /* search for same credentials stashed for current host */
272         currentaddress = addrstr;
273         nextaddress = strchr(currentaddress, ',');
274         if (nextaddress)
275                 *nextaddress++ = '\0';
276
277         while (currentaddress) {
278                 if (key_search(currentaddress, argv[3],
279                         argc == 5 ? argv[4] : NULL) > 0
280                 ) {
281                         printf("You already have stashed credentials "
282                                 "for %s (%s)\n", currentaddress, argv[2]);
283                         printf("If you want to update them use:\n");
284                         printf("\t%s update\n", thisprogram);
285
286                         return EXIT_FAILURE;
287                 }
288
289                 currentaddress = nextaddress;
290                 if (currentaddress) {
291                         *(currentaddress - 1) = ',';
292                         nextaddress = strchr(currentaddress, ',');
293                         if (nextaddress)
294                                 *nextaddress++ = '\0';
295                 }
296         }
297
298         /*
299          * if there isn't same credentials stashed add them to keyring
300          * and set permisson mask
301          */
302         pass = getpass("Password: ");
303
304         currentaddress = addrstr;
305         nextaddress = strchr(currentaddress, ',');
306         if (nextaddress)
307                 *nextaddress++ = '\0';
308
309         while (currentaddress) {
310                 key_serial_t key = key_add(currentaddress, argv[3],
311                                            argc == 5 ? argv[4] : NULL, pass);
312                 if (key <= 0) {
313                         fprintf(stderr, "error: Add credential key for %s\n",
314                                 currentaddress);
315                 } else {
316                         if (keyctl(KEYCTL_SETPERM, key, KEY_POS_VIEW | \
317                                 KEY_POS_WRITE | KEY_USR_VIEW | \
318                                 KEY_USR_WRITE) < 0
319                         ) {
320                                 fprintf(stderr, "error: Setting permissons "
321                                         "on key, attempt to delete...\n");
322
323                                 if (keyctl(KEYCTL_UNLINK, key, DEST_KEYRING) < 0) {
324                                         fprintf(stderr, "error: Deleting key from "
325                                                 "keyring for %s (%s)\n",
326                                                 currentaddress, argv[2]);
327                                 }
328                         }
329                 }
330
331                 currentaddress = nextaddress;
332                 if (currentaddress) {
333                         nextaddress = strchr(currentaddress, ',');
334                         if (nextaddress)
335                                 *nextaddress++ = '\0';
336                 }
337         }
338
339         return EXIT_SUCCESS;
340 }
341
342 /* clear command handler */
343 static int cifscreds_clear(int argc, char *argv[])
344 {
345         char addrstr[MAX_ADDR_LIST_LEN];
346         char *currentaddress, *nextaddress;
347         int ret, count = 0, errors = 0;
348
349         if (argc != 4 && argc != 5)
350                 usage();
351
352         ret = resolve_host(argv[2], addrstr);
353         switch (ret) {
354         case EX_USAGE:
355                 fprintf(stderr, "error: Could not resolve address "
356                         "for %s\n", argv[2]);
357                 return EXIT_FAILURE;
358
359         case EX_SYSERR:
360                 fprintf(stderr, "error: Problem parsing address list\n");
361                 return EXIT_FAILURE;
362         }
363
364         if (strpbrk(argv[3], USER_DISALLOWED_CHARS)) {
365                 fprintf(stderr, "error: Incorrect username\n");
366                 return EXIT_FAILURE;
367         }
368
369         if (argc == 5) {
370                 if (strspn(argv[4], DOMAIN_ALLOWED_CHARS) !=
371                         strnlen(argv[4], MAX_DOMAIN_SIZE)
372                 ) {
373                         fprintf(stderr, "error: Incorrect domain name\n");
374                         return EXIT_FAILURE;
375                 }
376         }
377
378         /*
379          * search for same credentials stashed for current host
380          * and unlink them from session keyring
381          */
382         currentaddress = addrstr;
383         nextaddress = strchr(currentaddress, ',');
384         if (nextaddress)
385                 *nextaddress++ = '\0';
386
387         while (currentaddress) {
388                 key_serial_t key = key_search(currentaddress, argv[3],
389                                                 argc == 5 ? argv[4] : NULL);
390                 if (key > 0) {
391                         if (keyctl(KEYCTL_UNLINK, key, DEST_KEYRING) < 0) {
392                                 fprintf(stderr, "error: Removing key from "
393                                         "keyring for %s (%s)\n",
394                                         currentaddress, argv[2]);
395                                 errors++;
396                         } else {
397                                 count++;
398                         }
399                 }
400
401                 currentaddress = nextaddress;
402                 if (currentaddress) {
403                         nextaddress = strchr(currentaddress, ',');
404                         if (nextaddress)
405                                 *nextaddress++ = '\0';
406                 }
407         }
408
409         if (!count && !errors) {
410                 printf("You have no same stashed credentials "
411                         " for %s\n", argv[2]);
412                 printf("If you want to add them use:\n");
413                 printf("\t%s add\n", thisprogram);
414
415                 return EXIT_FAILURE;
416         }
417
418         return EXIT_SUCCESS;
419 }
420
421 /* clearall command handler */
422 static int cifscreds_clearall(int argc, char *argv[] __attribute__ ((unused)))
423 {
424         key_serial_t key;
425         int count = 0, errors = 0;
426
427         if (argc != 2)
428                 usage();
429
430         /*
431          * search for all program's credentials stashed in session keyring
432          * and then unlink them
433          */
434         do {
435                 key = key_search_all();
436                 if (key > 0) {
437                         if (keyctl(KEYCTL_UNLINK, key, DEST_KEYRING) < 0) {
438                                 fprintf(stderr, "error: Deleting key "
439                                         "from keyring");
440                                 errors++;
441                         } else {
442                                 count++;
443                         }
444                 }
445         } while (key > 0);
446
447         if (!count && !errors) {
448                 printf("You have no stashed " THIS_PROGRAM_NAME
449                         " credentials\n");
450                 printf("If you want to add them use:\n");
451                 printf("\t%s add\n", thisprogram);
452
453                 return EXIT_FAILURE;
454         }
455
456         return EXIT_SUCCESS;
457 }
458
459 /* update command handler */
460 static int cifscreds_update(int argc, char *argv[])
461 {
462         char addrstr[MAX_ADDR_LIST_LEN];
463         char *currentaddress, *nextaddress, *pass;
464         char *addrs[16];
465         int ret, id, count = 0;
466
467         if (argc != 4 && argc != 5)
468                 usage();
469
470         ret = resolve_host(argv[2], addrstr);
471         switch (ret) {
472         case EX_USAGE:
473                 fprintf(stderr, "error: Could not resolve address "
474                         "for %s\n", argv[2]);
475                 return EXIT_FAILURE;
476
477         case EX_SYSERR:
478                 fprintf(stderr, "error: Problem parsing address list\n");
479                 return EXIT_FAILURE;
480         }
481
482         if (strpbrk(argv[3], USER_DISALLOWED_CHARS)) {
483                 fprintf(stderr, "error: Incorrect username\n");
484                 return EXIT_FAILURE;
485         }
486
487         if (argc == 5) {
488                 if (strspn(argv[4], DOMAIN_ALLOWED_CHARS) !=
489                         strnlen(argv[4], MAX_DOMAIN_SIZE)
490                 ) {
491                         fprintf(stderr, "error: Incorrect domain name\n");
492                         return EXIT_FAILURE;
493                 }
494         }
495
496         /* search for necessary credentials stashed in session keyring */
497         currentaddress = addrstr;
498         nextaddress = strchr(currentaddress, ',');
499         if (nextaddress)
500                 *nextaddress++ = '\0';
501
502         while (currentaddress) {
503                 if (key_search(currentaddress, argv[3],
504                         argc == 5 ? argv[4] : NULL) > 0
505                 ) {
506                         addrs[count] = currentaddress;
507                         count++;
508                 }
509
510                 currentaddress = nextaddress;
511                 if (currentaddress) {
512                         nextaddress = strchr(currentaddress, ',');
513                         if (nextaddress)
514                                 *nextaddress++ = '\0';
515                 }
516         }
517
518         if (!count) {
519                 printf("You have no same stashed credentials "
520                         "for %s\n", argv[2]);
521                 printf("If you want to add them use:\n");
522                 printf("\t%s add\n", thisprogram);
523
524                 return EXIT_FAILURE;
525         }
526
527         /* update payload of found keys */
528         pass = getpass("Password: ");
529
530         for (id = 0; id < count; id++) {
531                 key_serial_t key = key_add(addrs[id], argv[3],
532                                         argc == 5 ? argv[4] : NULL, pass);
533                 if (key <= 0)
534                         fprintf(stderr, "error: Update credential key "
535                                 "for %s\n", addrs[id]);
536         }
537
538         return EXIT_SUCCESS;
539 }
540
541 int main(int argc, char **argv)
542 {
543         struct command *cmd, *best;
544         int n;
545
546         thisprogram = (char *)basename(argv[0]);
547         if (thisprogram == NULL)
548                 thisprogram = THIS_PROGRAM_NAME;
549
550         if (argc == 1)
551                 usage();
552
553         /* find the best fit command */
554         best = NULL;
555         n = strnlen(argv[1], MAX_COMMAND_SIZE);
556
557         for (cmd = commands; cmd->action; cmd++) {
558                 if (memcmp(cmd->name, argv[1], n) != 0)
559                         continue;
560
561                 if (cmd->name[n] == 0) {
562                         /* exact match */
563                         best = cmd;
564                         break;
565                 }
566
567                 /* partial match */
568                 if (best) {
569                         fprintf(stderr, "Ambiguous command\n");
570                         exit(EXIT_FAILURE);
571                 }
572
573                 best = cmd;
574         }
575
576         if (!best) {
577                 fprintf(stderr, "Unknown command\n");
578                 exit(EXIT_FAILURE);
579         }
580
581         exit(best->action(argc, argv));
582 }