Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[sfrench/cifs-2.6.git] / net / rxrpc / rxkad.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Kerberos-based RxRPC security
3  *
4  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <crypto/skcipher.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/skbuff.h>
14 #include <linux/udp.h>
15 #include <linux/scatterlist.h>
16 #include <linux/ctype.h>
17 #include <linux/slab.h>
18 #include <net/sock.h>
19 #include <net/af_rxrpc.h>
20 #include <keys/rxrpc-type.h>
21 #include "ar-internal.h"
22
23 #define RXKAD_VERSION                   2
24 #define MAXKRB5TICKETLEN                1024
25 #define RXKAD_TKT_TYPE_KERBEROS_V5      256
26 #define ANAME_SZ                        40      /* size of authentication name */
27 #define INST_SZ                         40      /* size of principal's instance */
28 #define REALM_SZ                        40      /* size of principal's auth domain */
29 #define SNAME_SZ                        40      /* size of service name */
30
31 struct rxkad_level1_hdr {
32         __be32  data_size;      /* true data size (excluding padding) */
33 };
34
35 struct rxkad_level2_hdr {
36         __be32  data_size;      /* true data size (excluding padding) */
37         __be32  checksum;       /* decrypted data checksum */
38 };
39
40 /*
41  * this holds a pinned cipher so that keventd doesn't get called by the cipher
42  * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE
43  * packets
44  */
45 static struct crypto_sync_skcipher *rxkad_ci;
46 static DEFINE_MUTEX(rxkad_ci_mutex);
47
48 /*
49  * initialise connection security
50  */
51 static int rxkad_init_connection_security(struct rxrpc_connection *conn)
52 {
53         struct crypto_sync_skcipher *ci;
54         struct rxrpc_key_token *token;
55         int ret;
56
57         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->params.key));
58
59         token = conn->params.key->payload.data[0];
60         conn->security_ix = token->security_index;
61
62         ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
63         if (IS_ERR(ci)) {
64                 _debug("no cipher");
65                 ret = PTR_ERR(ci);
66                 goto error;
67         }
68
69         if (crypto_sync_skcipher_setkey(ci, token->kad->session_key,
70                                    sizeof(token->kad->session_key)) < 0)
71                 BUG();
72
73         switch (conn->params.security_level) {
74         case RXRPC_SECURITY_PLAIN:
75                 break;
76         case RXRPC_SECURITY_AUTH:
77                 conn->size_align = 8;
78                 conn->security_size = sizeof(struct rxkad_level1_hdr);
79                 break;
80         case RXRPC_SECURITY_ENCRYPT:
81                 conn->size_align = 8;
82                 conn->security_size = sizeof(struct rxkad_level2_hdr);
83                 break;
84         default:
85                 ret = -EKEYREJECTED;
86                 goto error;
87         }
88
89         conn->cipher = ci;
90         ret = 0;
91 error:
92         _leave(" = %d", ret);
93         return ret;
94 }
95
96 /*
97  * prime the encryption state with the invariant parts of a connection's
98  * description
99  */
100 static int rxkad_prime_packet_security(struct rxrpc_connection *conn)
101 {
102         struct rxrpc_key_token *token;
103         SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher);
104         struct scatterlist sg;
105         struct rxrpc_crypt iv;
106         __be32 *tmpbuf;
107         size_t tmpsize = 4 * sizeof(__be32);
108
109         _enter("");
110
111         if (!conn->params.key)
112                 return 0;
113
114         tmpbuf = kmalloc(tmpsize, GFP_KERNEL);
115         if (!tmpbuf)
116                 return -ENOMEM;
117
118         token = conn->params.key->payload.data[0];
119         memcpy(&iv, token->kad->session_key, sizeof(iv));
120
121         tmpbuf[0] = htonl(conn->proto.epoch);
122         tmpbuf[1] = htonl(conn->proto.cid);
123         tmpbuf[2] = 0;
124         tmpbuf[3] = htonl(conn->security_ix);
125
126         sg_init_one(&sg, tmpbuf, tmpsize);
127         skcipher_request_set_sync_tfm(req, conn->cipher);
128         skcipher_request_set_callback(req, 0, NULL, NULL);
129         skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x);
130         crypto_skcipher_encrypt(req);
131         skcipher_request_zero(req);
132
133         memcpy(&conn->csum_iv, tmpbuf + 2, sizeof(conn->csum_iv));
134         kfree(tmpbuf);
135         _leave(" = 0");
136         return 0;
137 }
138
139 /*
140  * partially encrypt a packet (level 1 security)
141  */
142 static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
143                                     struct sk_buff *skb,
144                                     u32 data_size,
145                                     void *sechdr,
146                                     struct skcipher_request *req)
147 {
148         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
149         struct rxkad_level1_hdr hdr;
150         struct rxrpc_crypt iv;
151         struct scatterlist sg;
152         u16 check;
153
154         _enter("");
155
156         check = sp->hdr.seq ^ call->call_id;
157         data_size |= (u32)check << 16;
158
159         hdr.data_size = htonl(data_size);
160         memcpy(sechdr, &hdr, sizeof(hdr));
161
162         /* start the encryption afresh */
163         memset(&iv, 0, sizeof(iv));
164
165         sg_init_one(&sg, sechdr, 8);
166         skcipher_request_set_sync_tfm(req, call->conn->cipher);
167         skcipher_request_set_callback(req, 0, NULL, NULL);
168         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
169         crypto_skcipher_encrypt(req);
170         skcipher_request_zero(req);
171
172         _leave(" = 0");
173         return 0;
174 }
175
176 /*
177  * wholly encrypt a packet (level 2 security)
178  */
179 static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
180                                        struct sk_buff *skb,
181                                        u32 data_size,
182                                        void *sechdr,
183                                        struct skcipher_request *req)
184 {
185         const struct rxrpc_key_token *token;
186         struct rxkad_level2_hdr rxkhdr;
187         struct rxrpc_skb_priv *sp;
188         struct rxrpc_crypt iv;
189         struct scatterlist sg[16];
190         unsigned int len;
191         u16 check;
192         int err;
193
194         sp = rxrpc_skb(skb);
195
196         _enter("");
197
198         check = sp->hdr.seq ^ call->call_id;
199
200         rxkhdr.data_size = htonl(data_size | (u32)check << 16);
201         rxkhdr.checksum = 0;
202         memcpy(sechdr, &rxkhdr, sizeof(rxkhdr));
203
204         /* encrypt from the session key */
205         token = call->conn->params.key->payload.data[0];
206         memcpy(&iv, token->kad->session_key, sizeof(iv));
207
208         sg_init_one(&sg[0], sechdr, sizeof(rxkhdr));
209         skcipher_request_set_sync_tfm(req, call->conn->cipher);
210         skcipher_request_set_callback(req, 0, NULL, NULL);
211         skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x);
212         crypto_skcipher_encrypt(req);
213
214         /* we want to encrypt the skbuff in-place */
215         err = -EMSGSIZE;
216         if (skb_shinfo(skb)->nr_frags > 16)
217                 goto out;
218
219         len = data_size + call->conn->size_align - 1;
220         len &= ~(call->conn->size_align - 1);
221
222         sg_init_table(sg, ARRAY_SIZE(sg));
223         err = skb_to_sgvec(skb, sg, 0, len);
224         if (unlikely(err < 0))
225                 goto out;
226         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
227         crypto_skcipher_encrypt(req);
228
229         _leave(" = 0");
230         err = 0;
231
232 out:
233         skcipher_request_zero(req);
234         return err;
235 }
236
237 /*
238  * checksum an RxRPC packet header
239  */
240 static int rxkad_secure_packet(struct rxrpc_call *call,
241                                struct sk_buff *skb,
242                                size_t data_size,
243                                void *sechdr)
244 {
245         struct rxrpc_skb_priv *sp;
246         SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
247         struct rxrpc_crypt iv;
248         struct scatterlist sg;
249         u32 x, y;
250         int ret;
251
252         sp = rxrpc_skb(skb);
253
254         _enter("{%d{%x}},{#%u},%zu,",
255                call->debug_id, key_serial(call->conn->params.key),
256                sp->hdr.seq, data_size);
257
258         if (!call->conn->cipher)
259                 return 0;
260
261         ret = key_validate(call->conn->params.key);
262         if (ret < 0)
263                 return ret;
264
265         /* continue encrypting from where we left off */
266         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
267
268         /* calculate the security checksum */
269         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
270         x |= sp->hdr.seq & 0x3fffffff;
271         call->crypto_buf[0] = htonl(call->call_id);
272         call->crypto_buf[1] = htonl(x);
273
274         sg_init_one(&sg, call->crypto_buf, 8);
275         skcipher_request_set_sync_tfm(req, call->conn->cipher);
276         skcipher_request_set_callback(req, 0, NULL, NULL);
277         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
278         crypto_skcipher_encrypt(req);
279         skcipher_request_zero(req);
280
281         y = ntohl(call->crypto_buf[1]);
282         y = (y >> 16) & 0xffff;
283         if (y == 0)
284                 y = 1; /* zero checksums are not permitted */
285         sp->hdr.cksum = y;
286
287         switch (call->conn->params.security_level) {
288         case RXRPC_SECURITY_PLAIN:
289                 ret = 0;
290                 break;
291         case RXRPC_SECURITY_AUTH:
292                 ret = rxkad_secure_packet_auth(call, skb, data_size, sechdr,
293                                                req);
294                 break;
295         case RXRPC_SECURITY_ENCRYPT:
296                 ret = rxkad_secure_packet_encrypt(call, skb, data_size,
297                                                   sechdr, req);
298                 break;
299         default:
300                 ret = -EPERM;
301                 break;
302         }
303
304         _leave(" = %d [set %hx]", ret, y);
305         return ret;
306 }
307
308 /*
309  * decrypt partial encryption on a packet (level 1 security)
310  */
311 static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
312                                  unsigned int offset, unsigned int len,
313                                  rxrpc_seq_t seq,
314                                  struct skcipher_request *req)
315 {
316         struct rxkad_level1_hdr sechdr;
317         struct rxrpc_crypt iv;
318         struct scatterlist sg[16];
319         bool aborted;
320         u32 data_size, buf;
321         u16 check;
322         int ret;
323
324         _enter("");
325
326         if (len < 8) {
327                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H",
328                                            RXKADSEALEDINCON);
329                 goto protocol_error;
330         }
331
332         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
333          * directly into the target buffer.
334          */
335         sg_init_table(sg, ARRAY_SIZE(sg));
336         ret = skb_to_sgvec(skb, sg, offset, 8);
337         if (unlikely(ret < 0))
338                 return ret;
339
340         /* start the decryption afresh */
341         memset(&iv, 0, sizeof(iv));
342
343         skcipher_request_set_sync_tfm(req, call->conn->cipher);
344         skcipher_request_set_callback(req, 0, NULL, NULL);
345         skcipher_request_set_crypt(req, sg, sg, 8, iv.x);
346         crypto_skcipher_decrypt(req);
347         skcipher_request_zero(req);
348
349         /* Extract the decrypted packet length */
350         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
351                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1",
352                                              RXKADDATALEN);
353                 goto protocol_error;
354         }
355         offset += sizeof(sechdr);
356         len -= sizeof(sechdr);
357
358         buf = ntohl(sechdr.data_size);
359         data_size = buf & 0xffff;
360
361         check = buf >> 16;
362         check ^= seq ^ call->call_id;
363         check &= 0xffff;
364         if (check != 0) {
365                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C",
366                                              RXKADSEALEDINCON);
367                 goto protocol_error;
368         }
369
370         if (data_size > len) {
371                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L",
372                                              RXKADDATALEN);
373                 goto protocol_error;
374         }
375
376         _leave(" = 0 [dlen=%x]", data_size);
377         return 0;
378
379 protocol_error:
380         if (aborted)
381                 rxrpc_send_abort_packet(call);
382         return -EPROTO;
383 }
384
385 /*
386  * wholly decrypt a packet (level 2 security)
387  */
388 static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
389                                  unsigned int offset, unsigned int len,
390                                  rxrpc_seq_t seq,
391                                  struct skcipher_request *req)
392 {
393         const struct rxrpc_key_token *token;
394         struct rxkad_level2_hdr sechdr;
395         struct rxrpc_crypt iv;
396         struct scatterlist _sg[4], *sg;
397         bool aborted;
398         u32 data_size, buf;
399         u16 check;
400         int nsg, ret;
401
402         _enter(",{%d}", skb->len);
403
404         if (len < 8) {
405                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H",
406                                              RXKADSEALEDINCON);
407                 goto protocol_error;
408         }
409
410         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
411          * directly into the target buffer.
412          */
413         sg = _sg;
414         nsg = skb_shinfo(skb)->nr_frags;
415         if (nsg <= 4) {
416                 nsg = 4;
417         } else {
418                 sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO);
419                 if (!sg)
420                         goto nomem;
421         }
422
423         sg_init_table(sg, nsg);
424         ret = skb_to_sgvec(skb, sg, offset, len);
425         if (unlikely(ret < 0)) {
426                 if (sg != _sg)
427                         kfree(sg);
428                 return ret;
429         }
430
431         /* decrypt from the session key */
432         token = call->conn->params.key->payload.data[0];
433         memcpy(&iv, token->kad->session_key, sizeof(iv));
434
435         skcipher_request_set_sync_tfm(req, call->conn->cipher);
436         skcipher_request_set_callback(req, 0, NULL, NULL);
437         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
438         crypto_skcipher_decrypt(req);
439         skcipher_request_zero(req);
440         if (sg != _sg)
441                 kfree(sg);
442
443         /* Extract the decrypted packet length */
444         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
445                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2",
446                                              RXKADDATALEN);
447                 goto protocol_error;
448         }
449         offset += sizeof(sechdr);
450         len -= sizeof(sechdr);
451
452         buf = ntohl(sechdr.data_size);
453         data_size = buf & 0xffff;
454
455         check = buf >> 16;
456         check ^= seq ^ call->call_id;
457         check &= 0xffff;
458         if (check != 0) {
459                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C",
460                                              RXKADSEALEDINCON);
461                 goto protocol_error;
462         }
463
464         if (data_size > len) {
465                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L",
466                                              RXKADDATALEN);
467                 goto protocol_error;
468         }
469
470         _leave(" = 0 [dlen=%x]", data_size);
471         return 0;
472
473 protocol_error:
474         if (aborted)
475                 rxrpc_send_abort_packet(call);
476         return -EPROTO;
477
478 nomem:
479         _leave(" = -ENOMEM");
480         return -ENOMEM;
481 }
482
483 /*
484  * Verify the security on a received packet or subpacket (if part of a
485  * jumbo packet).
486  */
487 static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
488                                unsigned int offset, unsigned int len,
489                                rxrpc_seq_t seq, u16 expected_cksum)
490 {
491         SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
492         struct rxrpc_crypt iv;
493         struct scatterlist sg;
494         bool aborted;
495         u16 cksum;
496         u32 x, y;
497
498         _enter("{%d{%x}},{#%u}",
499                call->debug_id, key_serial(call->conn->params.key), seq);
500
501         if (!call->conn->cipher)
502                 return 0;
503
504         /* continue encrypting from where we left off */
505         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
506
507         /* validate the security checksum */
508         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
509         x |= seq & 0x3fffffff;
510         call->crypto_buf[0] = htonl(call->call_id);
511         call->crypto_buf[1] = htonl(x);
512
513         sg_init_one(&sg, call->crypto_buf, 8);
514         skcipher_request_set_sync_tfm(req, call->conn->cipher);
515         skcipher_request_set_callback(req, 0, NULL, NULL);
516         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
517         crypto_skcipher_encrypt(req);
518         skcipher_request_zero(req);
519
520         y = ntohl(call->crypto_buf[1]);
521         cksum = (y >> 16) & 0xffff;
522         if (cksum == 0)
523                 cksum = 1; /* zero checksums are not permitted */
524
525         if (cksum != expected_cksum) {
526                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK",
527                                              RXKADSEALEDINCON);
528                 goto protocol_error;
529         }
530
531         switch (call->conn->params.security_level) {
532         case RXRPC_SECURITY_PLAIN:
533                 return 0;
534         case RXRPC_SECURITY_AUTH:
535                 return rxkad_verify_packet_1(call, skb, offset, len, seq, req);
536         case RXRPC_SECURITY_ENCRYPT:
537                 return rxkad_verify_packet_2(call, skb, offset, len, seq, req);
538         default:
539                 return -ENOANO;
540         }
541
542 protocol_error:
543         if (aborted)
544                 rxrpc_send_abort_packet(call);
545         return -EPROTO;
546 }
547
548 /*
549  * Locate the data contained in a packet that was partially encrypted.
550  */
551 static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb,
552                                 unsigned int *_offset, unsigned int *_len)
553 {
554         struct rxkad_level1_hdr sechdr;
555
556         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
557                 BUG();
558         *_offset += sizeof(sechdr);
559         *_len = ntohl(sechdr.data_size) & 0xffff;
560 }
561
562 /*
563  * Locate the data contained in a packet that was completely encrypted.
564  */
565 static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb,
566                                 unsigned int *_offset, unsigned int *_len)
567 {
568         struct rxkad_level2_hdr sechdr;
569
570         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
571                 BUG();
572         *_offset += sizeof(sechdr);
573         *_len = ntohl(sechdr.data_size) & 0xffff;
574 }
575
576 /*
577  * Locate the data contained in an already decrypted packet.
578  */
579 static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb,
580                               unsigned int *_offset, unsigned int *_len)
581 {
582         switch (call->conn->params.security_level) {
583         case RXRPC_SECURITY_AUTH:
584                 rxkad_locate_data_1(call, skb, _offset, _len);
585                 return;
586         case RXRPC_SECURITY_ENCRYPT:
587                 rxkad_locate_data_2(call, skb, _offset, _len);
588                 return;
589         default:
590                 return;
591         }
592 }
593
594 /*
595  * issue a challenge
596  */
597 static int rxkad_issue_challenge(struct rxrpc_connection *conn)
598 {
599         struct rxkad_challenge challenge;
600         struct rxrpc_wire_header whdr;
601         struct msghdr msg;
602         struct kvec iov[2];
603         size_t len;
604         u32 serial;
605         int ret;
606
607         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
608
609         ret = key_validate(conn->params.key);
610         if (ret < 0)
611                 return ret;
612
613         get_random_bytes(&conn->security_nonce, sizeof(conn->security_nonce));
614
615         challenge.version       = htonl(2);
616         challenge.nonce         = htonl(conn->security_nonce);
617         challenge.min_level     = htonl(0);
618         challenge.__padding     = 0;
619
620         msg.msg_name    = &conn->params.peer->srx.transport;
621         msg.msg_namelen = conn->params.peer->srx.transport_len;
622         msg.msg_control = NULL;
623         msg.msg_controllen = 0;
624         msg.msg_flags   = 0;
625
626         whdr.epoch      = htonl(conn->proto.epoch);
627         whdr.cid        = htonl(conn->proto.cid);
628         whdr.callNumber = 0;
629         whdr.seq        = 0;
630         whdr.type       = RXRPC_PACKET_TYPE_CHALLENGE;
631         whdr.flags      = conn->out_clientflag;
632         whdr.userStatus = 0;
633         whdr.securityIndex = conn->security_ix;
634         whdr._rsvd      = 0;
635         whdr.serviceId  = htons(conn->service_id);
636
637         iov[0].iov_base = &whdr;
638         iov[0].iov_len  = sizeof(whdr);
639         iov[1].iov_base = &challenge;
640         iov[1].iov_len  = sizeof(challenge);
641
642         len = iov[0].iov_len + iov[1].iov_len;
643
644         serial = atomic_inc_return(&conn->serial);
645         whdr.serial = htonl(serial);
646         _proto("Tx CHALLENGE %%%u", serial);
647
648         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len);
649         if (ret < 0) {
650                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
651                                     rxrpc_tx_point_rxkad_challenge);
652                 return -EAGAIN;
653         }
654
655         conn->params.peer->last_tx_at = ktime_get_seconds();
656         trace_rxrpc_tx_packet(conn->debug_id, &whdr,
657                               rxrpc_tx_point_rxkad_challenge);
658         _leave(" = 0");
659         return 0;
660 }
661
662 /*
663  * send a Kerberos security response
664  */
665 static int rxkad_send_response(struct rxrpc_connection *conn,
666                                struct rxrpc_host_header *hdr,
667                                struct rxkad_response *resp,
668                                const struct rxkad_key *s2)
669 {
670         struct rxrpc_wire_header whdr;
671         struct msghdr msg;
672         struct kvec iov[3];
673         size_t len;
674         u32 serial;
675         int ret;
676
677         _enter("");
678
679         msg.msg_name    = &conn->params.peer->srx.transport;
680         msg.msg_namelen = conn->params.peer->srx.transport_len;
681         msg.msg_control = NULL;
682         msg.msg_controllen = 0;
683         msg.msg_flags   = 0;
684
685         memset(&whdr, 0, sizeof(whdr));
686         whdr.epoch      = htonl(hdr->epoch);
687         whdr.cid        = htonl(hdr->cid);
688         whdr.type       = RXRPC_PACKET_TYPE_RESPONSE;
689         whdr.flags      = conn->out_clientflag;
690         whdr.securityIndex = hdr->securityIndex;
691         whdr.serviceId  = htons(hdr->serviceId);
692
693         iov[0].iov_base = &whdr;
694         iov[0].iov_len  = sizeof(whdr);
695         iov[1].iov_base = resp;
696         iov[1].iov_len  = sizeof(*resp);
697         iov[2].iov_base = (void *)s2->ticket;
698         iov[2].iov_len  = s2->ticket_len;
699
700         len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len;
701
702         serial = atomic_inc_return(&conn->serial);
703         whdr.serial = htonl(serial);
704         _proto("Tx RESPONSE %%%u", serial);
705
706         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 3, len);
707         if (ret < 0) {
708                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
709                                     rxrpc_tx_point_rxkad_response);
710                 return -EAGAIN;
711         }
712
713         conn->params.peer->last_tx_at = ktime_get_seconds();
714         _leave(" = 0");
715         return 0;
716 }
717
718 /*
719  * calculate the response checksum
720  */
721 static void rxkad_calc_response_checksum(struct rxkad_response *response)
722 {
723         u32 csum = 1000003;
724         int loop;
725         u8 *p = (u8 *) response;
726
727         for (loop = sizeof(*response); loop > 0; loop--)
728                 csum = csum * 0x10204081 + *p++;
729
730         response->encrypted.checksum = htonl(csum);
731 }
732
733 /*
734  * encrypt the response packet
735  */
736 static void rxkad_encrypt_response(struct rxrpc_connection *conn,
737                                    struct rxkad_response *resp,
738                                    const struct rxkad_key *s2)
739 {
740         SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher);
741         struct rxrpc_crypt iv;
742         struct scatterlist sg[1];
743
744         /* continue encrypting from where we left off */
745         memcpy(&iv, s2->session_key, sizeof(iv));
746
747         sg_init_table(sg, 1);
748         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
749         skcipher_request_set_sync_tfm(req, conn->cipher);
750         skcipher_request_set_callback(req, 0, NULL, NULL);
751         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
752         crypto_skcipher_encrypt(req);
753         skcipher_request_zero(req);
754 }
755
756 /*
757  * respond to a challenge packet
758  */
759 static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
760                                       struct sk_buff *skb,
761                                       u32 *_abort_code)
762 {
763         const struct rxrpc_key_token *token;
764         struct rxkad_challenge challenge;
765         struct rxkad_response *resp;
766         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
767         const char *eproto;
768         u32 version, nonce, min_level, abort_code;
769         int ret;
770
771         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
772
773         eproto = tracepoint_string("chall_no_key");
774         abort_code = RX_PROTOCOL_ERROR;
775         if (!conn->params.key)
776                 goto protocol_error;
777
778         abort_code = RXKADEXPIRED;
779         ret = key_validate(conn->params.key);
780         if (ret < 0)
781                 goto other_error;
782
783         eproto = tracepoint_string("chall_short");
784         abort_code = RXKADPACKETSHORT;
785         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
786                           &challenge, sizeof(challenge)) < 0)
787                 goto protocol_error;
788
789         version = ntohl(challenge.version);
790         nonce = ntohl(challenge.nonce);
791         min_level = ntohl(challenge.min_level);
792
793         _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }",
794                sp->hdr.serial, version, nonce, min_level);
795
796         eproto = tracepoint_string("chall_ver");
797         abort_code = RXKADINCONSISTENCY;
798         if (version != RXKAD_VERSION)
799                 goto protocol_error;
800
801         abort_code = RXKADLEVELFAIL;
802         ret = -EACCES;
803         if (conn->params.security_level < min_level)
804                 goto other_error;
805
806         token = conn->params.key->payload.data[0];
807
808         /* build the response packet */
809         resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
810         if (!resp)
811                 return -ENOMEM;
812
813         resp->version                   = htonl(RXKAD_VERSION);
814         resp->encrypted.epoch           = htonl(conn->proto.epoch);
815         resp->encrypted.cid             = htonl(conn->proto.cid);
816         resp->encrypted.securityIndex   = htonl(conn->security_ix);
817         resp->encrypted.inc_nonce       = htonl(nonce + 1);
818         resp->encrypted.level           = htonl(conn->params.security_level);
819         resp->kvno                      = htonl(token->kad->kvno);
820         resp->ticket_len                = htonl(token->kad->ticket_len);
821         resp->encrypted.call_id[0]      = htonl(conn->channels[0].call_counter);
822         resp->encrypted.call_id[1]      = htonl(conn->channels[1].call_counter);
823         resp->encrypted.call_id[2]      = htonl(conn->channels[2].call_counter);
824         resp->encrypted.call_id[3]      = htonl(conn->channels[3].call_counter);
825
826         /* calculate the response checksum and then do the encryption */
827         rxkad_calc_response_checksum(resp);
828         rxkad_encrypt_response(conn, resp, token->kad);
829         ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad);
830         kfree(resp);
831         return ret;
832
833 protocol_error:
834         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
835         ret = -EPROTO;
836 other_error:
837         *_abort_code = abort_code;
838         return ret;
839 }
840
841 /*
842  * decrypt the kerberos IV ticket in the response
843  */
844 static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
845                                 struct sk_buff *skb,
846                                 void *ticket, size_t ticket_len,
847                                 struct rxrpc_crypt *_session_key,
848                                 time64_t *_expiry,
849                                 u32 *_abort_code)
850 {
851         struct skcipher_request *req;
852         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
853         struct rxrpc_crypt iv, key;
854         struct scatterlist sg[1];
855         struct in_addr addr;
856         unsigned int life;
857         const char *eproto;
858         time64_t issue, now;
859         bool little_endian;
860         int ret;
861         u32 abort_code;
862         u8 *p, *q, *name, *end;
863
864         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key));
865
866         *_expiry = 0;
867
868         ret = key_validate(conn->server_key);
869         if (ret < 0) {
870                 switch (ret) {
871                 case -EKEYEXPIRED:
872                         abort_code = RXKADEXPIRED;
873                         goto other_error;
874                 default:
875                         abort_code = RXKADNOAUTH;
876                         goto other_error;
877                 }
878         }
879
880         ASSERT(conn->server_key->payload.data[0] != NULL);
881         ASSERTCMP((unsigned long) ticket & 7UL, ==, 0);
882
883         memcpy(&iv, &conn->server_key->payload.data[2], sizeof(iv));
884
885         ret = -ENOMEM;
886         req = skcipher_request_alloc(conn->server_key->payload.data[0],
887                                      GFP_NOFS);
888         if (!req)
889                 goto temporary_error;
890
891         sg_init_one(&sg[0], ticket, ticket_len);
892         skcipher_request_set_callback(req, 0, NULL, NULL);
893         skcipher_request_set_crypt(req, sg, sg, ticket_len, iv.x);
894         crypto_skcipher_decrypt(req);
895         skcipher_request_free(req);
896
897         p = ticket;
898         end = p + ticket_len;
899
900 #define Z(field)                                        \
901         ({                                              \
902                 u8 *__str = p;                          \
903                 eproto = tracepoint_string("rxkad_bad_"#field); \
904                 q = memchr(p, 0, end - p);              \
905                 if (!q || q - p > (field##_SZ))         \
906                         goto bad_ticket;                \
907                 for (; p < q; p++)                      \
908                         if (!isprint(*p))               \
909                                 goto bad_ticket;        \
910                 p++;                                    \
911                 __str;                                  \
912         })
913
914         /* extract the ticket flags */
915         _debug("KIV FLAGS: %x", *p);
916         little_endian = *p & 1;
917         p++;
918
919         /* extract the authentication name */
920         name = Z(ANAME);
921         _debug("KIV ANAME: %s", name);
922
923         /* extract the principal's instance */
924         name = Z(INST);
925         _debug("KIV INST : %s", name);
926
927         /* extract the principal's authentication domain */
928         name = Z(REALM);
929         _debug("KIV REALM: %s", name);
930
931         eproto = tracepoint_string("rxkad_bad_len");
932         if (end - p < 4 + 8 + 4 + 2)
933                 goto bad_ticket;
934
935         /* get the IPv4 address of the entity that requested the ticket */
936         memcpy(&addr, p, sizeof(addr));
937         p += 4;
938         _debug("KIV ADDR : %pI4", &addr);
939
940         /* get the session key from the ticket */
941         memcpy(&key, p, sizeof(key));
942         p += 8;
943         _debug("KIV KEY  : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1]));
944         memcpy(_session_key, &key, sizeof(key));
945
946         /* get the ticket's lifetime */
947         life = *p++ * 5 * 60;
948         _debug("KIV LIFE : %u", life);
949
950         /* get the issue time of the ticket */
951         if (little_endian) {
952                 __le32 stamp;
953                 memcpy(&stamp, p, 4);
954                 issue = rxrpc_u32_to_time64(le32_to_cpu(stamp));
955         } else {
956                 __be32 stamp;
957                 memcpy(&stamp, p, 4);
958                 issue = rxrpc_u32_to_time64(be32_to_cpu(stamp));
959         }
960         p += 4;
961         now = ktime_get_real_seconds();
962         _debug("KIV ISSUE: %llx [%llx]", issue, now);
963
964         /* check the ticket is in date */
965         if (issue > now) {
966                 abort_code = RXKADNOAUTH;
967                 ret = -EKEYREJECTED;
968                 goto other_error;
969         }
970
971         if (issue < now - life) {
972                 abort_code = RXKADEXPIRED;
973                 ret = -EKEYEXPIRED;
974                 goto other_error;
975         }
976
977         *_expiry = issue + life;
978
979         /* get the service name */
980         name = Z(SNAME);
981         _debug("KIV SNAME: %s", name);
982
983         /* get the service instance name */
984         name = Z(INST);
985         _debug("KIV SINST: %s", name);
986         return 0;
987
988 bad_ticket:
989         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
990         abort_code = RXKADBADTICKET;
991         ret = -EPROTO;
992 other_error:
993         *_abort_code = abort_code;
994         return ret;
995 temporary_error:
996         return ret;
997 }
998
999 /*
1000  * decrypt the response packet
1001  */
1002 static void rxkad_decrypt_response(struct rxrpc_connection *conn,
1003                                    struct rxkad_response *resp,
1004                                    const struct rxrpc_crypt *session_key)
1005 {
1006         SYNC_SKCIPHER_REQUEST_ON_STACK(req, rxkad_ci);
1007         struct scatterlist sg[1];
1008         struct rxrpc_crypt iv;
1009
1010         _enter(",,%08x%08x",
1011                ntohl(session_key->n[0]), ntohl(session_key->n[1]));
1012
1013         ASSERT(rxkad_ci != NULL);
1014
1015         mutex_lock(&rxkad_ci_mutex);
1016         if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x,
1017                                    sizeof(*session_key)) < 0)
1018                 BUG();
1019
1020         memcpy(&iv, session_key, sizeof(iv));
1021
1022         sg_init_table(sg, 1);
1023         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
1024         skcipher_request_set_sync_tfm(req, rxkad_ci);
1025         skcipher_request_set_callback(req, 0, NULL, NULL);
1026         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
1027         crypto_skcipher_decrypt(req);
1028         skcipher_request_zero(req);
1029
1030         mutex_unlock(&rxkad_ci_mutex);
1031
1032         _leave("");
1033 }
1034
1035 /*
1036  * verify a response
1037  */
1038 static int rxkad_verify_response(struct rxrpc_connection *conn,
1039                                  struct sk_buff *skb,
1040                                  u32 *_abort_code)
1041 {
1042         struct rxkad_response *response;
1043         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
1044         struct rxrpc_crypt session_key;
1045         const char *eproto;
1046         time64_t expiry;
1047         void *ticket;
1048         u32 abort_code, version, kvno, ticket_len, level;
1049         __be32 csum;
1050         int ret, i;
1051
1052         _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
1053
1054         ret = -ENOMEM;
1055         response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
1056         if (!response)
1057                 goto temporary_error;
1058
1059         eproto = tracepoint_string("rxkad_rsp_short");
1060         abort_code = RXKADPACKETSHORT;
1061         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1062                           response, sizeof(*response)) < 0)
1063                 goto protocol_error;
1064         if (!pskb_pull(skb, sizeof(*response)))
1065                 BUG();
1066
1067         version = ntohl(response->version);
1068         ticket_len = ntohl(response->ticket_len);
1069         kvno = ntohl(response->kvno);
1070         _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
1071                sp->hdr.serial, version, kvno, ticket_len);
1072
1073         eproto = tracepoint_string("rxkad_rsp_ver");
1074         abort_code = RXKADINCONSISTENCY;
1075         if (version != RXKAD_VERSION)
1076                 goto protocol_error;
1077
1078         eproto = tracepoint_string("rxkad_rsp_tktlen");
1079         abort_code = RXKADTICKETLEN;
1080         if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
1081                 goto protocol_error;
1082
1083         eproto = tracepoint_string("rxkad_rsp_unkkey");
1084         abort_code = RXKADUNKNOWNKEY;
1085         if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
1086                 goto protocol_error;
1087
1088         /* extract the kerberos ticket and decrypt and decode it */
1089         ret = -ENOMEM;
1090         ticket = kmalloc(ticket_len, GFP_NOFS);
1091         if (!ticket)
1092                 goto temporary_error;
1093
1094         eproto = tracepoint_string("rxkad_tkt_short");
1095         abort_code = RXKADPACKETSHORT;
1096         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1097                           ticket, ticket_len) < 0)
1098                 goto protocol_error_free;
1099
1100         ret = rxkad_decrypt_ticket(conn, skb, ticket, ticket_len, &session_key,
1101                                    &expiry, _abort_code);
1102         if (ret < 0)
1103                 goto temporary_error_free_resp;
1104
1105         /* use the session key from inside the ticket to decrypt the
1106          * response */
1107         rxkad_decrypt_response(conn, response, &session_key);
1108
1109         eproto = tracepoint_string("rxkad_rsp_param");
1110         abort_code = RXKADSEALEDINCON;
1111         if (ntohl(response->encrypted.epoch) != conn->proto.epoch)
1112                 goto protocol_error_free;
1113         if (ntohl(response->encrypted.cid) != conn->proto.cid)
1114                 goto protocol_error_free;
1115         if (ntohl(response->encrypted.securityIndex) != conn->security_ix)
1116                 goto protocol_error_free;
1117         csum = response->encrypted.checksum;
1118         response->encrypted.checksum = 0;
1119         rxkad_calc_response_checksum(response);
1120         eproto = tracepoint_string("rxkad_rsp_csum");
1121         if (response->encrypted.checksum != csum)
1122                 goto protocol_error_free;
1123
1124         spin_lock(&conn->channel_lock);
1125         for (i = 0; i < RXRPC_MAXCALLS; i++) {
1126                 struct rxrpc_call *call;
1127                 u32 call_id = ntohl(response->encrypted.call_id[i]);
1128
1129                 eproto = tracepoint_string("rxkad_rsp_callid");
1130                 if (call_id > INT_MAX)
1131                         goto protocol_error_unlock;
1132
1133                 eproto = tracepoint_string("rxkad_rsp_callctr");
1134                 if (call_id < conn->channels[i].call_counter)
1135                         goto protocol_error_unlock;
1136
1137                 eproto = tracepoint_string("rxkad_rsp_callst");
1138                 if (call_id > conn->channels[i].call_counter) {
1139                         call = rcu_dereference_protected(
1140                                 conn->channels[i].call,
1141                                 lockdep_is_held(&conn->channel_lock));
1142                         if (call && call->state < RXRPC_CALL_COMPLETE)
1143                                 goto protocol_error_unlock;
1144                         conn->channels[i].call_counter = call_id;
1145                 }
1146         }
1147         spin_unlock(&conn->channel_lock);
1148
1149         eproto = tracepoint_string("rxkad_rsp_seq");
1150         abort_code = RXKADOUTOFSEQUENCE;
1151         if (ntohl(response->encrypted.inc_nonce) != conn->security_nonce + 1)
1152                 goto protocol_error_free;
1153
1154         eproto = tracepoint_string("rxkad_rsp_level");
1155         abort_code = RXKADLEVELFAIL;
1156         level = ntohl(response->encrypted.level);
1157         if (level > RXRPC_SECURITY_ENCRYPT)
1158                 goto protocol_error_free;
1159         conn->params.security_level = level;
1160
1161         /* create a key to hold the security data and expiration time - after
1162          * this the connection security can be handled in exactly the same way
1163          * as for a client connection */
1164         ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
1165         if (ret < 0)
1166                 goto temporary_error_free_ticket;
1167
1168         kfree(ticket);
1169         kfree(response);
1170         _leave(" = 0");
1171         return 0;
1172
1173 protocol_error_unlock:
1174         spin_unlock(&conn->channel_lock);
1175 protocol_error_free:
1176         kfree(ticket);
1177 protocol_error:
1178         kfree(response);
1179         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1180         *_abort_code = abort_code;
1181         return -EPROTO;
1182
1183 temporary_error_free_ticket:
1184         kfree(ticket);
1185 temporary_error_free_resp:
1186         kfree(response);
1187 temporary_error:
1188         /* Ignore the response packet if we got a temporary error such as
1189          * ENOMEM.  We just want to send the challenge again.  Note that we
1190          * also come out this way if the ticket decryption fails.
1191          */
1192         return ret;
1193 }
1194
1195 /*
1196  * clear the connection security
1197  */
1198 static void rxkad_clear(struct rxrpc_connection *conn)
1199 {
1200         _enter("");
1201
1202         if (conn->cipher)
1203                 crypto_free_sync_skcipher(conn->cipher);
1204 }
1205
1206 /*
1207  * Initialise the rxkad security service.
1208  */
1209 static int rxkad_init(void)
1210 {
1211         /* pin the cipher we need so that the crypto layer doesn't invoke
1212          * keventd to go get it */
1213         rxkad_ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
1214         return PTR_ERR_OR_ZERO(rxkad_ci);
1215 }
1216
1217 /*
1218  * Clean up the rxkad security service.
1219  */
1220 static void rxkad_exit(void)
1221 {
1222         if (rxkad_ci)
1223                 crypto_free_sync_skcipher(rxkad_ci);
1224 }
1225
1226 /*
1227  * RxRPC Kerberos-based security
1228  */
1229 const struct rxrpc_security rxkad = {
1230         .name                           = "rxkad",
1231         .security_index                 = RXRPC_SECURITY_RXKAD,
1232         .init                           = rxkad_init,
1233         .exit                           = rxkad_exit,
1234         .init_connection_security       = rxkad_init_connection_security,
1235         .prime_packet_security          = rxkad_prime_packet_security,
1236         .secure_packet                  = rxkad_secure_packet,
1237         .verify_packet                  = rxkad_verify_packet,
1238         .locate_data                    = rxkad_locate_data,
1239         .issue_challenge                = rxkad_issue_challenge,
1240         .respond_to_challenge           = rxkad_respond_to_challenge,
1241         .verify_response                = rxkad_verify_response,
1242         .clear                          = rxkad_clear,
1243 };