Merge branch 'locking-urgent-for-linus' of git://git.kernel.org/pub/scm/linux/kernel...
[sfrench/cifs-2.6.git] / net / rxrpc / rxkad.c
1 /* Kerberos-based RxRPC security
2  *
3  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
4  * Written by David Howells (dhowells@redhat.com)
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
13
14 #include <crypto/skcipher.h>
15 #include <linux/module.h>
16 #include <linux/net.h>
17 #include <linux/skbuff.h>
18 #include <linux/udp.h>
19 #include <linux/scatterlist.h>
20 #include <linux/ctype.h>
21 #include <linux/slab.h>
22 #include <net/sock.h>
23 #include <net/af_rxrpc.h>
24 #include <keys/rxrpc-type.h>
25 #include "ar-internal.h"
26
27 #define RXKAD_VERSION                   2
28 #define MAXKRB5TICKETLEN                1024
29 #define RXKAD_TKT_TYPE_KERBEROS_V5      256
30 #define ANAME_SZ                        40      /* size of authentication name */
31 #define INST_SZ                         40      /* size of principal's instance */
32 #define REALM_SZ                        40      /* size of principal's auth domain */
33 #define SNAME_SZ                        40      /* size of service name */
34
35 struct rxkad_level1_hdr {
36         __be32  data_size;      /* true data size (excluding padding) */
37 };
38
39 struct rxkad_level2_hdr {
40         __be32  data_size;      /* true data size (excluding padding) */
41         __be32  checksum;       /* decrypted data checksum */
42 };
43
44 /*
45  * this holds a pinned cipher so that keventd doesn't get called by the cipher
46  * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE
47  * packets
48  */
49 static struct crypto_sync_skcipher *rxkad_ci;
50 static DEFINE_MUTEX(rxkad_ci_mutex);
51
52 /*
53  * initialise connection security
54  */
55 static int rxkad_init_connection_security(struct rxrpc_connection *conn)
56 {
57         struct crypto_sync_skcipher *ci;
58         struct rxrpc_key_token *token;
59         int ret;
60
61         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->params.key));
62
63         token = conn->params.key->payload.data[0];
64         conn->security_ix = token->security_index;
65
66         ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
67         if (IS_ERR(ci)) {
68                 _debug("no cipher");
69                 ret = PTR_ERR(ci);
70                 goto error;
71         }
72
73         if (crypto_sync_skcipher_setkey(ci, token->kad->session_key,
74                                    sizeof(token->kad->session_key)) < 0)
75                 BUG();
76
77         switch (conn->params.security_level) {
78         case RXRPC_SECURITY_PLAIN:
79                 break;
80         case RXRPC_SECURITY_AUTH:
81                 conn->size_align = 8;
82                 conn->security_size = sizeof(struct rxkad_level1_hdr);
83                 break;
84         case RXRPC_SECURITY_ENCRYPT:
85                 conn->size_align = 8;
86                 conn->security_size = sizeof(struct rxkad_level2_hdr);
87                 break;
88         default:
89                 ret = -EKEYREJECTED;
90                 goto error;
91         }
92
93         conn->cipher = ci;
94         ret = 0;
95 error:
96         _leave(" = %d", ret);
97         return ret;
98 }
99
100 /*
101  * prime the encryption state with the invariant parts of a connection's
102  * description
103  */
104 static int rxkad_prime_packet_security(struct rxrpc_connection *conn)
105 {
106         struct rxrpc_key_token *token;
107         SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher);
108         struct scatterlist sg;
109         struct rxrpc_crypt iv;
110         __be32 *tmpbuf;
111         size_t tmpsize = 4 * sizeof(__be32);
112
113         _enter("");
114
115         if (!conn->params.key)
116                 return 0;
117
118         tmpbuf = kmalloc(tmpsize, GFP_KERNEL);
119         if (!tmpbuf)
120                 return -ENOMEM;
121
122         token = conn->params.key->payload.data[0];
123         memcpy(&iv, token->kad->session_key, sizeof(iv));
124
125         tmpbuf[0] = htonl(conn->proto.epoch);
126         tmpbuf[1] = htonl(conn->proto.cid);
127         tmpbuf[2] = 0;
128         tmpbuf[3] = htonl(conn->security_ix);
129
130         sg_init_one(&sg, tmpbuf, tmpsize);
131         skcipher_request_set_sync_tfm(req, conn->cipher);
132         skcipher_request_set_callback(req, 0, NULL, NULL);
133         skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x);
134         crypto_skcipher_encrypt(req);
135         skcipher_request_zero(req);
136
137         memcpy(&conn->csum_iv, tmpbuf + 2, sizeof(conn->csum_iv));
138         kfree(tmpbuf);
139         _leave(" = 0");
140         return 0;
141 }
142
143 /*
144  * partially encrypt a packet (level 1 security)
145  */
146 static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
147                                     struct sk_buff *skb,
148                                     u32 data_size,
149                                     void *sechdr,
150                                     struct skcipher_request *req)
151 {
152         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
153         struct rxkad_level1_hdr hdr;
154         struct rxrpc_crypt iv;
155         struct scatterlist sg;
156         u16 check;
157
158         _enter("");
159
160         check = sp->hdr.seq ^ call->call_id;
161         data_size |= (u32)check << 16;
162
163         hdr.data_size = htonl(data_size);
164         memcpy(sechdr, &hdr, sizeof(hdr));
165
166         /* start the encryption afresh */
167         memset(&iv, 0, sizeof(iv));
168
169         sg_init_one(&sg, sechdr, 8);
170         skcipher_request_set_sync_tfm(req, call->conn->cipher);
171         skcipher_request_set_callback(req, 0, NULL, NULL);
172         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
173         crypto_skcipher_encrypt(req);
174         skcipher_request_zero(req);
175
176         _leave(" = 0");
177         return 0;
178 }
179
180 /*
181  * wholly encrypt a packet (level 2 security)
182  */
183 static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
184                                        struct sk_buff *skb,
185                                        u32 data_size,
186                                        void *sechdr,
187                                        struct skcipher_request *req)
188 {
189         const struct rxrpc_key_token *token;
190         struct rxkad_level2_hdr rxkhdr;
191         struct rxrpc_skb_priv *sp;
192         struct rxrpc_crypt iv;
193         struct scatterlist sg[16];
194         struct sk_buff *trailer;
195         unsigned int len;
196         u16 check;
197         int nsg;
198         int err;
199
200         sp = rxrpc_skb(skb);
201
202         _enter("");
203
204         check = sp->hdr.seq ^ call->call_id;
205
206         rxkhdr.data_size = htonl(data_size | (u32)check << 16);
207         rxkhdr.checksum = 0;
208         memcpy(sechdr, &rxkhdr, sizeof(rxkhdr));
209
210         /* encrypt from the session key */
211         token = call->conn->params.key->payload.data[0];
212         memcpy(&iv, token->kad->session_key, sizeof(iv));
213
214         sg_init_one(&sg[0], sechdr, sizeof(rxkhdr));
215         skcipher_request_set_sync_tfm(req, call->conn->cipher);
216         skcipher_request_set_callback(req, 0, NULL, NULL);
217         skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x);
218         crypto_skcipher_encrypt(req);
219
220         /* we want to encrypt the skbuff in-place */
221         nsg = skb_cow_data(skb, 0, &trailer);
222         err = -ENOMEM;
223         if (nsg < 0 || nsg > 16)
224                 goto out;
225
226         len = data_size + call->conn->size_align - 1;
227         len &= ~(call->conn->size_align - 1);
228
229         sg_init_table(sg, nsg);
230         err = skb_to_sgvec(skb, sg, 0, len);
231         if (unlikely(err < 0))
232                 goto out;
233         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
234         crypto_skcipher_encrypt(req);
235
236         _leave(" = 0");
237         err = 0;
238
239 out:
240         skcipher_request_zero(req);
241         return err;
242 }
243
244 /*
245  * checksum an RxRPC packet header
246  */
247 static int rxkad_secure_packet(struct rxrpc_call *call,
248                                struct sk_buff *skb,
249                                size_t data_size,
250                                void *sechdr)
251 {
252         struct rxrpc_skb_priv *sp;
253         SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
254         struct rxrpc_crypt iv;
255         struct scatterlist sg;
256         u32 x, y;
257         int ret;
258
259         sp = rxrpc_skb(skb);
260
261         _enter("{%d{%x}},{#%u},%zu,",
262                call->debug_id, key_serial(call->conn->params.key),
263                sp->hdr.seq, data_size);
264
265         if (!call->conn->cipher)
266                 return 0;
267
268         ret = key_validate(call->conn->params.key);
269         if (ret < 0)
270                 return ret;
271
272         /* continue encrypting from where we left off */
273         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
274
275         /* calculate the security checksum */
276         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
277         x |= sp->hdr.seq & 0x3fffffff;
278         call->crypto_buf[0] = htonl(call->call_id);
279         call->crypto_buf[1] = htonl(x);
280
281         sg_init_one(&sg, call->crypto_buf, 8);
282         skcipher_request_set_sync_tfm(req, call->conn->cipher);
283         skcipher_request_set_callback(req, 0, NULL, NULL);
284         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
285         crypto_skcipher_encrypt(req);
286         skcipher_request_zero(req);
287
288         y = ntohl(call->crypto_buf[1]);
289         y = (y >> 16) & 0xffff;
290         if (y == 0)
291                 y = 1; /* zero checksums are not permitted */
292         sp->hdr.cksum = y;
293
294         switch (call->conn->params.security_level) {
295         case RXRPC_SECURITY_PLAIN:
296                 ret = 0;
297                 break;
298         case RXRPC_SECURITY_AUTH:
299                 ret = rxkad_secure_packet_auth(call, skb, data_size, sechdr,
300                                                req);
301                 break;
302         case RXRPC_SECURITY_ENCRYPT:
303                 ret = rxkad_secure_packet_encrypt(call, skb, data_size,
304                                                   sechdr, req);
305                 break;
306         default:
307                 ret = -EPERM;
308                 break;
309         }
310
311         _leave(" = %d [set %hx]", ret, y);
312         return ret;
313 }
314
315 /*
316  * decrypt partial encryption on a packet (level 1 security)
317  */
318 static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
319                                  unsigned int offset, unsigned int len,
320                                  rxrpc_seq_t seq,
321                                  struct skcipher_request *req)
322 {
323         struct rxkad_level1_hdr sechdr;
324         struct rxrpc_crypt iv;
325         struct scatterlist sg[16];
326         struct sk_buff *trailer;
327         bool aborted;
328         u32 data_size, buf;
329         u16 check;
330         int nsg, ret;
331
332         _enter("");
333
334         if (len < 8) {
335                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H",
336                                            RXKADSEALEDINCON);
337                 goto protocol_error;
338         }
339
340         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
341          * directly into the target buffer.
342          */
343         nsg = skb_cow_data(skb, 0, &trailer);
344         if (nsg < 0 || nsg > 16)
345                 goto nomem;
346
347         sg_init_table(sg, nsg);
348         ret = skb_to_sgvec(skb, sg, offset, 8);
349         if (unlikely(ret < 0))
350                 return ret;
351
352         /* start the decryption afresh */
353         memset(&iv, 0, sizeof(iv));
354
355         skcipher_request_set_sync_tfm(req, call->conn->cipher);
356         skcipher_request_set_callback(req, 0, NULL, NULL);
357         skcipher_request_set_crypt(req, sg, sg, 8, iv.x);
358         crypto_skcipher_decrypt(req);
359         skcipher_request_zero(req);
360
361         /* Extract the decrypted packet length */
362         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
363                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1",
364                                              RXKADDATALEN);
365                 goto protocol_error;
366         }
367         offset += sizeof(sechdr);
368         len -= sizeof(sechdr);
369
370         buf = ntohl(sechdr.data_size);
371         data_size = buf & 0xffff;
372
373         check = buf >> 16;
374         check ^= seq ^ call->call_id;
375         check &= 0xffff;
376         if (check != 0) {
377                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C",
378                                              RXKADSEALEDINCON);
379                 goto protocol_error;
380         }
381
382         if (data_size > len) {
383                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L",
384                                              RXKADDATALEN);
385                 goto protocol_error;
386         }
387
388         _leave(" = 0 [dlen=%x]", data_size);
389         return 0;
390
391 protocol_error:
392         if (aborted)
393                 rxrpc_send_abort_packet(call);
394         return -EPROTO;
395
396 nomem:
397         _leave(" = -ENOMEM");
398         return -ENOMEM;
399 }
400
401 /*
402  * wholly decrypt a packet (level 2 security)
403  */
404 static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
405                                  unsigned int offset, unsigned int len,
406                                  rxrpc_seq_t seq,
407                                  struct skcipher_request *req)
408 {
409         const struct rxrpc_key_token *token;
410         struct rxkad_level2_hdr sechdr;
411         struct rxrpc_crypt iv;
412         struct scatterlist _sg[4], *sg;
413         struct sk_buff *trailer;
414         bool aborted;
415         u32 data_size, buf;
416         u16 check;
417         int nsg, ret;
418
419         _enter(",{%d}", skb->len);
420
421         if (len < 8) {
422                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H",
423                                              RXKADSEALEDINCON);
424                 goto protocol_error;
425         }
426
427         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
428          * directly into the target buffer.
429          */
430         nsg = skb_cow_data(skb, 0, &trailer);
431         if (nsg < 0)
432                 goto nomem;
433
434         sg = _sg;
435         if (unlikely(nsg > 4)) {
436                 sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO);
437                 if (!sg)
438                         goto nomem;
439         }
440
441         sg_init_table(sg, nsg);
442         ret = skb_to_sgvec(skb, sg, offset, len);
443         if (unlikely(ret < 0)) {
444                 if (sg != _sg)
445                         kfree(sg);
446                 return ret;
447         }
448
449         /* decrypt from the session key */
450         token = call->conn->params.key->payload.data[0];
451         memcpy(&iv, token->kad->session_key, sizeof(iv));
452
453         skcipher_request_set_sync_tfm(req, call->conn->cipher);
454         skcipher_request_set_callback(req, 0, NULL, NULL);
455         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
456         crypto_skcipher_decrypt(req);
457         skcipher_request_zero(req);
458         if (sg != _sg)
459                 kfree(sg);
460
461         /* Extract the decrypted packet length */
462         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
463                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2",
464                                              RXKADDATALEN);
465                 goto protocol_error;
466         }
467         offset += sizeof(sechdr);
468         len -= sizeof(sechdr);
469
470         buf = ntohl(sechdr.data_size);
471         data_size = buf & 0xffff;
472
473         check = buf >> 16;
474         check ^= seq ^ call->call_id;
475         check &= 0xffff;
476         if (check != 0) {
477                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C",
478                                              RXKADSEALEDINCON);
479                 goto protocol_error;
480         }
481
482         if (data_size > len) {
483                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L",
484                                              RXKADDATALEN);
485                 goto protocol_error;
486         }
487
488         _leave(" = 0 [dlen=%x]", data_size);
489         return 0;
490
491 protocol_error:
492         if (aborted)
493                 rxrpc_send_abort_packet(call);
494         return -EPROTO;
495
496 nomem:
497         _leave(" = -ENOMEM");
498         return -ENOMEM;
499 }
500
501 /*
502  * Verify the security on a received packet or subpacket (if part of a
503  * jumbo packet).
504  */
505 static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
506                                unsigned int offset, unsigned int len,
507                                rxrpc_seq_t seq, u16 expected_cksum)
508 {
509         SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
510         struct rxrpc_crypt iv;
511         struct scatterlist sg;
512         bool aborted;
513         u16 cksum;
514         u32 x, y;
515
516         _enter("{%d{%x}},{#%u}",
517                call->debug_id, key_serial(call->conn->params.key), seq);
518
519         if (!call->conn->cipher)
520                 return 0;
521
522         /* continue encrypting from where we left off */
523         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
524
525         /* validate the security checksum */
526         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
527         x |= seq & 0x3fffffff;
528         call->crypto_buf[0] = htonl(call->call_id);
529         call->crypto_buf[1] = htonl(x);
530
531         sg_init_one(&sg, call->crypto_buf, 8);
532         skcipher_request_set_sync_tfm(req, call->conn->cipher);
533         skcipher_request_set_callback(req, 0, NULL, NULL);
534         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
535         crypto_skcipher_encrypt(req);
536         skcipher_request_zero(req);
537
538         y = ntohl(call->crypto_buf[1]);
539         cksum = (y >> 16) & 0xffff;
540         if (cksum == 0)
541                 cksum = 1; /* zero checksums are not permitted */
542
543         if (cksum != expected_cksum) {
544                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK",
545                                              RXKADSEALEDINCON);
546                 goto protocol_error;
547         }
548
549         switch (call->conn->params.security_level) {
550         case RXRPC_SECURITY_PLAIN:
551                 return 0;
552         case RXRPC_SECURITY_AUTH:
553                 return rxkad_verify_packet_1(call, skb, offset, len, seq, req);
554         case RXRPC_SECURITY_ENCRYPT:
555                 return rxkad_verify_packet_2(call, skb, offset, len, seq, req);
556         default:
557                 return -ENOANO;
558         }
559
560 protocol_error:
561         if (aborted)
562                 rxrpc_send_abort_packet(call);
563         return -EPROTO;
564 }
565
566 /*
567  * Locate the data contained in a packet that was partially encrypted.
568  */
569 static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb,
570                                 unsigned int *_offset, unsigned int *_len)
571 {
572         struct rxkad_level1_hdr sechdr;
573
574         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
575                 BUG();
576         *_offset += sizeof(sechdr);
577         *_len = ntohl(sechdr.data_size) & 0xffff;
578 }
579
580 /*
581  * Locate the data contained in a packet that was completely encrypted.
582  */
583 static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb,
584                                 unsigned int *_offset, unsigned int *_len)
585 {
586         struct rxkad_level2_hdr sechdr;
587
588         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
589                 BUG();
590         *_offset += sizeof(sechdr);
591         *_len = ntohl(sechdr.data_size) & 0xffff;
592 }
593
594 /*
595  * Locate the data contained in an already decrypted packet.
596  */
597 static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb,
598                               unsigned int *_offset, unsigned int *_len)
599 {
600         switch (call->conn->params.security_level) {
601         case RXRPC_SECURITY_AUTH:
602                 rxkad_locate_data_1(call, skb, _offset, _len);
603                 return;
604         case RXRPC_SECURITY_ENCRYPT:
605                 rxkad_locate_data_2(call, skb, _offset, _len);
606                 return;
607         default:
608                 return;
609         }
610 }
611
612 /*
613  * issue a challenge
614  */
615 static int rxkad_issue_challenge(struct rxrpc_connection *conn)
616 {
617         struct rxkad_challenge challenge;
618         struct rxrpc_wire_header whdr;
619         struct msghdr msg;
620         struct kvec iov[2];
621         size_t len;
622         u32 serial;
623         int ret;
624
625         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
626
627         ret = key_validate(conn->params.key);
628         if (ret < 0)
629                 return ret;
630
631         get_random_bytes(&conn->security_nonce, sizeof(conn->security_nonce));
632
633         challenge.version       = htonl(2);
634         challenge.nonce         = htonl(conn->security_nonce);
635         challenge.min_level     = htonl(0);
636         challenge.__padding     = 0;
637
638         msg.msg_name    = &conn->params.peer->srx.transport;
639         msg.msg_namelen = conn->params.peer->srx.transport_len;
640         msg.msg_control = NULL;
641         msg.msg_controllen = 0;
642         msg.msg_flags   = 0;
643
644         whdr.epoch      = htonl(conn->proto.epoch);
645         whdr.cid        = htonl(conn->proto.cid);
646         whdr.callNumber = 0;
647         whdr.seq        = 0;
648         whdr.type       = RXRPC_PACKET_TYPE_CHALLENGE;
649         whdr.flags      = conn->out_clientflag;
650         whdr.userStatus = 0;
651         whdr.securityIndex = conn->security_ix;
652         whdr._rsvd      = 0;
653         whdr.serviceId  = htons(conn->service_id);
654
655         iov[0].iov_base = &whdr;
656         iov[0].iov_len  = sizeof(whdr);
657         iov[1].iov_base = &challenge;
658         iov[1].iov_len  = sizeof(challenge);
659
660         len = iov[0].iov_len + iov[1].iov_len;
661
662         serial = atomic_inc_return(&conn->serial);
663         whdr.serial = htonl(serial);
664         _proto("Tx CHALLENGE %%%u", serial);
665
666         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len);
667         if (ret < 0) {
668                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
669                                     rxrpc_tx_point_rxkad_challenge);
670                 return -EAGAIN;
671         }
672
673         conn->params.peer->last_tx_at = ktime_get_seconds();
674         trace_rxrpc_tx_packet(conn->debug_id, &whdr,
675                               rxrpc_tx_point_rxkad_challenge);
676         _leave(" = 0");
677         return 0;
678 }
679
680 /*
681  * send a Kerberos security response
682  */
683 static int rxkad_send_response(struct rxrpc_connection *conn,
684                                struct rxrpc_host_header *hdr,
685                                struct rxkad_response *resp,
686                                const struct rxkad_key *s2)
687 {
688         struct rxrpc_wire_header whdr;
689         struct msghdr msg;
690         struct kvec iov[3];
691         size_t len;
692         u32 serial;
693         int ret;
694
695         _enter("");
696
697         msg.msg_name    = &conn->params.peer->srx.transport;
698         msg.msg_namelen = conn->params.peer->srx.transport_len;
699         msg.msg_control = NULL;
700         msg.msg_controllen = 0;
701         msg.msg_flags   = 0;
702
703         memset(&whdr, 0, sizeof(whdr));
704         whdr.epoch      = htonl(hdr->epoch);
705         whdr.cid        = htonl(hdr->cid);
706         whdr.type       = RXRPC_PACKET_TYPE_RESPONSE;
707         whdr.flags      = conn->out_clientflag;
708         whdr.securityIndex = hdr->securityIndex;
709         whdr.serviceId  = htons(hdr->serviceId);
710
711         iov[0].iov_base = &whdr;
712         iov[0].iov_len  = sizeof(whdr);
713         iov[1].iov_base = resp;
714         iov[1].iov_len  = sizeof(*resp);
715         iov[2].iov_base = (void *)s2->ticket;
716         iov[2].iov_len  = s2->ticket_len;
717
718         len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len;
719
720         serial = atomic_inc_return(&conn->serial);
721         whdr.serial = htonl(serial);
722         _proto("Tx RESPONSE %%%u", serial);
723
724         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 3, len);
725         if (ret < 0) {
726                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
727                                     rxrpc_tx_point_rxkad_response);
728                 return -EAGAIN;
729         }
730
731         conn->params.peer->last_tx_at = ktime_get_seconds();
732         _leave(" = 0");
733         return 0;
734 }
735
736 /*
737  * calculate the response checksum
738  */
739 static void rxkad_calc_response_checksum(struct rxkad_response *response)
740 {
741         u32 csum = 1000003;
742         int loop;
743         u8 *p = (u8 *) response;
744
745         for (loop = sizeof(*response); loop > 0; loop--)
746                 csum = csum * 0x10204081 + *p++;
747
748         response->encrypted.checksum = htonl(csum);
749 }
750
751 /*
752  * encrypt the response packet
753  */
754 static void rxkad_encrypt_response(struct rxrpc_connection *conn,
755                                    struct rxkad_response *resp,
756                                    const struct rxkad_key *s2)
757 {
758         SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher);
759         struct rxrpc_crypt iv;
760         struct scatterlist sg[1];
761
762         /* continue encrypting from where we left off */
763         memcpy(&iv, s2->session_key, sizeof(iv));
764
765         sg_init_table(sg, 1);
766         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
767         skcipher_request_set_sync_tfm(req, conn->cipher);
768         skcipher_request_set_callback(req, 0, NULL, NULL);
769         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
770         crypto_skcipher_encrypt(req);
771         skcipher_request_zero(req);
772 }
773
774 /*
775  * respond to a challenge packet
776  */
777 static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
778                                       struct sk_buff *skb,
779                                       u32 *_abort_code)
780 {
781         const struct rxrpc_key_token *token;
782         struct rxkad_challenge challenge;
783         struct rxkad_response *resp;
784         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
785         const char *eproto;
786         u32 version, nonce, min_level, abort_code;
787         int ret;
788
789         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
790
791         eproto = tracepoint_string("chall_no_key");
792         abort_code = RX_PROTOCOL_ERROR;
793         if (!conn->params.key)
794                 goto protocol_error;
795
796         abort_code = RXKADEXPIRED;
797         ret = key_validate(conn->params.key);
798         if (ret < 0)
799                 goto other_error;
800
801         eproto = tracepoint_string("chall_short");
802         abort_code = RXKADPACKETSHORT;
803         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
804                           &challenge, sizeof(challenge)) < 0)
805                 goto protocol_error;
806
807         version = ntohl(challenge.version);
808         nonce = ntohl(challenge.nonce);
809         min_level = ntohl(challenge.min_level);
810
811         _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }",
812                sp->hdr.serial, version, nonce, min_level);
813
814         eproto = tracepoint_string("chall_ver");
815         abort_code = RXKADINCONSISTENCY;
816         if (version != RXKAD_VERSION)
817                 goto protocol_error;
818
819         abort_code = RXKADLEVELFAIL;
820         ret = -EACCES;
821         if (conn->params.security_level < min_level)
822                 goto other_error;
823
824         token = conn->params.key->payload.data[0];
825
826         /* build the response packet */
827         resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
828         if (!resp)
829                 return -ENOMEM;
830
831         resp->version                   = htonl(RXKAD_VERSION);
832         resp->encrypted.epoch           = htonl(conn->proto.epoch);
833         resp->encrypted.cid             = htonl(conn->proto.cid);
834         resp->encrypted.securityIndex   = htonl(conn->security_ix);
835         resp->encrypted.inc_nonce       = htonl(nonce + 1);
836         resp->encrypted.level           = htonl(conn->params.security_level);
837         resp->kvno                      = htonl(token->kad->kvno);
838         resp->ticket_len                = htonl(token->kad->ticket_len);
839         resp->encrypted.call_id[0]      = htonl(conn->channels[0].call_counter);
840         resp->encrypted.call_id[1]      = htonl(conn->channels[1].call_counter);
841         resp->encrypted.call_id[2]      = htonl(conn->channels[2].call_counter);
842         resp->encrypted.call_id[3]      = htonl(conn->channels[3].call_counter);
843
844         /* calculate the response checksum and then do the encryption */
845         rxkad_calc_response_checksum(resp);
846         rxkad_encrypt_response(conn, resp, token->kad);
847         ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad);
848         kfree(resp);
849         return ret;
850
851 protocol_error:
852         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
853         ret = -EPROTO;
854 other_error:
855         *_abort_code = abort_code;
856         return ret;
857 }
858
859 /*
860  * decrypt the kerberos IV ticket in the response
861  */
862 static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
863                                 struct sk_buff *skb,
864                                 void *ticket, size_t ticket_len,
865                                 struct rxrpc_crypt *_session_key,
866                                 time64_t *_expiry,
867                                 u32 *_abort_code)
868 {
869         struct skcipher_request *req;
870         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
871         struct rxrpc_crypt iv, key;
872         struct scatterlist sg[1];
873         struct in_addr addr;
874         unsigned int life;
875         const char *eproto;
876         time64_t issue, now;
877         bool little_endian;
878         int ret;
879         u32 abort_code;
880         u8 *p, *q, *name, *end;
881
882         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key));
883
884         *_expiry = 0;
885
886         ret = key_validate(conn->server_key);
887         if (ret < 0) {
888                 switch (ret) {
889                 case -EKEYEXPIRED:
890                         abort_code = RXKADEXPIRED;
891                         goto other_error;
892                 default:
893                         abort_code = RXKADNOAUTH;
894                         goto other_error;
895                 }
896         }
897
898         ASSERT(conn->server_key->payload.data[0] != NULL);
899         ASSERTCMP((unsigned long) ticket & 7UL, ==, 0);
900
901         memcpy(&iv, &conn->server_key->payload.data[2], sizeof(iv));
902
903         ret = -ENOMEM;
904         req = skcipher_request_alloc(conn->server_key->payload.data[0],
905                                      GFP_NOFS);
906         if (!req)
907                 goto temporary_error;
908
909         sg_init_one(&sg[0], ticket, ticket_len);
910         skcipher_request_set_callback(req, 0, NULL, NULL);
911         skcipher_request_set_crypt(req, sg, sg, ticket_len, iv.x);
912         crypto_skcipher_decrypt(req);
913         skcipher_request_free(req);
914
915         p = ticket;
916         end = p + ticket_len;
917
918 #define Z(field)                                        \
919         ({                                              \
920                 u8 *__str = p;                          \
921                 eproto = tracepoint_string("rxkad_bad_"#field); \
922                 q = memchr(p, 0, end - p);              \
923                 if (!q || q - p > (field##_SZ))         \
924                         goto bad_ticket;                \
925                 for (; p < q; p++)                      \
926                         if (!isprint(*p))               \
927                                 goto bad_ticket;        \
928                 p++;                                    \
929                 __str;                                  \
930         })
931
932         /* extract the ticket flags */
933         _debug("KIV FLAGS: %x", *p);
934         little_endian = *p & 1;
935         p++;
936
937         /* extract the authentication name */
938         name = Z(ANAME);
939         _debug("KIV ANAME: %s", name);
940
941         /* extract the principal's instance */
942         name = Z(INST);
943         _debug("KIV INST : %s", name);
944
945         /* extract the principal's authentication domain */
946         name = Z(REALM);
947         _debug("KIV REALM: %s", name);
948
949         eproto = tracepoint_string("rxkad_bad_len");
950         if (end - p < 4 + 8 + 4 + 2)
951                 goto bad_ticket;
952
953         /* get the IPv4 address of the entity that requested the ticket */
954         memcpy(&addr, p, sizeof(addr));
955         p += 4;
956         _debug("KIV ADDR : %pI4", &addr);
957
958         /* get the session key from the ticket */
959         memcpy(&key, p, sizeof(key));
960         p += 8;
961         _debug("KIV KEY  : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1]));
962         memcpy(_session_key, &key, sizeof(key));
963
964         /* get the ticket's lifetime */
965         life = *p++ * 5 * 60;
966         _debug("KIV LIFE : %u", life);
967
968         /* get the issue time of the ticket */
969         if (little_endian) {
970                 __le32 stamp;
971                 memcpy(&stamp, p, 4);
972                 issue = rxrpc_u32_to_time64(le32_to_cpu(stamp));
973         } else {
974                 __be32 stamp;
975                 memcpy(&stamp, p, 4);
976                 issue = rxrpc_u32_to_time64(be32_to_cpu(stamp));
977         }
978         p += 4;
979         now = ktime_get_real_seconds();
980         _debug("KIV ISSUE: %llx [%llx]", issue, now);
981
982         /* check the ticket is in date */
983         if (issue > now) {
984                 abort_code = RXKADNOAUTH;
985                 ret = -EKEYREJECTED;
986                 goto other_error;
987         }
988
989         if (issue < now - life) {
990                 abort_code = RXKADEXPIRED;
991                 ret = -EKEYEXPIRED;
992                 goto other_error;
993         }
994
995         *_expiry = issue + life;
996
997         /* get the service name */
998         name = Z(SNAME);
999         _debug("KIV SNAME: %s", name);
1000
1001         /* get the service instance name */
1002         name = Z(INST);
1003         _debug("KIV SINST: %s", name);
1004         return 0;
1005
1006 bad_ticket:
1007         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1008         abort_code = RXKADBADTICKET;
1009         ret = -EPROTO;
1010 other_error:
1011         *_abort_code = abort_code;
1012         return ret;
1013 temporary_error:
1014         return ret;
1015 }
1016
1017 /*
1018  * decrypt the response packet
1019  */
1020 static void rxkad_decrypt_response(struct rxrpc_connection *conn,
1021                                    struct rxkad_response *resp,
1022                                    const struct rxrpc_crypt *session_key)
1023 {
1024         SYNC_SKCIPHER_REQUEST_ON_STACK(req, rxkad_ci);
1025         struct scatterlist sg[1];
1026         struct rxrpc_crypt iv;
1027
1028         _enter(",,%08x%08x",
1029                ntohl(session_key->n[0]), ntohl(session_key->n[1]));
1030
1031         ASSERT(rxkad_ci != NULL);
1032
1033         mutex_lock(&rxkad_ci_mutex);
1034         if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x,
1035                                    sizeof(*session_key)) < 0)
1036                 BUG();
1037
1038         memcpy(&iv, session_key, sizeof(iv));
1039
1040         sg_init_table(sg, 1);
1041         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
1042         skcipher_request_set_sync_tfm(req, rxkad_ci);
1043         skcipher_request_set_callback(req, 0, NULL, NULL);
1044         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
1045         crypto_skcipher_decrypt(req);
1046         skcipher_request_zero(req);
1047
1048         mutex_unlock(&rxkad_ci_mutex);
1049
1050         _leave("");
1051 }
1052
1053 /*
1054  * verify a response
1055  */
1056 static int rxkad_verify_response(struct rxrpc_connection *conn,
1057                                  struct sk_buff *skb,
1058                                  u32 *_abort_code)
1059 {
1060         struct rxkad_response *response;
1061         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
1062         struct rxrpc_crypt session_key;
1063         const char *eproto;
1064         time64_t expiry;
1065         void *ticket;
1066         u32 abort_code, version, kvno, ticket_len, level;
1067         __be32 csum;
1068         int ret, i;
1069
1070         _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
1071
1072         ret = -ENOMEM;
1073         response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
1074         if (!response)
1075                 goto temporary_error;
1076
1077         eproto = tracepoint_string("rxkad_rsp_short");
1078         abort_code = RXKADPACKETSHORT;
1079         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1080                           response, sizeof(*response)) < 0)
1081                 goto protocol_error;
1082         if (!pskb_pull(skb, sizeof(*response)))
1083                 BUG();
1084
1085         version = ntohl(response->version);
1086         ticket_len = ntohl(response->ticket_len);
1087         kvno = ntohl(response->kvno);
1088         _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
1089                sp->hdr.serial, version, kvno, ticket_len);
1090
1091         eproto = tracepoint_string("rxkad_rsp_ver");
1092         abort_code = RXKADINCONSISTENCY;
1093         if (version != RXKAD_VERSION)
1094                 goto protocol_error;
1095
1096         eproto = tracepoint_string("rxkad_rsp_tktlen");
1097         abort_code = RXKADTICKETLEN;
1098         if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
1099                 goto protocol_error;
1100
1101         eproto = tracepoint_string("rxkad_rsp_unkkey");
1102         abort_code = RXKADUNKNOWNKEY;
1103         if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
1104                 goto protocol_error;
1105
1106         /* extract the kerberos ticket and decrypt and decode it */
1107         ret = -ENOMEM;
1108         ticket = kmalloc(ticket_len, GFP_NOFS);
1109         if (!ticket)
1110                 goto temporary_error;
1111
1112         eproto = tracepoint_string("rxkad_tkt_short");
1113         abort_code = RXKADPACKETSHORT;
1114         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1115                           ticket, ticket_len) < 0)
1116                 goto protocol_error_free;
1117
1118         ret = rxkad_decrypt_ticket(conn, skb, ticket, ticket_len, &session_key,
1119                                    &expiry, _abort_code);
1120         if (ret < 0)
1121                 goto temporary_error_free_resp;
1122
1123         /* use the session key from inside the ticket to decrypt the
1124          * response */
1125         rxkad_decrypt_response(conn, response, &session_key);
1126
1127         eproto = tracepoint_string("rxkad_rsp_param");
1128         abort_code = RXKADSEALEDINCON;
1129         if (ntohl(response->encrypted.epoch) != conn->proto.epoch)
1130                 goto protocol_error_free;
1131         if (ntohl(response->encrypted.cid) != conn->proto.cid)
1132                 goto protocol_error_free;
1133         if (ntohl(response->encrypted.securityIndex) != conn->security_ix)
1134                 goto protocol_error_free;
1135         csum = response->encrypted.checksum;
1136         response->encrypted.checksum = 0;
1137         rxkad_calc_response_checksum(response);
1138         eproto = tracepoint_string("rxkad_rsp_csum");
1139         if (response->encrypted.checksum != csum)
1140                 goto protocol_error_free;
1141
1142         spin_lock(&conn->channel_lock);
1143         for (i = 0; i < RXRPC_MAXCALLS; i++) {
1144                 struct rxrpc_call *call;
1145                 u32 call_id = ntohl(response->encrypted.call_id[i]);
1146
1147                 eproto = tracepoint_string("rxkad_rsp_callid");
1148                 if (call_id > INT_MAX)
1149                         goto protocol_error_unlock;
1150
1151                 eproto = tracepoint_string("rxkad_rsp_callctr");
1152                 if (call_id < conn->channels[i].call_counter)
1153                         goto protocol_error_unlock;
1154
1155                 eproto = tracepoint_string("rxkad_rsp_callst");
1156                 if (call_id > conn->channels[i].call_counter) {
1157                         call = rcu_dereference_protected(
1158                                 conn->channels[i].call,
1159                                 lockdep_is_held(&conn->channel_lock));
1160                         if (call && call->state < RXRPC_CALL_COMPLETE)
1161                                 goto protocol_error_unlock;
1162                         conn->channels[i].call_counter = call_id;
1163                 }
1164         }
1165         spin_unlock(&conn->channel_lock);
1166
1167         eproto = tracepoint_string("rxkad_rsp_seq");
1168         abort_code = RXKADOUTOFSEQUENCE;
1169         if (ntohl(response->encrypted.inc_nonce) != conn->security_nonce + 1)
1170                 goto protocol_error_free;
1171
1172         eproto = tracepoint_string("rxkad_rsp_level");
1173         abort_code = RXKADLEVELFAIL;
1174         level = ntohl(response->encrypted.level);
1175         if (level > RXRPC_SECURITY_ENCRYPT)
1176                 goto protocol_error_free;
1177         conn->params.security_level = level;
1178
1179         /* create a key to hold the security data and expiration time - after
1180          * this the connection security can be handled in exactly the same way
1181          * as for a client connection */
1182         ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
1183         if (ret < 0)
1184                 goto temporary_error_free_ticket;
1185
1186         kfree(ticket);
1187         kfree(response);
1188         _leave(" = 0");
1189         return 0;
1190
1191 protocol_error_unlock:
1192         spin_unlock(&conn->channel_lock);
1193 protocol_error_free:
1194         kfree(ticket);
1195 protocol_error:
1196         kfree(response);
1197         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1198         *_abort_code = abort_code;
1199         return -EPROTO;
1200
1201 temporary_error_free_ticket:
1202         kfree(ticket);
1203 temporary_error_free_resp:
1204         kfree(response);
1205 temporary_error:
1206         /* Ignore the response packet if we got a temporary error such as
1207          * ENOMEM.  We just want to send the challenge again.  Note that we
1208          * also come out this way if the ticket decryption fails.
1209          */
1210         return ret;
1211 }
1212
1213 /*
1214  * clear the connection security
1215  */
1216 static void rxkad_clear(struct rxrpc_connection *conn)
1217 {
1218         _enter("");
1219
1220         if (conn->cipher)
1221                 crypto_free_sync_skcipher(conn->cipher);
1222 }
1223
1224 /*
1225  * Initialise the rxkad security service.
1226  */
1227 static int rxkad_init(void)
1228 {
1229         /* pin the cipher we need so that the crypto layer doesn't invoke
1230          * keventd to go get it */
1231         rxkad_ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
1232         return PTR_ERR_OR_ZERO(rxkad_ci);
1233 }
1234
1235 /*
1236  * Clean up the rxkad security service.
1237  */
1238 static void rxkad_exit(void)
1239 {
1240         if (rxkad_ci)
1241                 crypto_free_sync_skcipher(rxkad_ci);
1242 }
1243
1244 /*
1245  * RxRPC Kerberos-based security
1246  */
1247 const struct rxrpc_security rxkad = {
1248         .name                           = "rxkad",
1249         .security_index                 = RXRPC_SECURITY_RXKAD,
1250         .init                           = rxkad_init,
1251         .exit                           = rxkad_exit,
1252         .init_connection_security       = rxkad_init_connection_security,
1253         .prime_packet_security          = rxkad_prime_packet_security,
1254         .secure_packet                  = rxkad_secure_packet,
1255         .verify_packet                  = rxkad_verify_packet,
1256         .locate_data                    = rxkad_locate_data,
1257         .issue_challenge                = rxkad_issue_challenge,
1258         .respond_to_challenge           = rxkad_respond_to_challenge,
1259         .verify_response                = rxkad_verify_response,
1260         .clear                          = rxkad_clear,
1261 };