52a24d4ef5d8a841a6e0f1ad58c14873ae03e666
[sfrench/cifs-2.6.git] / net / rxrpc / rxkad.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Kerberos-based RxRPC security
3  *
4  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <crypto/skcipher.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/skbuff.h>
14 #include <linux/udp.h>
15 #include <linux/scatterlist.h>
16 #include <linux/ctype.h>
17 #include <linux/slab.h>
18 #include <net/sock.h>
19 #include <net/af_rxrpc.h>
20 #include <keys/rxrpc-type.h>
21 #include "ar-internal.h"
22
23 #define RXKAD_VERSION                   2
24 #define MAXKRB5TICKETLEN                1024
25 #define RXKAD_TKT_TYPE_KERBEROS_V5      256
26 #define ANAME_SZ                        40      /* size of authentication name */
27 #define INST_SZ                         40      /* size of principal's instance */
28 #define REALM_SZ                        40      /* size of principal's auth domain */
29 #define SNAME_SZ                        40      /* size of service name */
30
31 struct rxkad_level1_hdr {
32         __be32  data_size;      /* true data size (excluding padding) */
33 };
34
35 struct rxkad_level2_hdr {
36         __be32  data_size;      /* true data size (excluding padding) */
37         __be32  checksum;       /* decrypted data checksum */
38 };
39
40 /*
41  * this holds a pinned cipher so that keventd doesn't get called by the cipher
42  * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE
43  * packets
44  */
45 static struct crypto_sync_skcipher *rxkad_ci;
46 static struct skcipher_request *rxkad_ci_req;
47 static DEFINE_MUTEX(rxkad_ci_mutex);
48
49 /*
50  * initialise connection security
51  */
52 static int rxkad_init_connection_security(struct rxrpc_connection *conn)
53 {
54         struct crypto_sync_skcipher *ci;
55         struct rxrpc_key_token *token;
56         int ret;
57
58         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->params.key));
59
60         token = conn->params.key->payload.data[0];
61         conn->security_ix = token->security_index;
62
63         ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
64         if (IS_ERR(ci)) {
65                 _debug("no cipher");
66                 ret = PTR_ERR(ci);
67                 goto error;
68         }
69
70         if (crypto_sync_skcipher_setkey(ci, token->kad->session_key,
71                                    sizeof(token->kad->session_key)) < 0)
72                 BUG();
73
74         switch (conn->params.security_level) {
75         case RXRPC_SECURITY_PLAIN:
76                 break;
77         case RXRPC_SECURITY_AUTH:
78                 conn->size_align = 8;
79                 conn->security_size = sizeof(struct rxkad_level1_hdr);
80                 break;
81         case RXRPC_SECURITY_ENCRYPT:
82                 conn->size_align = 8;
83                 conn->security_size = sizeof(struct rxkad_level2_hdr);
84                 break;
85         default:
86                 ret = -EKEYREJECTED;
87                 goto error;
88         }
89
90         conn->cipher = ci;
91         ret = 0;
92 error:
93         _leave(" = %d", ret);
94         return ret;
95 }
96
97 /*
98  * prime the encryption state with the invariant parts of a connection's
99  * description
100  */
101 static int rxkad_prime_packet_security(struct rxrpc_connection *conn)
102 {
103         struct skcipher_request *req;
104         struct rxrpc_key_token *token;
105         struct scatterlist sg;
106         struct rxrpc_crypt iv;
107         __be32 *tmpbuf;
108         size_t tmpsize = 4 * sizeof(__be32);
109
110         _enter("");
111
112         if (!conn->params.key)
113                 return 0;
114
115         tmpbuf = kmalloc(tmpsize, GFP_KERNEL);
116         if (!tmpbuf)
117                 return -ENOMEM;
118
119         req = skcipher_request_alloc(&conn->cipher->base, GFP_NOFS);
120         if (!req) {
121                 kfree(tmpbuf);
122                 return -ENOMEM;
123         }
124
125         token = conn->params.key->payload.data[0];
126         memcpy(&iv, token->kad->session_key, sizeof(iv));
127
128         tmpbuf[0] = htonl(conn->proto.epoch);
129         tmpbuf[1] = htonl(conn->proto.cid);
130         tmpbuf[2] = 0;
131         tmpbuf[3] = htonl(conn->security_ix);
132
133         sg_init_one(&sg, tmpbuf, tmpsize);
134         skcipher_request_set_sync_tfm(req, conn->cipher);
135         skcipher_request_set_callback(req, 0, NULL, NULL);
136         skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x);
137         crypto_skcipher_encrypt(req);
138         skcipher_request_free(req);
139
140         memcpy(&conn->csum_iv, tmpbuf + 2, sizeof(conn->csum_iv));
141         kfree(tmpbuf);
142         _leave(" = 0");
143         return 0;
144 }
145
146 /*
147  * Allocate and prepare the crypto request on a call.  For any particular call,
148  * this is called serially for the packets, so no lock should be necessary.
149  */
150 static struct skcipher_request *rxkad_get_call_crypto(struct rxrpc_call *call)
151 {
152         struct crypto_skcipher *tfm = &call->conn->cipher->base;
153         struct skcipher_request *cipher_req = call->cipher_req;
154
155         if (!cipher_req) {
156                 cipher_req = skcipher_request_alloc(tfm, GFP_NOFS);
157                 if (!cipher_req)
158                         return NULL;
159                 call->cipher_req = cipher_req;
160         }
161
162         return cipher_req;
163 }
164
165 /*
166  * Clean up the crypto on a call.
167  */
168 static void rxkad_free_call_crypto(struct rxrpc_call *call)
169 {
170         if (call->cipher_req)
171                 skcipher_request_free(call->cipher_req);
172         call->cipher_req = NULL;
173 }
174
175 /*
176  * partially encrypt a packet (level 1 security)
177  */
178 static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
179                                     struct sk_buff *skb,
180                                     u32 data_size,
181                                     void *sechdr,
182                                     struct skcipher_request *req)
183 {
184         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
185         struct rxkad_level1_hdr hdr;
186         struct rxrpc_crypt iv;
187         struct scatterlist sg;
188         u16 check;
189
190         _enter("");
191
192         check = sp->hdr.seq ^ call->call_id;
193         data_size |= (u32)check << 16;
194
195         hdr.data_size = htonl(data_size);
196         memcpy(sechdr, &hdr, sizeof(hdr));
197
198         /* start the encryption afresh */
199         memset(&iv, 0, sizeof(iv));
200
201         sg_init_one(&sg, sechdr, 8);
202         skcipher_request_set_sync_tfm(req, call->conn->cipher);
203         skcipher_request_set_callback(req, 0, NULL, NULL);
204         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
205         crypto_skcipher_encrypt(req);
206         skcipher_request_zero(req);
207
208         _leave(" = 0");
209         return 0;
210 }
211
212 /*
213  * wholly encrypt a packet (level 2 security)
214  */
215 static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
216                                        struct sk_buff *skb,
217                                        u32 data_size,
218                                        void *sechdr,
219                                        struct skcipher_request *req)
220 {
221         const struct rxrpc_key_token *token;
222         struct rxkad_level2_hdr rxkhdr;
223         struct rxrpc_skb_priv *sp;
224         struct rxrpc_crypt iv;
225         struct scatterlist sg[16];
226         unsigned int len;
227         u16 check;
228         int err;
229
230         sp = rxrpc_skb(skb);
231
232         _enter("");
233
234         check = sp->hdr.seq ^ call->call_id;
235
236         rxkhdr.data_size = htonl(data_size | (u32)check << 16);
237         rxkhdr.checksum = 0;
238         memcpy(sechdr, &rxkhdr, sizeof(rxkhdr));
239
240         /* encrypt from the session key */
241         token = call->conn->params.key->payload.data[0];
242         memcpy(&iv, token->kad->session_key, sizeof(iv));
243
244         sg_init_one(&sg[0], sechdr, sizeof(rxkhdr));
245         skcipher_request_set_sync_tfm(req, call->conn->cipher);
246         skcipher_request_set_callback(req, 0, NULL, NULL);
247         skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x);
248         crypto_skcipher_encrypt(req);
249
250         /* we want to encrypt the skbuff in-place */
251         err = -EMSGSIZE;
252         if (skb_shinfo(skb)->nr_frags > 16)
253                 goto out;
254
255         len = data_size + call->conn->size_align - 1;
256         len &= ~(call->conn->size_align - 1);
257
258         sg_init_table(sg, ARRAY_SIZE(sg));
259         err = skb_to_sgvec(skb, sg, 0, len);
260         if (unlikely(err < 0))
261                 goto out;
262         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
263         crypto_skcipher_encrypt(req);
264
265         _leave(" = 0");
266         err = 0;
267
268 out:
269         skcipher_request_zero(req);
270         return err;
271 }
272
273 /*
274  * checksum an RxRPC packet header
275  */
276 static int rxkad_secure_packet(struct rxrpc_call *call,
277                                struct sk_buff *skb,
278                                size_t data_size,
279                                void *sechdr)
280 {
281         struct rxrpc_skb_priv *sp;
282         struct skcipher_request *req;
283         struct rxrpc_crypt iv;
284         struct scatterlist sg;
285         u32 x, y;
286         int ret;
287
288         sp = rxrpc_skb(skb);
289
290         _enter("{%d{%x}},{#%u},%zu,",
291                call->debug_id, key_serial(call->conn->params.key),
292                sp->hdr.seq, data_size);
293
294         if (!call->conn->cipher)
295                 return 0;
296
297         ret = key_validate(call->conn->params.key);
298         if (ret < 0)
299                 return ret;
300
301         req = rxkad_get_call_crypto(call);
302         if (!req)
303                 return -ENOMEM;
304
305         /* continue encrypting from where we left off */
306         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
307
308         /* calculate the security checksum */
309         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
310         x |= sp->hdr.seq & 0x3fffffff;
311         call->crypto_buf[0] = htonl(call->call_id);
312         call->crypto_buf[1] = htonl(x);
313
314         sg_init_one(&sg, call->crypto_buf, 8);
315         skcipher_request_set_sync_tfm(req, call->conn->cipher);
316         skcipher_request_set_callback(req, 0, NULL, NULL);
317         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
318         crypto_skcipher_encrypt(req);
319         skcipher_request_zero(req);
320
321         y = ntohl(call->crypto_buf[1]);
322         y = (y >> 16) & 0xffff;
323         if (y == 0)
324                 y = 1; /* zero checksums are not permitted */
325         sp->hdr.cksum = y;
326
327         switch (call->conn->params.security_level) {
328         case RXRPC_SECURITY_PLAIN:
329                 ret = 0;
330                 break;
331         case RXRPC_SECURITY_AUTH:
332                 ret = rxkad_secure_packet_auth(call, skb, data_size, sechdr,
333                                                req);
334                 break;
335         case RXRPC_SECURITY_ENCRYPT:
336                 ret = rxkad_secure_packet_encrypt(call, skb, data_size,
337                                                   sechdr, req);
338                 break;
339         default:
340                 ret = -EPERM;
341                 break;
342         }
343
344         _leave(" = %d [set %hx]", ret, y);
345         return ret;
346 }
347
348 /*
349  * decrypt partial encryption on a packet (level 1 security)
350  */
351 static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
352                                  unsigned int offset, unsigned int len,
353                                  rxrpc_seq_t seq,
354                                  struct skcipher_request *req)
355 {
356         struct rxkad_level1_hdr sechdr;
357         struct rxrpc_crypt iv;
358         struct scatterlist sg[16];
359         bool aborted;
360         u32 data_size, buf;
361         u16 check;
362         int ret;
363
364         _enter("");
365
366         if (len < 8) {
367                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H",
368                                            RXKADSEALEDINCON);
369                 goto protocol_error;
370         }
371
372         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
373          * directly into the target buffer.
374          */
375         sg_init_table(sg, ARRAY_SIZE(sg));
376         ret = skb_to_sgvec(skb, sg, offset, 8);
377         if (unlikely(ret < 0))
378                 return ret;
379
380         /* start the decryption afresh */
381         memset(&iv, 0, sizeof(iv));
382
383         skcipher_request_set_sync_tfm(req, call->conn->cipher);
384         skcipher_request_set_callback(req, 0, NULL, NULL);
385         skcipher_request_set_crypt(req, sg, sg, 8, iv.x);
386         crypto_skcipher_decrypt(req);
387         skcipher_request_zero(req);
388
389         /* Extract the decrypted packet length */
390         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
391                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1",
392                                              RXKADDATALEN);
393                 goto protocol_error;
394         }
395         offset += sizeof(sechdr);
396         len -= sizeof(sechdr);
397
398         buf = ntohl(sechdr.data_size);
399         data_size = buf & 0xffff;
400
401         check = buf >> 16;
402         check ^= seq ^ call->call_id;
403         check &= 0xffff;
404         if (check != 0) {
405                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C",
406                                              RXKADSEALEDINCON);
407                 goto protocol_error;
408         }
409
410         if (data_size > len) {
411                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L",
412                                              RXKADDATALEN);
413                 goto protocol_error;
414         }
415
416         _leave(" = 0 [dlen=%x]", data_size);
417         return 0;
418
419 protocol_error:
420         if (aborted)
421                 rxrpc_send_abort_packet(call);
422         return -EPROTO;
423 }
424
425 /*
426  * wholly decrypt a packet (level 2 security)
427  */
428 static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
429                                  unsigned int offset, unsigned int len,
430                                  rxrpc_seq_t seq,
431                                  struct skcipher_request *req)
432 {
433         const struct rxrpc_key_token *token;
434         struct rxkad_level2_hdr sechdr;
435         struct rxrpc_crypt iv;
436         struct scatterlist _sg[4], *sg;
437         bool aborted;
438         u32 data_size, buf;
439         u16 check;
440         int nsg, ret;
441
442         _enter(",{%d}", skb->len);
443
444         if (len < 8) {
445                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H",
446                                              RXKADSEALEDINCON);
447                 goto protocol_error;
448         }
449
450         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
451          * directly into the target buffer.
452          */
453         sg = _sg;
454         nsg = skb_shinfo(skb)->nr_frags;
455         if (nsg <= 4) {
456                 nsg = 4;
457         } else {
458                 sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO);
459                 if (!sg)
460                         goto nomem;
461         }
462
463         sg_init_table(sg, nsg);
464         ret = skb_to_sgvec(skb, sg, offset, len);
465         if (unlikely(ret < 0)) {
466                 if (sg != _sg)
467                         kfree(sg);
468                 return ret;
469         }
470
471         /* decrypt from the session key */
472         token = call->conn->params.key->payload.data[0];
473         memcpy(&iv, token->kad->session_key, sizeof(iv));
474
475         skcipher_request_set_sync_tfm(req, call->conn->cipher);
476         skcipher_request_set_callback(req, 0, NULL, NULL);
477         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
478         crypto_skcipher_decrypt(req);
479         skcipher_request_zero(req);
480         if (sg != _sg)
481                 kfree(sg);
482
483         /* Extract the decrypted packet length */
484         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
485                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2",
486                                              RXKADDATALEN);
487                 goto protocol_error;
488         }
489         offset += sizeof(sechdr);
490         len -= sizeof(sechdr);
491
492         buf = ntohl(sechdr.data_size);
493         data_size = buf & 0xffff;
494
495         check = buf >> 16;
496         check ^= seq ^ call->call_id;
497         check &= 0xffff;
498         if (check != 0) {
499                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C",
500                                              RXKADSEALEDINCON);
501                 goto protocol_error;
502         }
503
504         if (data_size > len) {
505                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L",
506                                              RXKADDATALEN);
507                 goto protocol_error;
508         }
509
510         _leave(" = 0 [dlen=%x]", data_size);
511         return 0;
512
513 protocol_error:
514         if (aborted)
515                 rxrpc_send_abort_packet(call);
516         return -EPROTO;
517
518 nomem:
519         _leave(" = -ENOMEM");
520         return -ENOMEM;
521 }
522
523 /*
524  * Verify the security on a received packet or subpacket (if part of a
525  * jumbo packet).
526  */
527 static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
528                                unsigned int offset, unsigned int len,
529                                rxrpc_seq_t seq, u16 expected_cksum)
530 {
531         struct skcipher_request *req;
532         struct rxrpc_crypt iv;
533         struct scatterlist sg;
534         bool aborted;
535         u16 cksum;
536         u32 x, y;
537
538         _enter("{%d{%x}},{#%u}",
539                call->debug_id, key_serial(call->conn->params.key), seq);
540
541         if (!call->conn->cipher)
542                 return 0;
543
544         req = rxkad_get_call_crypto(call);
545         if (!req)
546                 return -ENOMEM;
547
548         /* continue encrypting from where we left off */
549         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
550
551         /* validate the security checksum */
552         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
553         x |= seq & 0x3fffffff;
554         call->crypto_buf[0] = htonl(call->call_id);
555         call->crypto_buf[1] = htonl(x);
556
557         sg_init_one(&sg, call->crypto_buf, 8);
558         skcipher_request_set_sync_tfm(req, call->conn->cipher);
559         skcipher_request_set_callback(req, 0, NULL, NULL);
560         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
561         crypto_skcipher_encrypt(req);
562         skcipher_request_zero(req);
563
564         y = ntohl(call->crypto_buf[1]);
565         cksum = (y >> 16) & 0xffff;
566         if (cksum == 0)
567                 cksum = 1; /* zero checksums are not permitted */
568
569         if (cksum != expected_cksum) {
570                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK",
571                                              RXKADSEALEDINCON);
572                 goto protocol_error;
573         }
574
575         switch (call->conn->params.security_level) {
576         case RXRPC_SECURITY_PLAIN:
577                 return 0;
578         case RXRPC_SECURITY_AUTH:
579                 return rxkad_verify_packet_1(call, skb, offset, len, seq, req);
580         case RXRPC_SECURITY_ENCRYPT:
581                 return rxkad_verify_packet_2(call, skb, offset, len, seq, req);
582         default:
583                 return -ENOANO;
584         }
585
586 protocol_error:
587         if (aborted)
588                 rxrpc_send_abort_packet(call);
589         return -EPROTO;
590 }
591
592 /*
593  * Locate the data contained in a packet that was partially encrypted.
594  */
595 static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb,
596                                 unsigned int *_offset, unsigned int *_len)
597 {
598         struct rxkad_level1_hdr sechdr;
599
600         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
601                 BUG();
602         *_offset += sizeof(sechdr);
603         *_len = ntohl(sechdr.data_size) & 0xffff;
604 }
605
606 /*
607  * Locate the data contained in a packet that was completely encrypted.
608  */
609 static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb,
610                                 unsigned int *_offset, unsigned int *_len)
611 {
612         struct rxkad_level2_hdr sechdr;
613
614         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
615                 BUG();
616         *_offset += sizeof(sechdr);
617         *_len = ntohl(sechdr.data_size) & 0xffff;
618 }
619
620 /*
621  * Locate the data contained in an already decrypted packet.
622  */
623 static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb,
624                               unsigned int *_offset, unsigned int *_len)
625 {
626         switch (call->conn->params.security_level) {
627         case RXRPC_SECURITY_AUTH:
628                 rxkad_locate_data_1(call, skb, _offset, _len);
629                 return;
630         case RXRPC_SECURITY_ENCRYPT:
631                 rxkad_locate_data_2(call, skb, _offset, _len);
632                 return;
633         default:
634                 return;
635         }
636 }
637
638 /*
639  * issue a challenge
640  */
641 static int rxkad_issue_challenge(struct rxrpc_connection *conn)
642 {
643         struct rxkad_challenge challenge;
644         struct rxrpc_wire_header whdr;
645         struct msghdr msg;
646         struct kvec iov[2];
647         size_t len;
648         u32 serial;
649         int ret;
650
651         _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
652
653         ret = key_validate(conn->server_key);
654         if (ret < 0)
655                 return ret;
656
657         get_random_bytes(&conn->security_nonce, sizeof(conn->security_nonce));
658
659         challenge.version       = htonl(2);
660         challenge.nonce         = htonl(conn->security_nonce);
661         challenge.min_level     = htonl(0);
662         challenge.__padding     = 0;
663
664         msg.msg_name    = &conn->params.peer->srx.transport;
665         msg.msg_namelen = conn->params.peer->srx.transport_len;
666         msg.msg_control = NULL;
667         msg.msg_controllen = 0;
668         msg.msg_flags   = 0;
669
670         whdr.epoch      = htonl(conn->proto.epoch);
671         whdr.cid        = htonl(conn->proto.cid);
672         whdr.callNumber = 0;
673         whdr.seq        = 0;
674         whdr.type       = RXRPC_PACKET_TYPE_CHALLENGE;
675         whdr.flags      = conn->out_clientflag;
676         whdr.userStatus = 0;
677         whdr.securityIndex = conn->security_ix;
678         whdr._rsvd      = 0;
679         whdr.serviceId  = htons(conn->service_id);
680
681         iov[0].iov_base = &whdr;
682         iov[0].iov_len  = sizeof(whdr);
683         iov[1].iov_base = &challenge;
684         iov[1].iov_len  = sizeof(challenge);
685
686         len = iov[0].iov_len + iov[1].iov_len;
687
688         serial = atomic_inc_return(&conn->serial);
689         whdr.serial = htonl(serial);
690         _proto("Tx CHALLENGE %%%u", serial);
691
692         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len);
693         if (ret < 0) {
694                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
695                                     rxrpc_tx_point_rxkad_challenge);
696                 return -EAGAIN;
697         }
698
699         conn->params.peer->last_tx_at = ktime_get_seconds();
700         trace_rxrpc_tx_packet(conn->debug_id, &whdr,
701                               rxrpc_tx_point_rxkad_challenge);
702         _leave(" = 0");
703         return 0;
704 }
705
706 /*
707  * send a Kerberos security response
708  */
709 static int rxkad_send_response(struct rxrpc_connection *conn,
710                                struct rxrpc_host_header *hdr,
711                                struct rxkad_response *resp,
712                                const struct rxkad_key *s2)
713 {
714         struct rxrpc_wire_header whdr;
715         struct msghdr msg;
716         struct kvec iov[3];
717         size_t len;
718         u32 serial;
719         int ret;
720
721         _enter("");
722
723         msg.msg_name    = &conn->params.peer->srx.transport;
724         msg.msg_namelen = conn->params.peer->srx.transport_len;
725         msg.msg_control = NULL;
726         msg.msg_controllen = 0;
727         msg.msg_flags   = 0;
728
729         memset(&whdr, 0, sizeof(whdr));
730         whdr.epoch      = htonl(hdr->epoch);
731         whdr.cid        = htonl(hdr->cid);
732         whdr.type       = RXRPC_PACKET_TYPE_RESPONSE;
733         whdr.flags      = conn->out_clientflag;
734         whdr.securityIndex = hdr->securityIndex;
735         whdr.serviceId  = htons(hdr->serviceId);
736
737         iov[0].iov_base = &whdr;
738         iov[0].iov_len  = sizeof(whdr);
739         iov[1].iov_base = resp;
740         iov[1].iov_len  = sizeof(*resp);
741         iov[2].iov_base = (void *)s2->ticket;
742         iov[2].iov_len  = s2->ticket_len;
743
744         len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len;
745
746         serial = atomic_inc_return(&conn->serial);
747         whdr.serial = htonl(serial);
748         _proto("Tx RESPONSE %%%u", serial);
749
750         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 3, len);
751         if (ret < 0) {
752                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
753                                     rxrpc_tx_point_rxkad_response);
754                 return -EAGAIN;
755         }
756
757         conn->params.peer->last_tx_at = ktime_get_seconds();
758         _leave(" = 0");
759         return 0;
760 }
761
762 /*
763  * calculate the response checksum
764  */
765 static void rxkad_calc_response_checksum(struct rxkad_response *response)
766 {
767         u32 csum = 1000003;
768         int loop;
769         u8 *p = (u8 *) response;
770
771         for (loop = sizeof(*response); loop > 0; loop--)
772                 csum = csum * 0x10204081 + *p++;
773
774         response->encrypted.checksum = htonl(csum);
775 }
776
777 /*
778  * encrypt the response packet
779  */
780 static int rxkad_encrypt_response(struct rxrpc_connection *conn,
781                                   struct rxkad_response *resp,
782                                   const struct rxkad_key *s2)
783 {
784         struct skcipher_request *req;
785         struct rxrpc_crypt iv;
786         struct scatterlist sg[1];
787
788         req = skcipher_request_alloc(&conn->cipher->base, GFP_NOFS);
789         if (!req)
790                 return -ENOMEM;
791
792         /* continue encrypting from where we left off */
793         memcpy(&iv, s2->session_key, sizeof(iv));
794
795         sg_init_table(sg, 1);
796         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
797         skcipher_request_set_sync_tfm(req, conn->cipher);
798         skcipher_request_set_callback(req, 0, NULL, NULL);
799         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
800         crypto_skcipher_encrypt(req);
801         skcipher_request_free(req);
802         return 0;
803 }
804
805 /*
806  * respond to a challenge packet
807  */
808 static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
809                                       struct sk_buff *skb,
810                                       u32 *_abort_code)
811 {
812         const struct rxrpc_key_token *token;
813         struct rxkad_challenge challenge;
814         struct rxkad_response *resp;
815         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
816         const char *eproto;
817         u32 version, nonce, min_level, abort_code;
818         int ret;
819
820         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
821
822         eproto = tracepoint_string("chall_no_key");
823         abort_code = RX_PROTOCOL_ERROR;
824         if (!conn->params.key)
825                 goto protocol_error;
826
827         abort_code = RXKADEXPIRED;
828         ret = key_validate(conn->params.key);
829         if (ret < 0)
830                 goto other_error;
831
832         eproto = tracepoint_string("chall_short");
833         abort_code = RXKADPACKETSHORT;
834         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
835                           &challenge, sizeof(challenge)) < 0)
836                 goto protocol_error;
837
838         version = ntohl(challenge.version);
839         nonce = ntohl(challenge.nonce);
840         min_level = ntohl(challenge.min_level);
841
842         _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }",
843                sp->hdr.serial, version, nonce, min_level);
844
845         eproto = tracepoint_string("chall_ver");
846         abort_code = RXKADINCONSISTENCY;
847         if (version != RXKAD_VERSION)
848                 goto protocol_error;
849
850         abort_code = RXKADLEVELFAIL;
851         ret = -EACCES;
852         if (conn->params.security_level < min_level)
853                 goto other_error;
854
855         token = conn->params.key->payload.data[0];
856
857         /* build the response packet */
858         resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
859         if (!resp)
860                 return -ENOMEM;
861
862         resp->version                   = htonl(RXKAD_VERSION);
863         resp->encrypted.epoch           = htonl(conn->proto.epoch);
864         resp->encrypted.cid             = htonl(conn->proto.cid);
865         resp->encrypted.securityIndex   = htonl(conn->security_ix);
866         resp->encrypted.inc_nonce       = htonl(nonce + 1);
867         resp->encrypted.level           = htonl(conn->params.security_level);
868         resp->kvno                      = htonl(token->kad->kvno);
869         resp->ticket_len                = htonl(token->kad->ticket_len);
870         resp->encrypted.call_id[0]      = htonl(conn->channels[0].call_counter);
871         resp->encrypted.call_id[1]      = htonl(conn->channels[1].call_counter);
872         resp->encrypted.call_id[2]      = htonl(conn->channels[2].call_counter);
873         resp->encrypted.call_id[3]      = htonl(conn->channels[3].call_counter);
874
875         /* calculate the response checksum and then do the encryption */
876         rxkad_calc_response_checksum(resp);
877         ret = rxkad_encrypt_response(conn, resp, token->kad);
878         if (ret == 0)
879                 ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad);
880         kfree(resp);
881         return ret;
882
883 protocol_error:
884         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
885         ret = -EPROTO;
886 other_error:
887         *_abort_code = abort_code;
888         return ret;
889 }
890
891 /*
892  * decrypt the kerberos IV ticket in the response
893  */
894 static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
895                                 struct sk_buff *skb,
896                                 void *ticket, size_t ticket_len,
897                                 struct rxrpc_crypt *_session_key,
898                                 time64_t *_expiry,
899                                 u32 *_abort_code)
900 {
901         struct skcipher_request *req;
902         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
903         struct rxrpc_crypt iv, key;
904         struct scatterlist sg[1];
905         struct in_addr addr;
906         unsigned int life;
907         const char *eproto;
908         time64_t issue, now;
909         bool little_endian;
910         int ret;
911         u32 abort_code;
912         u8 *p, *q, *name, *end;
913
914         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key));
915
916         *_expiry = 0;
917
918         ret = key_validate(conn->server_key);
919         if (ret < 0) {
920                 switch (ret) {
921                 case -EKEYEXPIRED:
922                         abort_code = RXKADEXPIRED;
923                         goto other_error;
924                 default:
925                         abort_code = RXKADNOAUTH;
926                         goto other_error;
927                 }
928         }
929
930         ASSERT(conn->server_key->payload.data[0] != NULL);
931         ASSERTCMP((unsigned long) ticket & 7UL, ==, 0);
932
933         memcpy(&iv, &conn->server_key->payload.data[2], sizeof(iv));
934
935         ret = -ENOMEM;
936         req = skcipher_request_alloc(conn->server_key->payload.data[0],
937                                      GFP_NOFS);
938         if (!req)
939                 goto temporary_error;
940
941         sg_init_one(&sg[0], ticket, ticket_len);
942         skcipher_request_set_callback(req, 0, NULL, NULL);
943         skcipher_request_set_crypt(req, sg, sg, ticket_len, iv.x);
944         crypto_skcipher_decrypt(req);
945         skcipher_request_free(req);
946
947         p = ticket;
948         end = p + ticket_len;
949
950 #define Z(field)                                        \
951         ({                                              \
952                 u8 *__str = p;                          \
953                 eproto = tracepoint_string("rxkad_bad_"#field); \
954                 q = memchr(p, 0, end - p);              \
955                 if (!q || q - p > (field##_SZ))         \
956                         goto bad_ticket;                \
957                 for (; p < q; p++)                      \
958                         if (!isprint(*p))               \
959                                 goto bad_ticket;        \
960                 p++;                                    \
961                 __str;                                  \
962         })
963
964         /* extract the ticket flags */
965         _debug("KIV FLAGS: %x", *p);
966         little_endian = *p & 1;
967         p++;
968
969         /* extract the authentication name */
970         name = Z(ANAME);
971         _debug("KIV ANAME: %s", name);
972
973         /* extract the principal's instance */
974         name = Z(INST);
975         _debug("KIV INST : %s", name);
976
977         /* extract the principal's authentication domain */
978         name = Z(REALM);
979         _debug("KIV REALM: %s", name);
980
981         eproto = tracepoint_string("rxkad_bad_len");
982         if (end - p < 4 + 8 + 4 + 2)
983                 goto bad_ticket;
984
985         /* get the IPv4 address of the entity that requested the ticket */
986         memcpy(&addr, p, sizeof(addr));
987         p += 4;
988         _debug("KIV ADDR : %pI4", &addr);
989
990         /* get the session key from the ticket */
991         memcpy(&key, p, sizeof(key));
992         p += 8;
993         _debug("KIV KEY  : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1]));
994         memcpy(_session_key, &key, sizeof(key));
995
996         /* get the ticket's lifetime */
997         life = *p++ * 5 * 60;
998         _debug("KIV LIFE : %u", life);
999
1000         /* get the issue time of the ticket */
1001         if (little_endian) {
1002                 __le32 stamp;
1003                 memcpy(&stamp, p, 4);
1004                 issue = rxrpc_u32_to_time64(le32_to_cpu(stamp));
1005         } else {
1006                 __be32 stamp;
1007                 memcpy(&stamp, p, 4);
1008                 issue = rxrpc_u32_to_time64(be32_to_cpu(stamp));
1009         }
1010         p += 4;
1011         now = ktime_get_real_seconds();
1012         _debug("KIV ISSUE: %llx [%llx]", issue, now);
1013
1014         /* check the ticket is in date */
1015         if (issue > now) {
1016                 abort_code = RXKADNOAUTH;
1017                 ret = -EKEYREJECTED;
1018                 goto other_error;
1019         }
1020
1021         if (issue < now - life) {
1022                 abort_code = RXKADEXPIRED;
1023                 ret = -EKEYEXPIRED;
1024                 goto other_error;
1025         }
1026
1027         *_expiry = issue + life;
1028
1029         /* get the service name */
1030         name = Z(SNAME);
1031         _debug("KIV SNAME: %s", name);
1032
1033         /* get the service instance name */
1034         name = Z(INST);
1035         _debug("KIV SINST: %s", name);
1036         return 0;
1037
1038 bad_ticket:
1039         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1040         abort_code = RXKADBADTICKET;
1041         ret = -EPROTO;
1042 other_error:
1043         *_abort_code = abort_code;
1044         return ret;
1045 temporary_error:
1046         return ret;
1047 }
1048
1049 /*
1050  * decrypt the response packet
1051  */
1052 static void rxkad_decrypt_response(struct rxrpc_connection *conn,
1053                                    struct rxkad_response *resp,
1054                                    const struct rxrpc_crypt *session_key)
1055 {
1056         struct skcipher_request *req = rxkad_ci_req;
1057         struct scatterlist sg[1];
1058         struct rxrpc_crypt iv;
1059
1060         _enter(",,%08x%08x",
1061                ntohl(session_key->n[0]), ntohl(session_key->n[1]));
1062
1063         mutex_lock(&rxkad_ci_mutex);
1064         if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x,
1065                                         sizeof(*session_key)) < 0)
1066                 BUG();
1067
1068         memcpy(&iv, session_key, sizeof(iv));
1069
1070         sg_init_table(sg, 1);
1071         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
1072         skcipher_request_set_sync_tfm(req, rxkad_ci);
1073         skcipher_request_set_callback(req, 0, NULL, NULL);
1074         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
1075         crypto_skcipher_decrypt(req);
1076         skcipher_request_zero(req);
1077
1078         mutex_unlock(&rxkad_ci_mutex);
1079
1080         _leave("");
1081 }
1082
1083 /*
1084  * verify a response
1085  */
1086 static int rxkad_verify_response(struct rxrpc_connection *conn,
1087                                  struct sk_buff *skb,
1088                                  u32 *_abort_code)
1089 {
1090         struct rxkad_response *response;
1091         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
1092         struct rxrpc_crypt session_key;
1093         const char *eproto;
1094         time64_t expiry;
1095         void *ticket;
1096         u32 abort_code, version, kvno, ticket_len, level;
1097         __be32 csum;
1098         int ret, i;
1099
1100         _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
1101
1102         ret = -ENOMEM;
1103         response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
1104         if (!response)
1105                 goto temporary_error;
1106
1107         eproto = tracepoint_string("rxkad_rsp_short");
1108         abort_code = RXKADPACKETSHORT;
1109         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1110                           response, sizeof(*response)) < 0)
1111                 goto protocol_error;
1112         if (!pskb_pull(skb, sizeof(*response)))
1113                 BUG();
1114
1115         version = ntohl(response->version);
1116         ticket_len = ntohl(response->ticket_len);
1117         kvno = ntohl(response->kvno);
1118         _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
1119                sp->hdr.serial, version, kvno, ticket_len);
1120
1121         eproto = tracepoint_string("rxkad_rsp_ver");
1122         abort_code = RXKADINCONSISTENCY;
1123         if (version != RXKAD_VERSION)
1124                 goto protocol_error;
1125
1126         eproto = tracepoint_string("rxkad_rsp_tktlen");
1127         abort_code = RXKADTICKETLEN;
1128         if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
1129                 goto protocol_error;
1130
1131         eproto = tracepoint_string("rxkad_rsp_unkkey");
1132         abort_code = RXKADUNKNOWNKEY;
1133         if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
1134                 goto protocol_error;
1135
1136         /* extract the kerberos ticket and decrypt and decode it */
1137         ret = -ENOMEM;
1138         ticket = kmalloc(ticket_len, GFP_NOFS);
1139         if (!ticket)
1140                 goto temporary_error;
1141
1142         eproto = tracepoint_string("rxkad_tkt_short");
1143         abort_code = RXKADPACKETSHORT;
1144         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1145                           ticket, ticket_len) < 0)
1146                 goto protocol_error_free;
1147
1148         ret = rxkad_decrypt_ticket(conn, skb, ticket, ticket_len, &session_key,
1149                                    &expiry, _abort_code);
1150         if (ret < 0)
1151                 goto temporary_error_free_ticket;
1152
1153         /* use the session key from inside the ticket to decrypt the
1154          * response */
1155         rxkad_decrypt_response(conn, response, &session_key);
1156
1157         eproto = tracepoint_string("rxkad_rsp_param");
1158         abort_code = RXKADSEALEDINCON;
1159         if (ntohl(response->encrypted.epoch) != conn->proto.epoch)
1160                 goto protocol_error_free;
1161         if (ntohl(response->encrypted.cid) != conn->proto.cid)
1162                 goto protocol_error_free;
1163         if (ntohl(response->encrypted.securityIndex) != conn->security_ix)
1164                 goto protocol_error_free;
1165         csum = response->encrypted.checksum;
1166         response->encrypted.checksum = 0;
1167         rxkad_calc_response_checksum(response);
1168         eproto = tracepoint_string("rxkad_rsp_csum");
1169         if (response->encrypted.checksum != csum)
1170                 goto protocol_error_free;
1171
1172         spin_lock(&conn->channel_lock);
1173         for (i = 0; i < RXRPC_MAXCALLS; i++) {
1174                 struct rxrpc_call *call;
1175                 u32 call_id = ntohl(response->encrypted.call_id[i]);
1176
1177                 eproto = tracepoint_string("rxkad_rsp_callid");
1178                 if (call_id > INT_MAX)
1179                         goto protocol_error_unlock;
1180
1181                 eproto = tracepoint_string("rxkad_rsp_callctr");
1182                 if (call_id < conn->channels[i].call_counter)
1183                         goto protocol_error_unlock;
1184
1185                 eproto = tracepoint_string("rxkad_rsp_callst");
1186                 if (call_id > conn->channels[i].call_counter) {
1187                         call = rcu_dereference_protected(
1188                                 conn->channels[i].call,
1189                                 lockdep_is_held(&conn->channel_lock));
1190                         if (call && call->state < RXRPC_CALL_COMPLETE)
1191                                 goto protocol_error_unlock;
1192                         conn->channels[i].call_counter = call_id;
1193                 }
1194         }
1195         spin_unlock(&conn->channel_lock);
1196
1197         eproto = tracepoint_string("rxkad_rsp_seq");
1198         abort_code = RXKADOUTOFSEQUENCE;
1199         if (ntohl(response->encrypted.inc_nonce) != conn->security_nonce + 1)
1200                 goto protocol_error_free;
1201
1202         eproto = tracepoint_string("rxkad_rsp_level");
1203         abort_code = RXKADLEVELFAIL;
1204         level = ntohl(response->encrypted.level);
1205         if (level > RXRPC_SECURITY_ENCRYPT)
1206                 goto protocol_error_free;
1207         conn->params.security_level = level;
1208
1209         /* create a key to hold the security data and expiration time - after
1210          * this the connection security can be handled in exactly the same way
1211          * as for a client connection */
1212         ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
1213         if (ret < 0)
1214                 goto temporary_error_free_ticket;
1215
1216         kfree(ticket);
1217         kfree(response);
1218         _leave(" = 0");
1219         return 0;
1220
1221 protocol_error_unlock:
1222         spin_unlock(&conn->channel_lock);
1223 protocol_error_free:
1224         kfree(ticket);
1225 protocol_error:
1226         kfree(response);
1227         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1228         *_abort_code = abort_code;
1229         return -EPROTO;
1230
1231 temporary_error_free_ticket:
1232         kfree(ticket);
1233         kfree(response);
1234 temporary_error:
1235         /* Ignore the response packet if we got a temporary error such as
1236          * ENOMEM.  We just want to send the challenge again.  Note that we
1237          * also come out this way if the ticket decryption fails.
1238          */
1239         return ret;
1240 }
1241
1242 /*
1243  * clear the connection security
1244  */
1245 static void rxkad_clear(struct rxrpc_connection *conn)
1246 {
1247         _enter("");
1248
1249         if (conn->cipher)
1250                 crypto_free_sync_skcipher(conn->cipher);
1251 }
1252
1253 /*
1254  * Initialise the rxkad security service.
1255  */
1256 static int rxkad_init(void)
1257 {
1258         struct crypto_sync_skcipher *tfm;
1259         struct skcipher_request *req;
1260
1261         /* pin the cipher we need so that the crypto layer doesn't invoke
1262          * keventd to go get it */
1263         tfm = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
1264         if (IS_ERR(tfm))
1265                 return PTR_ERR(tfm);
1266
1267         req = skcipher_request_alloc(&tfm->base, GFP_KERNEL);
1268         if (!req)
1269                 goto nomem_tfm;
1270
1271         rxkad_ci_req = req;
1272         rxkad_ci = tfm;
1273         return 0;
1274
1275 nomem_tfm:
1276         crypto_free_sync_skcipher(tfm);
1277         return -ENOMEM;
1278 }
1279
1280 /*
1281  * Clean up the rxkad security service.
1282  */
1283 static void rxkad_exit(void)
1284 {
1285         crypto_free_sync_skcipher(rxkad_ci);
1286         skcipher_request_free(rxkad_ci_req);
1287 }
1288
1289 /*
1290  * RxRPC Kerberos-based security
1291  */
1292 const struct rxrpc_security rxkad = {
1293         .name                           = "rxkad",
1294         .security_index                 = RXRPC_SECURITY_RXKAD,
1295         .no_key_abort                   = RXKADUNKNOWNKEY,
1296         .init                           = rxkad_init,
1297         .exit                           = rxkad_exit,
1298         .init_connection_security       = rxkad_init_connection_security,
1299         .prime_packet_security          = rxkad_prime_packet_security,
1300         .secure_packet                  = rxkad_secure_packet,
1301         .verify_packet                  = rxkad_verify_packet,
1302         .free_call_crypto               = rxkad_free_call_crypto,
1303         .locate_data                    = rxkad_locate_data,
1304         .issue_challenge                = rxkad_issue_challenge,
1305         .respond_to_challenge           = rxkad_respond_to_challenge,
1306         .verify_response                = rxkad_verify_response,
1307         .clear                          = rxkad_clear,
1308 };