Merge branch 'drm-next-5.1' of git://people.freedesktop.org/~agd5f/linux into drm...
[sfrench/cifs-2.6.git] / net / ipv4 / esp4.c
1 #define pr_fmt(fmt) "IPsec: " fmt
2
3 #include <crypto/aead.h>
4 #include <crypto/authenc.h>
5 #include <linux/err.h>
6 #include <linux/module.h>
7 #include <net/ip.h>
8 #include <net/xfrm.h>
9 #include <net/esp.h>
10 #include <linux/scatterlist.h>
11 #include <linux/kernel.h>
12 #include <linux/pfkeyv2.h>
13 #include <linux/rtnetlink.h>
14 #include <linux/slab.h>
15 #include <linux/spinlock.h>
16 #include <linux/in6.h>
17 #include <net/icmp.h>
18 #include <net/protocol.h>
19 #include <net/udp.h>
20
21 #include <linux/highmem.h>
22
23 struct esp_skb_cb {
24         struct xfrm_skb_cb xfrm;
25         void *tmp;
26 };
27
28 struct esp_output_extra {
29         __be32 seqhi;
30         u32 esphoff;
31 };
32
33 #define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
34
35 static u32 esp4_get_mtu(struct xfrm_state *x, int mtu);
36
37 /*
38  * Allocate an AEAD request structure with extra space for SG and IV.
39  *
40  * For alignment considerations the IV is placed at the front, followed
41  * by the request and finally the SG list.
42  *
43  * TODO: Use spare space in skb for this where possible.
44  */
45 static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int extralen)
46 {
47         unsigned int len;
48
49         len = extralen;
50
51         len += crypto_aead_ivsize(aead);
52
53         if (len) {
54                 len += crypto_aead_alignmask(aead) &
55                        ~(crypto_tfm_ctx_alignment() - 1);
56                 len = ALIGN(len, crypto_tfm_ctx_alignment());
57         }
58
59         len += sizeof(struct aead_request) + crypto_aead_reqsize(aead);
60         len = ALIGN(len, __alignof__(struct scatterlist));
61
62         len += sizeof(struct scatterlist) * nfrags;
63
64         return kmalloc(len, GFP_ATOMIC);
65 }
66
67 static inline void *esp_tmp_extra(void *tmp)
68 {
69         return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra));
70 }
71
72 static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int extralen)
73 {
74         return crypto_aead_ivsize(aead) ?
75                PTR_ALIGN((u8 *)tmp + extralen,
76                          crypto_aead_alignmask(aead) + 1) : tmp + extralen;
77 }
78
79 static inline struct aead_request *esp_tmp_req(struct crypto_aead *aead, u8 *iv)
80 {
81         struct aead_request *req;
82
83         req = (void *)PTR_ALIGN(iv + crypto_aead_ivsize(aead),
84                                 crypto_tfm_ctx_alignment());
85         aead_request_set_tfm(req, aead);
86         return req;
87 }
88
89 static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead,
90                                              struct aead_request *req)
91 {
92         return (void *)ALIGN((unsigned long)(req + 1) +
93                              crypto_aead_reqsize(aead),
94                              __alignof__(struct scatterlist));
95 }
96
97 static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
98 {
99         struct esp_output_extra *extra = esp_tmp_extra(tmp);
100         struct crypto_aead *aead = x->data;
101         int extralen = 0;
102         u8 *iv;
103         struct aead_request *req;
104         struct scatterlist *sg;
105
106         if (x->props.flags & XFRM_STATE_ESN)
107                 extralen += sizeof(*extra);
108
109         extra = esp_tmp_extra(tmp);
110         iv = esp_tmp_iv(aead, tmp, extralen);
111         req = esp_tmp_req(aead, iv);
112
113         /* Unref skb_frag_pages in the src scatterlist if necessary.
114          * Skip the first sg which comes from skb->data.
115          */
116         if (req->src != req->dst)
117                 for (sg = sg_next(req->src); sg; sg = sg_next(sg))
118                         put_page(sg_page(sg));
119 }
120
121 static void esp_output_done(struct crypto_async_request *base, int err)
122 {
123         struct sk_buff *skb = base->data;
124         struct xfrm_offload *xo = xfrm_offload(skb);
125         void *tmp;
126         struct xfrm_state *x;
127
128         if (xo && (xo->flags & XFRM_DEV_RESUME)) {
129                 struct sec_path *sp = skb_sec_path(skb);
130
131                 x = sp->xvec[sp->len - 1];
132         } else {
133                 x = skb_dst(skb)->xfrm;
134         }
135
136         tmp = ESP_SKB_CB(skb)->tmp;
137         esp_ssg_unref(x, tmp);
138         kfree(tmp);
139
140         if (xo && (xo->flags & XFRM_DEV_RESUME)) {
141                 if (err) {
142                         XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR);
143                         kfree_skb(skb);
144                         return;
145                 }
146
147                 skb_push(skb, skb->data - skb_mac_header(skb));
148                 secpath_reset(skb);
149                 xfrm_dev_resume(skb);
150         } else {
151                 xfrm_output_resume(skb, err);
152         }
153 }
154
155 /* Move ESP header back into place. */
156 static void esp_restore_header(struct sk_buff *skb, unsigned int offset)
157 {
158         struct ip_esp_hdr *esph = (void *)(skb->data + offset);
159         void *tmp = ESP_SKB_CB(skb)->tmp;
160         __be32 *seqhi = esp_tmp_extra(tmp);
161
162         esph->seq_no = esph->spi;
163         esph->spi = *seqhi;
164 }
165
166 static void esp_output_restore_header(struct sk_buff *skb)
167 {
168         void *tmp = ESP_SKB_CB(skb)->tmp;
169         struct esp_output_extra *extra = esp_tmp_extra(tmp);
170
171         esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff -
172                                 sizeof(__be32));
173 }
174
175 static struct ip_esp_hdr *esp_output_set_extra(struct sk_buff *skb,
176                                                struct xfrm_state *x,
177                                                struct ip_esp_hdr *esph,
178                                                struct esp_output_extra *extra)
179 {
180         /* For ESN we move the header forward by 4 bytes to
181          * accomodate the high bits.  We will move it back after
182          * encryption.
183          */
184         if ((x->props.flags & XFRM_STATE_ESN)) {
185                 __u32 seqhi;
186                 struct xfrm_offload *xo = xfrm_offload(skb);
187
188                 if (xo)
189                         seqhi = xo->seq.hi;
190                 else
191                         seqhi = XFRM_SKB_CB(skb)->seq.output.hi;
192
193                 extra->esphoff = (unsigned char *)esph -
194                                  skb_transport_header(skb);
195                 esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4);
196                 extra->seqhi = esph->spi;
197                 esph->seq_no = htonl(seqhi);
198         }
199
200         esph->spi = x->id.spi;
201
202         return esph;
203 }
204
205 static void esp_output_done_esn(struct crypto_async_request *base, int err)
206 {
207         struct sk_buff *skb = base->data;
208
209         esp_output_restore_header(skb);
210         esp_output_done(base, err);
211 }
212
213 static void esp_output_fill_trailer(u8 *tail, int tfclen, int plen, __u8 proto)
214 {
215         /* Fill padding... */
216         if (tfclen) {
217                 memset(tail, 0, tfclen);
218                 tail += tfclen;
219         }
220         do {
221                 int i;
222                 for (i = 0; i < plen - 2; i++)
223                         tail[i] = i + 1;
224         } while (0);
225         tail[plen - 2] = plen - 2;
226         tail[plen - 1] = proto;
227 }
228
229 static void esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
230 {
231         int encap_type;
232         struct udphdr *uh;
233         __be32 *udpdata32;
234         __be16 sport, dport;
235         struct xfrm_encap_tmpl *encap = x->encap;
236         struct ip_esp_hdr *esph = esp->esph;
237
238         spin_lock_bh(&x->lock);
239         sport = encap->encap_sport;
240         dport = encap->encap_dport;
241         encap_type = encap->encap_type;
242         spin_unlock_bh(&x->lock);
243
244         uh = (struct udphdr *)esph;
245         uh->source = sport;
246         uh->dest = dport;
247         uh->len = htons(skb->len + esp->tailen
248                   - skb_transport_offset(skb));
249         uh->check = 0;
250
251         switch (encap_type) {
252         default:
253         case UDP_ENCAP_ESPINUDP:
254                 esph = (struct ip_esp_hdr *)(uh + 1);
255                 break;
256         case UDP_ENCAP_ESPINUDP_NON_IKE:
257                 udpdata32 = (__be32 *)(uh + 1);
258                 udpdata32[0] = udpdata32[1] = 0;
259                 esph = (struct ip_esp_hdr *)(udpdata32 + 2);
260                 break;
261         }
262
263         *skb_mac_header(skb) = IPPROTO_UDP;
264         esp->esph = esph;
265 }
266
267 int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
268 {
269         u8 *tail;
270         u8 *vaddr;
271         int nfrags;
272         int esph_offset;
273         struct page *page;
274         struct sk_buff *trailer;
275         int tailen = esp->tailen;
276
277         /* this is non-NULL only with UDP Encapsulation */
278         if (x->encap)
279                 esp_output_udp_encap(x, skb, esp);
280
281         if (!skb_cloned(skb)) {
282                 if (tailen <= skb_tailroom(skb)) {
283                         nfrags = 1;
284                         trailer = skb;
285                         tail = skb_tail_pointer(trailer);
286
287                         goto skip_cow;
288                 } else if ((skb_shinfo(skb)->nr_frags < MAX_SKB_FRAGS)
289                            && !skb_has_frag_list(skb)) {
290                         int allocsize;
291                         struct sock *sk = skb->sk;
292                         struct page_frag *pfrag = &x->xfrag;
293
294                         esp->inplace = false;
295
296                         allocsize = ALIGN(tailen, L1_CACHE_BYTES);
297
298                         spin_lock_bh(&x->lock);
299
300                         if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) {
301                                 spin_unlock_bh(&x->lock);
302                                 goto cow;
303                         }
304
305                         page = pfrag->page;
306                         get_page(page);
307
308                         vaddr = kmap_atomic(page);
309
310                         tail = vaddr + pfrag->offset;
311
312                         esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
313
314                         kunmap_atomic(vaddr);
315
316                         nfrags = skb_shinfo(skb)->nr_frags;
317
318                         __skb_fill_page_desc(skb, nfrags, page, pfrag->offset,
319                                              tailen);
320                         skb_shinfo(skb)->nr_frags = ++nfrags;
321
322                         pfrag->offset = pfrag->offset + allocsize;
323
324                         spin_unlock_bh(&x->lock);
325
326                         nfrags++;
327
328                         skb->len += tailen;
329                         skb->data_len += tailen;
330                         skb->truesize += tailen;
331                         if (sk && sk_fullsock(sk))
332                                 refcount_add(tailen, &sk->sk_wmem_alloc);
333
334                         goto out;
335                 }
336         }
337
338 cow:
339         esph_offset = (unsigned char *)esp->esph - skb_transport_header(skb);
340
341         nfrags = skb_cow_data(skb, tailen, &trailer);
342         if (nfrags < 0)
343                 goto out;
344         tail = skb_tail_pointer(trailer);
345         esp->esph = (struct ip_esp_hdr *)(skb_transport_header(skb) + esph_offset);
346
347 skip_cow:
348         esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
349         pskb_put(skb, trailer, tailen);
350
351 out:
352         return nfrags;
353 }
354 EXPORT_SYMBOL_GPL(esp_output_head);
355
356 int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
357 {
358         u8 *iv;
359         int alen;
360         void *tmp;
361         int ivlen;
362         int assoclen;
363         int extralen;
364         struct page *page;
365         struct ip_esp_hdr *esph;
366         struct crypto_aead *aead;
367         struct aead_request *req;
368         struct scatterlist *sg, *dsg;
369         struct esp_output_extra *extra;
370         int err = -ENOMEM;
371
372         assoclen = sizeof(struct ip_esp_hdr);
373         extralen = 0;
374
375         if (x->props.flags & XFRM_STATE_ESN) {
376                 extralen += sizeof(*extra);
377                 assoclen += sizeof(__be32);
378         }
379
380         aead = x->data;
381         alen = crypto_aead_authsize(aead);
382         ivlen = crypto_aead_ivsize(aead);
383
384         tmp = esp_alloc_tmp(aead, esp->nfrags + 2, extralen);
385         if (!tmp)
386                 goto error;
387
388         extra = esp_tmp_extra(tmp);
389         iv = esp_tmp_iv(aead, tmp, extralen);
390         req = esp_tmp_req(aead, iv);
391         sg = esp_req_sg(aead, req);
392
393         if (esp->inplace)
394                 dsg = sg;
395         else
396                 dsg = &sg[esp->nfrags];
397
398         esph = esp_output_set_extra(skb, x, esp->esph, extra);
399         esp->esph = esph;
400
401         sg_init_table(sg, esp->nfrags);
402         err = skb_to_sgvec(skb, sg,
403                            (unsigned char *)esph - skb->data,
404                            assoclen + ivlen + esp->clen + alen);
405         if (unlikely(err < 0))
406                 goto error_free;
407
408         if (!esp->inplace) {
409                 int allocsize;
410                 struct page_frag *pfrag = &x->xfrag;
411
412                 allocsize = ALIGN(skb->data_len, L1_CACHE_BYTES);
413
414                 spin_lock_bh(&x->lock);
415                 if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) {
416                         spin_unlock_bh(&x->lock);
417                         goto error_free;
418                 }
419
420                 skb_shinfo(skb)->nr_frags = 1;
421
422                 page = pfrag->page;
423                 get_page(page);
424                 /* replace page frags in skb with new page */
425                 __skb_fill_page_desc(skb, 0, page, pfrag->offset, skb->data_len);
426                 pfrag->offset = pfrag->offset + allocsize;
427                 spin_unlock_bh(&x->lock);
428
429                 sg_init_table(dsg, skb_shinfo(skb)->nr_frags + 1);
430                 err = skb_to_sgvec(skb, dsg,
431                                    (unsigned char *)esph - skb->data,
432                                    assoclen + ivlen + esp->clen + alen);
433                 if (unlikely(err < 0))
434                         goto error_free;
435         }
436
437         if ((x->props.flags & XFRM_STATE_ESN))
438                 aead_request_set_callback(req, 0, esp_output_done_esn, skb);
439         else
440                 aead_request_set_callback(req, 0, esp_output_done, skb);
441
442         aead_request_set_crypt(req, sg, dsg, ivlen + esp->clen, iv);
443         aead_request_set_ad(req, assoclen);
444
445         memset(iv, 0, ivlen);
446         memcpy(iv + ivlen - min(ivlen, 8), (u8 *)&esp->seqno + 8 - min(ivlen, 8),
447                min(ivlen, 8));
448
449         ESP_SKB_CB(skb)->tmp = tmp;
450         err = crypto_aead_encrypt(req);
451
452         switch (err) {
453         case -EINPROGRESS:
454                 goto error;
455
456         case -ENOSPC:
457                 err = NET_XMIT_DROP;
458                 break;
459
460         case 0:
461                 if ((x->props.flags & XFRM_STATE_ESN))
462                         esp_output_restore_header(skb);
463         }
464
465         if (sg != dsg)
466                 esp_ssg_unref(x, tmp);
467
468 error_free:
469         kfree(tmp);
470 error:
471         return err;
472 }
473 EXPORT_SYMBOL_GPL(esp_output_tail);
474
475 static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
476 {
477         int alen;
478         int blksize;
479         struct ip_esp_hdr *esph;
480         struct crypto_aead *aead;
481         struct esp_info esp;
482
483         esp.inplace = true;
484
485         esp.proto = *skb_mac_header(skb);
486         *skb_mac_header(skb) = IPPROTO_ESP;
487
488         /* skb is pure payload to encrypt */
489
490         aead = x->data;
491         alen = crypto_aead_authsize(aead);
492
493         esp.tfclen = 0;
494         if (x->tfcpad) {
495                 struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb);
496                 u32 padto;
497
498                 padto = min(x->tfcpad, esp4_get_mtu(x, dst->child_mtu_cached));
499                 if (skb->len < padto)
500                         esp.tfclen = padto - skb->len;
501         }
502         blksize = ALIGN(crypto_aead_blocksize(aead), 4);
503         esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize);
504         esp.plen = esp.clen - skb->len - esp.tfclen;
505         esp.tailen = esp.tfclen + esp.plen + alen;
506
507         esp.esph = ip_esp_hdr(skb);
508
509         esp.nfrags = esp_output_head(x, skb, &esp);
510         if (esp.nfrags < 0)
511                 return esp.nfrags;
512
513         esph = esp.esph;
514         esph->spi = x->id.spi;
515
516         esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low);
517         esp.seqno = cpu_to_be64(XFRM_SKB_CB(skb)->seq.output.low +
518                                  ((u64)XFRM_SKB_CB(skb)->seq.output.hi << 32));
519
520         skb_push(skb, -skb_network_offset(skb));
521
522         return esp_output_tail(x, skb, &esp);
523 }
524
525 static inline int esp_remove_trailer(struct sk_buff *skb)
526 {
527         struct xfrm_state *x = xfrm_input_state(skb);
528         struct xfrm_offload *xo = xfrm_offload(skb);
529         struct crypto_aead *aead = x->data;
530         int alen, hlen, elen;
531         int padlen, trimlen;
532         __wsum csumdiff;
533         u8 nexthdr[2];
534         int ret;
535
536         alen = crypto_aead_authsize(aead);
537         hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead);
538         elen = skb->len - hlen;
539
540         if (xo && (xo->flags & XFRM_ESP_NO_TRAILER)) {
541                 ret = xo->proto;
542                 goto out;
543         }
544
545         if (skb_copy_bits(skb, skb->len - alen - 2, nexthdr, 2))
546                 BUG();
547
548         ret = -EINVAL;
549         padlen = nexthdr[0];
550         if (padlen + 2 + alen >= elen) {
551                 net_dbg_ratelimited("ipsec esp packet is garbage padlen=%d, elen=%d\n",
552                                     padlen + 2, elen - alen);
553                 goto out;
554         }
555
556         trimlen = alen + padlen + 2;
557         if (skb->ip_summed == CHECKSUM_COMPLETE) {
558                 csumdiff = skb_checksum(skb, skb->len - trimlen, trimlen, 0);
559                 skb->csum = csum_block_sub(skb->csum, csumdiff,
560                                            skb->len - trimlen);
561         }
562         pskb_trim(skb, skb->len - trimlen);
563
564         ret = nexthdr[1];
565
566 out:
567         return ret;
568 }
569
570 int esp_input_done2(struct sk_buff *skb, int err)
571 {
572         const struct iphdr *iph;
573         struct xfrm_state *x = xfrm_input_state(skb);
574         struct xfrm_offload *xo = xfrm_offload(skb);
575         struct crypto_aead *aead = x->data;
576         int hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead);
577         int ihl;
578
579         if (!xo || (xo && !(xo->flags & CRYPTO_DONE)))
580                 kfree(ESP_SKB_CB(skb)->tmp);
581
582         if (unlikely(err))
583                 goto out;
584
585         err = esp_remove_trailer(skb);
586         if (unlikely(err < 0))
587                 goto out;
588
589         iph = ip_hdr(skb);
590         ihl = iph->ihl * 4;
591
592         if (x->encap) {
593                 struct xfrm_encap_tmpl *encap = x->encap;
594                 struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
595
596                 /*
597                  * 1) if the NAT-T peer's IP or port changed then
598                  *    advertize the change to the keying daemon.
599                  *    This is an inbound SA, so just compare
600                  *    SRC ports.
601                  */
602                 if (iph->saddr != x->props.saddr.a4 ||
603                     uh->source != encap->encap_sport) {
604                         xfrm_address_t ipaddr;
605
606                         ipaddr.a4 = iph->saddr;
607                         km_new_mapping(x, &ipaddr, uh->source);
608
609                         /* XXX: perhaps add an extra
610                          * policy check here, to see
611                          * if we should allow or
612                          * reject a packet from a
613                          * different source
614                          * address/port.
615                          */
616                 }
617
618                 /*
619                  * 2) ignore UDP/TCP checksums in case
620                  *    of NAT-T in Transport Mode, or
621                  *    perform other post-processing fixes
622                  *    as per draft-ietf-ipsec-udp-encaps-06,
623                  *    section 3.1.2
624                  */
625                 if (x->props.mode == XFRM_MODE_TRANSPORT)
626                         skb->ip_summed = CHECKSUM_UNNECESSARY;
627         }
628
629         skb_pull_rcsum(skb, hlen);
630         if (x->props.mode == XFRM_MODE_TUNNEL)
631                 skb_reset_transport_header(skb);
632         else
633                 skb_set_transport_header(skb, -ihl);
634
635         /* RFC4303: Drop dummy packets without any error */
636         if (err == IPPROTO_NONE)
637                 err = -EINVAL;
638
639 out:
640         return err;
641 }
642 EXPORT_SYMBOL_GPL(esp_input_done2);
643
644 static void esp_input_done(struct crypto_async_request *base, int err)
645 {
646         struct sk_buff *skb = base->data;
647
648         xfrm_input_resume(skb, esp_input_done2(skb, err));
649 }
650
651 static void esp_input_restore_header(struct sk_buff *skb)
652 {
653         esp_restore_header(skb, 0);
654         __skb_pull(skb, 4);
655 }
656
657 static void esp_input_set_header(struct sk_buff *skb, __be32 *seqhi)
658 {
659         struct xfrm_state *x = xfrm_input_state(skb);
660         struct ip_esp_hdr *esph;
661
662         /* For ESN we move the header forward by 4 bytes to
663          * accomodate the high bits.  We will move it back after
664          * decryption.
665          */
666         if ((x->props.flags & XFRM_STATE_ESN)) {
667                 esph = skb_push(skb, 4);
668                 *seqhi = esph->spi;
669                 esph->spi = esph->seq_no;
670                 esph->seq_no = XFRM_SKB_CB(skb)->seq.input.hi;
671         }
672 }
673
674 static void esp_input_done_esn(struct crypto_async_request *base, int err)
675 {
676         struct sk_buff *skb = base->data;
677
678         esp_input_restore_header(skb);
679         esp_input_done(base, err);
680 }
681
682 /*
683  * Note: detecting truncated vs. non-truncated authentication data is very
684  * expensive, so we only support truncated data, which is the recommended
685  * and common case.
686  */
687 static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
688 {
689         struct crypto_aead *aead = x->data;
690         struct aead_request *req;
691         struct sk_buff *trailer;
692         int ivlen = crypto_aead_ivsize(aead);
693         int elen = skb->len - sizeof(struct ip_esp_hdr) - ivlen;
694         int nfrags;
695         int assoclen;
696         int seqhilen;
697         __be32 *seqhi;
698         void *tmp;
699         u8 *iv;
700         struct scatterlist *sg;
701         int err = -EINVAL;
702
703         if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + ivlen))
704                 goto out;
705
706         if (elen <= 0)
707                 goto out;
708
709         assoclen = sizeof(struct ip_esp_hdr);
710         seqhilen = 0;
711
712         if (x->props.flags & XFRM_STATE_ESN) {
713                 seqhilen += sizeof(__be32);
714                 assoclen += seqhilen;
715         }
716
717         if (!skb_cloned(skb)) {
718                 if (!skb_is_nonlinear(skb)) {
719                         nfrags = 1;
720
721                         goto skip_cow;
722                 } else if (!skb_has_frag_list(skb)) {
723                         nfrags = skb_shinfo(skb)->nr_frags;
724                         nfrags++;
725
726                         goto skip_cow;
727                 }
728         }
729
730         err = skb_cow_data(skb, 0, &trailer);
731         if (err < 0)
732                 goto out;
733
734         nfrags = err;
735
736 skip_cow:
737         err = -ENOMEM;
738         tmp = esp_alloc_tmp(aead, nfrags, seqhilen);
739         if (!tmp)
740                 goto out;
741
742         ESP_SKB_CB(skb)->tmp = tmp;
743         seqhi = esp_tmp_extra(tmp);
744         iv = esp_tmp_iv(aead, tmp, seqhilen);
745         req = esp_tmp_req(aead, iv);
746         sg = esp_req_sg(aead, req);
747
748         esp_input_set_header(skb, seqhi);
749
750         sg_init_table(sg, nfrags);
751         err = skb_to_sgvec(skb, sg, 0, skb->len);
752         if (unlikely(err < 0)) {
753                 kfree(tmp);
754                 goto out;
755         }
756
757         skb->ip_summed = CHECKSUM_NONE;
758
759         if ((x->props.flags & XFRM_STATE_ESN))
760                 aead_request_set_callback(req, 0, esp_input_done_esn, skb);
761         else
762                 aead_request_set_callback(req, 0, esp_input_done, skb);
763
764         aead_request_set_crypt(req, sg, sg, elen + ivlen, iv);
765         aead_request_set_ad(req, assoclen);
766
767         err = crypto_aead_decrypt(req);
768         if (err == -EINPROGRESS)
769                 goto out;
770
771         if ((x->props.flags & XFRM_STATE_ESN))
772                 esp_input_restore_header(skb);
773
774         err = esp_input_done2(skb, err);
775
776 out:
777         return err;
778 }
779
780 static u32 esp4_get_mtu(struct xfrm_state *x, int mtu)
781 {
782         struct crypto_aead *aead = x->data;
783         u32 blksize = ALIGN(crypto_aead_blocksize(aead), 4);
784         unsigned int net_adj;
785
786         switch (x->props.mode) {
787         case XFRM_MODE_TRANSPORT:
788         case XFRM_MODE_BEET:
789                 net_adj = sizeof(struct iphdr);
790                 break;
791         case XFRM_MODE_TUNNEL:
792                 net_adj = 0;
793                 break;
794         default:
795                 BUG();
796         }
797
798         return ((mtu - x->props.header_len - crypto_aead_authsize(aead) -
799                  net_adj) & ~(blksize - 1)) + net_adj - 2;
800 }
801
802 static int esp4_err(struct sk_buff *skb, u32 info)
803 {
804         struct net *net = dev_net(skb->dev);
805         const struct iphdr *iph = (const struct iphdr *)skb->data;
806         struct ip_esp_hdr *esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2));
807         struct xfrm_state *x;
808
809         switch (icmp_hdr(skb)->type) {
810         case ICMP_DEST_UNREACH:
811                 if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED)
812                         return 0;
813         case ICMP_REDIRECT:
814                 break;
815         default:
816                 return 0;
817         }
818
819         x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr,
820                               esph->spi, IPPROTO_ESP, AF_INET);
821         if (!x)
822                 return 0;
823
824         if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH)
825                 ipv4_update_pmtu(skb, net, info, 0, IPPROTO_ESP);
826         else
827                 ipv4_redirect(skb, net, 0, IPPROTO_ESP);
828         xfrm_state_put(x);
829
830         return 0;
831 }
832
833 static void esp_destroy(struct xfrm_state *x)
834 {
835         struct crypto_aead *aead = x->data;
836
837         if (!aead)
838                 return;
839
840         crypto_free_aead(aead);
841 }
842
843 static int esp_init_aead(struct xfrm_state *x)
844 {
845         char aead_name[CRYPTO_MAX_ALG_NAME];
846         struct crypto_aead *aead;
847         int err;
848
849         err = -ENAMETOOLONG;
850         if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)",
851                      x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME)
852                 goto error;
853
854         aead = crypto_alloc_aead(aead_name, 0, 0);
855         err = PTR_ERR(aead);
856         if (IS_ERR(aead))
857                 goto error;
858
859         x->data = aead;
860
861         err = crypto_aead_setkey(aead, x->aead->alg_key,
862                                  (x->aead->alg_key_len + 7) / 8);
863         if (err)
864                 goto error;
865
866         err = crypto_aead_setauthsize(aead, x->aead->alg_icv_len / 8);
867         if (err)
868                 goto error;
869
870 error:
871         return err;
872 }
873
874 static int esp_init_authenc(struct xfrm_state *x)
875 {
876         struct crypto_aead *aead;
877         struct crypto_authenc_key_param *param;
878         struct rtattr *rta;
879         char *key;
880         char *p;
881         char authenc_name[CRYPTO_MAX_ALG_NAME];
882         unsigned int keylen;
883         int err;
884
885         err = -EINVAL;
886         if (!x->ealg)
887                 goto error;
888
889         err = -ENAMETOOLONG;
890
891         if ((x->props.flags & XFRM_STATE_ESN)) {
892                 if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME,
893                              "%s%sauthencesn(%s,%s)%s",
894                              x->geniv ?: "", x->geniv ? "(" : "",
895                              x->aalg ? x->aalg->alg_name : "digest_null",
896                              x->ealg->alg_name,
897                              x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME)
898                         goto error;
899         } else {
900                 if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME,
901                              "%s%sauthenc(%s,%s)%s",
902                              x->geniv ?: "", x->geniv ? "(" : "",
903                              x->aalg ? x->aalg->alg_name : "digest_null",
904                              x->ealg->alg_name,
905                              x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME)
906                         goto error;
907         }
908
909         aead = crypto_alloc_aead(authenc_name, 0, 0);
910         err = PTR_ERR(aead);
911         if (IS_ERR(aead))
912                 goto error;
913
914         x->data = aead;
915
916         keylen = (x->aalg ? (x->aalg->alg_key_len + 7) / 8 : 0) +
917                  (x->ealg->alg_key_len + 7) / 8 + RTA_SPACE(sizeof(*param));
918         err = -ENOMEM;
919         key = kmalloc(keylen, GFP_KERNEL);
920         if (!key)
921                 goto error;
922
923         p = key;
924         rta = (void *)p;
925         rta->rta_type = CRYPTO_AUTHENC_KEYA_PARAM;
926         rta->rta_len = RTA_LENGTH(sizeof(*param));
927         param = RTA_DATA(rta);
928         p += RTA_SPACE(sizeof(*param));
929
930         if (x->aalg) {
931                 struct xfrm_algo_desc *aalg_desc;
932
933                 memcpy(p, x->aalg->alg_key, (x->aalg->alg_key_len + 7) / 8);
934                 p += (x->aalg->alg_key_len + 7) / 8;
935
936                 aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0);
937                 BUG_ON(!aalg_desc);
938
939                 err = -EINVAL;
940                 if (aalg_desc->uinfo.auth.icv_fullbits / 8 !=
941                     crypto_aead_authsize(aead)) {
942                         pr_info("ESP: %s digestsize %u != %hu\n",
943                                 x->aalg->alg_name,
944                                 crypto_aead_authsize(aead),
945                                 aalg_desc->uinfo.auth.icv_fullbits / 8);
946                         goto free_key;
947                 }
948
949                 err = crypto_aead_setauthsize(
950                         aead, x->aalg->alg_trunc_len / 8);
951                 if (err)
952                         goto free_key;
953         }
954
955         param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8);
956         memcpy(p, x->ealg->alg_key, (x->ealg->alg_key_len + 7) / 8);
957
958         err = crypto_aead_setkey(aead, key, keylen);
959
960 free_key:
961         kfree(key);
962
963 error:
964         return err;
965 }
966
967 static int esp_init_state(struct xfrm_state *x)
968 {
969         struct crypto_aead *aead;
970         u32 align;
971         int err;
972
973         x->data = NULL;
974
975         if (x->aead)
976                 err = esp_init_aead(x);
977         else
978                 err = esp_init_authenc(x);
979
980         if (err)
981                 goto error;
982
983         aead = x->data;
984
985         x->props.header_len = sizeof(struct ip_esp_hdr) +
986                               crypto_aead_ivsize(aead);
987         if (x->props.mode == XFRM_MODE_TUNNEL)
988                 x->props.header_len += sizeof(struct iphdr);
989         else if (x->props.mode == XFRM_MODE_BEET && x->sel.family != AF_INET6)
990                 x->props.header_len += IPV4_BEET_PHMAXLEN;
991         if (x->encap) {
992                 struct xfrm_encap_tmpl *encap = x->encap;
993
994                 switch (encap->encap_type) {
995                 default:
996                         err = -EINVAL;
997                         goto error;
998                 case UDP_ENCAP_ESPINUDP:
999                         x->props.header_len += sizeof(struct udphdr);
1000                         break;
1001                 case UDP_ENCAP_ESPINUDP_NON_IKE:
1002                         x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
1003                         break;
1004                 }
1005         }
1006
1007         align = ALIGN(crypto_aead_blocksize(aead), 4);
1008         x->props.trailer_len = align + 1 + crypto_aead_authsize(aead);
1009
1010 error:
1011         return err;
1012 }
1013
1014 static int esp4_rcv_cb(struct sk_buff *skb, int err)
1015 {
1016         return 0;
1017 }
1018
1019 static const struct xfrm_type esp_type =
1020 {
1021         .description    = "ESP4",
1022         .owner          = THIS_MODULE,
1023         .proto          = IPPROTO_ESP,
1024         .flags          = XFRM_TYPE_REPLAY_PROT,
1025         .init_state     = esp_init_state,
1026         .destructor     = esp_destroy,
1027         .get_mtu        = esp4_get_mtu,
1028         .input          = esp_input,
1029         .output         = esp_output,
1030 };
1031
1032 static struct xfrm4_protocol esp4_protocol = {
1033         .handler        =       xfrm4_rcv,
1034         .input_handler  =       xfrm_input,
1035         .cb_handler     =       esp4_rcv_cb,
1036         .err_handler    =       esp4_err,
1037         .priority       =       0,
1038 };
1039
1040 static int __init esp4_init(void)
1041 {
1042         if (xfrm_register_type(&esp_type, AF_INET) < 0) {
1043                 pr_info("%s: can't add xfrm type\n", __func__);
1044                 return -EAGAIN;
1045         }
1046         if (xfrm4_protocol_register(&esp4_protocol, IPPROTO_ESP) < 0) {
1047                 pr_info("%s: can't add protocol\n", __func__);
1048                 xfrm_unregister_type(&esp_type, AF_INET);
1049                 return -EAGAIN;
1050         }
1051         return 0;
1052 }
1053
1054 static void __exit esp4_fini(void)
1055 {
1056         if (xfrm4_protocol_deregister(&esp4_protocol, IPPROTO_ESP) < 0)
1057                 pr_info("%s: can't remove protocol\n", __func__);
1058         if (xfrm_unregister_type(&esp_type, AF_INET) < 0)
1059                 pr_info("%s: can't remove xfrm type\n", __func__);
1060 }
1061
1062 module_init(esp4_init);
1063 module_exit(esp4_fini);
1064 MODULE_LICENSE("GPL");
1065 MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_ESP);