bd7140885e60ec06c269fd7b1b014dc832aedf16
[sfrench/cifs-2.6.git] / net / ipv6 / seg6_local.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Authors:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  *  eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
8  */
9
10 #include <linux/types.h>
11 #include <linux/skbuff.h>
12 #include <linux/net.h>
13 #include <linux/module.h>
14 #include <net/ip.h>
15 #include <net/lwtunnel.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include <net/ip6_fib.h>
19 #include <net/route.h>
20 #include <net/seg6.h>
21 #include <linux/seg6.h>
22 #include <linux/seg6_local.h>
23 #include <net/addrconf.h>
24 #include <net/ip6_route.h>
25 #include <net/dst_cache.h>
26 #include <net/ip_tunnels.h>
27 #ifdef CONFIG_IPV6_SEG6_HMAC
28 #include <net/seg6_hmac.h>
29 #endif
30 #include <net/seg6_local.h>
31 #include <linux/etherdevice.h>
32 #include <linux/bpf.h>
33
34 #define SEG6_F_ATTR(i)          BIT(i)
35
36 struct seg6_local_lwt;
37
38 /* callbacks used for customizing the creation and destruction of a behavior */
39 struct seg6_local_lwtunnel_ops {
40         int (*build_state)(struct seg6_local_lwt *slwt, const void *cfg,
41                            struct netlink_ext_ack *extack);
42         void (*destroy_state)(struct seg6_local_lwt *slwt);
43 };
44
45 struct seg6_action_desc {
46         int action;
47         unsigned long attrs;
48
49         /* The optattrs field is used for specifying all the optional
50          * attributes supported by a specific behavior.
51          * It means that if one of these attributes is not provided in the
52          * netlink message during the behavior creation, no errors will be
53          * returned to the userspace.
54          *
55          * Each attribute can be only of two types (mutually exclusive):
56          * 1) required or 2) optional.
57          * Every user MUST obey to this rule! If you set an attribute as
58          * required the same attribute CANNOT be set as optional and vice
59          * versa.
60          */
61         unsigned long optattrs;
62
63         int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
64         int static_headroom;
65
66         struct seg6_local_lwtunnel_ops slwt_ops;
67 };
68
69 struct bpf_lwt_prog {
70         struct bpf_prog *prog;
71         char *name;
72 };
73
74 enum seg6_end_dt_mode {
75         DT_INVALID_MODE = -EINVAL,
76         DT_LEGACY_MODE  = 0,
77         DT_VRF_MODE     = 1,
78 };
79
80 struct seg6_end_dt_info {
81         enum seg6_end_dt_mode mode;
82
83         struct net *net;
84         /* VRF device associated to the routing table used by the SRv6
85          * End.DT4/DT6 behavior for routing IPv4/IPv6 packets.
86          */
87         int vrf_ifindex;
88         int vrf_table;
89
90         /* tunneled packet proto and family (IPv4 or IPv6) */
91         __be16 proto;
92         u16 family;
93         int hdrlen;
94 };
95
96 struct seg6_local_lwt {
97         int action;
98         struct ipv6_sr_hdr *srh;
99         int table;
100         struct in_addr nh4;
101         struct in6_addr nh6;
102         int iif;
103         int oif;
104         struct bpf_lwt_prog bpf;
105 #ifdef CONFIG_NET_L3_MASTER_DEV
106         struct seg6_end_dt_info dt_info;
107 #endif
108
109         int headroom;
110         struct seg6_action_desc *desc;
111         /* unlike the required attrs, we have to track the optional attributes
112          * that have been effectively parsed.
113          */
114         unsigned long parsed_optattrs;
115 };
116
117 static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
118 {
119         return (struct seg6_local_lwt *)lwt->data;
120 }
121
122 static struct ipv6_sr_hdr *get_srh(struct sk_buff *skb, int flags)
123 {
124         struct ipv6_sr_hdr *srh;
125         int len, srhoff = 0;
126
127         if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, &flags) < 0)
128                 return NULL;
129
130         if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
131                 return NULL;
132
133         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
134
135         len = (srh->hdrlen + 1) << 3;
136
137         if (!pskb_may_pull(skb, srhoff + len))
138                 return NULL;
139
140         /* note that pskb_may_pull may change pointers in header;
141          * for this reason it is necessary to reload them when needed.
142          */
143         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
144
145         if (!seg6_validate_srh(srh, len, true))
146                 return NULL;
147
148         return srh;
149 }
150
151 static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
152 {
153         struct ipv6_sr_hdr *srh;
154
155         srh = get_srh(skb, IP6_FH_F_SKIP_RH);
156         if (!srh)
157                 return NULL;
158
159 #ifdef CONFIG_IPV6_SEG6_HMAC
160         if (!seg6_hmac_validate_skb(skb))
161                 return NULL;
162 #endif
163
164         return srh;
165 }
166
167 static bool decap_and_validate(struct sk_buff *skb, int proto)
168 {
169         struct ipv6_sr_hdr *srh;
170         unsigned int off = 0;
171
172         srh = get_srh(skb, 0);
173         if (srh && srh->segments_left > 0)
174                 return false;
175
176 #ifdef CONFIG_IPV6_SEG6_HMAC
177         if (srh && !seg6_hmac_validate_skb(skb))
178                 return false;
179 #endif
180
181         if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
182                 return false;
183
184         if (!pskb_pull(skb, off))
185                 return false;
186
187         skb_postpull_rcsum(skb, skb_network_header(skb), off);
188
189         skb_reset_network_header(skb);
190         skb_reset_transport_header(skb);
191         if (iptunnel_pull_offloads(skb))
192                 return false;
193
194         return true;
195 }
196
197 static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
198 {
199         struct in6_addr *addr;
200
201         srh->segments_left--;
202         addr = srh->segments + srh->segments_left;
203         *daddr = *addr;
204 }
205
206 static int
207 seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
208                         u32 tbl_id, bool local_delivery)
209 {
210         struct net *net = dev_net(skb->dev);
211         struct ipv6hdr *hdr = ipv6_hdr(skb);
212         int flags = RT6_LOOKUP_F_HAS_SADDR;
213         struct dst_entry *dst = NULL;
214         struct rt6_info *rt;
215         struct flowi6 fl6;
216         int dev_flags = 0;
217
218         fl6.flowi6_iif = skb->dev->ifindex;
219         fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
220         fl6.saddr = hdr->saddr;
221         fl6.flowlabel = ip6_flowinfo(hdr);
222         fl6.flowi6_mark = skb->mark;
223         fl6.flowi6_proto = hdr->nexthdr;
224
225         if (nhaddr)
226                 fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
227
228         if (!tbl_id) {
229                 dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
230         } else {
231                 struct fib6_table *table;
232
233                 table = fib6_get_table(net, tbl_id);
234                 if (!table)
235                         goto out;
236
237                 rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
238                 dst = &rt->dst;
239         }
240
241         /* we want to discard traffic destined for local packet processing,
242          * if @local_delivery is set to false.
243          */
244         if (!local_delivery)
245                 dev_flags |= IFF_LOOPBACK;
246
247         if (dst && (dst->dev->flags & dev_flags) && !dst->error) {
248                 dst_release(dst);
249                 dst = NULL;
250         }
251
252 out:
253         if (!dst) {
254                 rt = net->ipv6.ip6_blk_hole_entry;
255                 dst = &rt->dst;
256                 dst_hold(dst);
257         }
258
259         skb_dst_drop(skb);
260         skb_dst_set(skb, dst);
261         return dst->error;
262 }
263
264 int seg6_lookup_nexthop(struct sk_buff *skb,
265                         struct in6_addr *nhaddr, u32 tbl_id)
266 {
267         return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false);
268 }
269
270 /* regular endpoint function */
271 static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
272 {
273         struct ipv6_sr_hdr *srh;
274
275         srh = get_and_validate_srh(skb);
276         if (!srh)
277                 goto drop;
278
279         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
280
281         seg6_lookup_nexthop(skb, NULL, 0);
282
283         return dst_input(skb);
284
285 drop:
286         kfree_skb(skb);
287         return -EINVAL;
288 }
289
290 /* regular endpoint, and forward to specified nexthop */
291 static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
292 {
293         struct ipv6_sr_hdr *srh;
294
295         srh = get_and_validate_srh(skb);
296         if (!srh)
297                 goto drop;
298
299         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
300
301         seg6_lookup_nexthop(skb, &slwt->nh6, 0);
302
303         return dst_input(skb);
304
305 drop:
306         kfree_skb(skb);
307         return -EINVAL;
308 }
309
310 static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
311 {
312         struct ipv6_sr_hdr *srh;
313
314         srh = get_and_validate_srh(skb);
315         if (!srh)
316                 goto drop;
317
318         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
319
320         seg6_lookup_nexthop(skb, NULL, slwt->table);
321
322         return dst_input(skb);
323
324 drop:
325         kfree_skb(skb);
326         return -EINVAL;
327 }
328
329 /* decapsulate and forward inner L2 frame on specified interface */
330 static int input_action_end_dx2(struct sk_buff *skb,
331                                 struct seg6_local_lwt *slwt)
332 {
333         struct net *net = dev_net(skb->dev);
334         struct net_device *odev;
335         struct ethhdr *eth;
336
337         if (!decap_and_validate(skb, IPPROTO_ETHERNET))
338                 goto drop;
339
340         if (!pskb_may_pull(skb, ETH_HLEN))
341                 goto drop;
342
343         skb_reset_mac_header(skb);
344         eth = (struct ethhdr *)skb->data;
345
346         /* To determine the frame's protocol, we assume it is 802.3. This avoids
347          * a call to eth_type_trans(), which is not really relevant for our
348          * use case.
349          */
350         if (!eth_proto_is_802_3(eth->h_proto))
351                 goto drop;
352
353         odev = dev_get_by_index_rcu(net, slwt->oif);
354         if (!odev)
355                 goto drop;
356
357         /* As we accept Ethernet frames, make sure the egress device is of
358          * the correct type.
359          */
360         if (odev->type != ARPHRD_ETHER)
361                 goto drop;
362
363         if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
364                 goto drop;
365
366         skb_orphan(skb);
367
368         if (skb_warn_if_lro(skb))
369                 goto drop;
370
371         skb_forward_csum(skb);
372
373         if (skb->len - ETH_HLEN > odev->mtu)
374                 goto drop;
375
376         skb->dev = odev;
377         skb->protocol = eth->h_proto;
378
379         return dev_queue_xmit(skb);
380
381 drop:
382         kfree_skb(skb);
383         return -EINVAL;
384 }
385
386 /* decapsulate and forward to specified nexthop */
387 static int input_action_end_dx6(struct sk_buff *skb,
388                                 struct seg6_local_lwt *slwt)
389 {
390         struct in6_addr *nhaddr = NULL;
391
392         /* this function accepts IPv6 encapsulated packets, with either
393          * an SRH with SL=0, or no SRH.
394          */
395
396         if (!decap_and_validate(skb, IPPROTO_IPV6))
397                 goto drop;
398
399         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
400                 goto drop;
401
402         /* The inner packet is not associated to any local interface,
403          * so we do not call netif_rx().
404          *
405          * If slwt->nh6 is set to ::, then lookup the nexthop for the
406          * inner packet's DA. Otherwise, use the specified nexthop.
407          */
408
409         if (!ipv6_addr_any(&slwt->nh6))
410                 nhaddr = &slwt->nh6;
411
412         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
413
414         seg6_lookup_nexthop(skb, nhaddr, 0);
415
416         return dst_input(skb);
417 drop:
418         kfree_skb(skb);
419         return -EINVAL;
420 }
421
422 static int input_action_end_dx4(struct sk_buff *skb,
423                                 struct seg6_local_lwt *slwt)
424 {
425         struct iphdr *iph;
426         __be32 nhaddr;
427         int err;
428
429         if (!decap_and_validate(skb, IPPROTO_IPIP))
430                 goto drop;
431
432         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
433                 goto drop;
434
435         skb->protocol = htons(ETH_P_IP);
436
437         iph = ip_hdr(skb);
438
439         nhaddr = slwt->nh4.s_addr ?: iph->daddr;
440
441         skb_dst_drop(skb);
442
443         skb_set_transport_header(skb, sizeof(struct iphdr));
444
445         err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
446         if (err)
447                 goto drop;
448
449         return dst_input(skb);
450
451 drop:
452         kfree_skb(skb);
453         return -EINVAL;
454 }
455
456 #ifdef CONFIG_NET_L3_MASTER_DEV
457 static struct net *fib6_config_get_net(const struct fib6_config *fib6_cfg)
458 {
459         const struct nl_info *nli = &fib6_cfg->fc_nlinfo;
460
461         return nli->nl_net;
462 }
463
464 static int __seg6_end_dt_vrf_build(struct seg6_local_lwt *slwt, const void *cfg,
465                                    u16 family, struct netlink_ext_ack *extack)
466 {
467         struct seg6_end_dt_info *info = &slwt->dt_info;
468         int vrf_ifindex;
469         struct net *net;
470
471         net = fib6_config_get_net(cfg);
472
473         /* note that vrf_table was already set by parse_nla_vrftable() */
474         vrf_ifindex = l3mdev_ifindex_lookup_by_table_id(L3MDEV_TYPE_VRF, net,
475                                                         info->vrf_table);
476         if (vrf_ifindex < 0) {
477                 if (vrf_ifindex == -EPERM) {
478                         NL_SET_ERR_MSG(extack,
479                                        "Strict mode for VRF is disabled");
480                 } else if (vrf_ifindex == -ENODEV) {
481                         NL_SET_ERR_MSG(extack,
482                                        "Table has no associated VRF device");
483                 } else {
484                         pr_debug("seg6local: SRv6 End.DT* creation error=%d\n",
485                                  vrf_ifindex);
486                 }
487
488                 return vrf_ifindex;
489         }
490
491         info->net = net;
492         info->vrf_ifindex = vrf_ifindex;
493
494         switch (family) {
495         case AF_INET:
496                 info->proto = htons(ETH_P_IP);
497                 info->hdrlen = sizeof(struct iphdr);
498                 break;
499         case AF_INET6:
500                 info->proto = htons(ETH_P_IPV6);
501                 info->hdrlen = sizeof(struct ipv6hdr);
502                 break;
503         default:
504                 return -EINVAL;
505         }
506
507         info->family = family;
508         info->mode = DT_VRF_MODE;
509
510         return 0;
511 }
512
513 /* The SRv6 End.DT4/DT6 behavior extracts the inner (IPv4/IPv6) packet and
514  * routes the IPv4/IPv6 packet by looking at the configured routing table.
515  *
516  * In the SRv6 End.DT4/DT6 use case, we can receive traffic (IPv6+Segment
517  * Routing Header packets) from several interfaces and the outer IPv6
518  * destination address (DA) is used for retrieving the specific instance of the
519  * End.DT4/DT6 behavior that should process the packets.
520  *
521  * However, the inner IPv4/IPv6 packet is not really bound to any receiving
522  * interface and thus the End.DT4/DT6 sets the VRF (associated with the
523  * corresponding routing table) as the *receiving* interface.
524  * In other words, the End.DT4/DT6 processes a packet as if it has been received
525  * directly by the VRF (and not by one of its slave devices, if any).
526  * In this way, the VRF interface is used for routing the IPv4/IPv6 packet in
527  * according to the routing table configured by the End.DT4/DT6 instance.
528  *
529  * This design allows you to get some interesting features like:
530  *  1) the statistics on rx packets;
531  *  2) the possibility to install a packet sniffer on the receiving interface
532  *     (the VRF one) for looking at the incoming packets;
533  *  3) the possibility to leverage the netfilter prerouting hook for the inner
534  *     IPv4 packet.
535  *
536  * This function returns:
537  *  - the sk_buff* when the VRF rcv handler has processed the packet correctly;
538  *  - NULL when the skb is consumed by the VRF rcv handler;
539  *  - a pointer which encodes a negative error number in case of error.
540  *    Note that in this case, the function takes care of freeing the skb.
541  */
542 static struct sk_buff *end_dt_vrf_rcv(struct sk_buff *skb, u16 family,
543                                       struct net_device *dev)
544 {
545         /* based on l3mdev_ip_rcv; we are only interested in the master */
546         if (unlikely(!netif_is_l3_master(dev) && !netif_has_l3_rx_handler(dev)))
547                 goto drop;
548
549         if (unlikely(!dev->l3mdev_ops->l3mdev_l3_rcv))
550                 goto drop;
551
552         /* the decap packet IPv4/IPv6 does not come with any mac header info.
553          * We must unset the mac header to allow the VRF device to rebuild it,
554          * just in case there is a sniffer attached on the device.
555          */
556         skb_unset_mac_header(skb);
557
558         skb = dev->l3mdev_ops->l3mdev_l3_rcv(dev, skb, family);
559         if (!skb)
560                 /* the skb buffer was consumed by the handler */
561                 return NULL;
562
563         /* when a packet is received by a VRF or by one of its slaves, the
564          * master device reference is set into the skb.
565          */
566         if (unlikely(skb->dev != dev || skb->skb_iif != dev->ifindex))
567                 goto drop;
568
569         return skb;
570
571 drop:
572         kfree_skb(skb);
573         return ERR_PTR(-EINVAL);
574 }
575
576 static struct net_device *end_dt_get_vrf_rcu(struct sk_buff *skb,
577                                              struct seg6_end_dt_info *info)
578 {
579         int vrf_ifindex = info->vrf_ifindex;
580         struct net *net = info->net;
581
582         if (unlikely(vrf_ifindex < 0))
583                 goto error;
584
585         if (unlikely(!net_eq(dev_net(skb->dev), net)))
586                 goto error;
587
588         return dev_get_by_index_rcu(net, vrf_ifindex);
589
590 error:
591         return NULL;
592 }
593
594 static struct sk_buff *end_dt_vrf_core(struct sk_buff *skb,
595                                        struct seg6_local_lwt *slwt)
596 {
597         struct seg6_end_dt_info *info = &slwt->dt_info;
598         struct net_device *vrf;
599
600         vrf = end_dt_get_vrf_rcu(skb, info);
601         if (unlikely(!vrf))
602                 goto drop;
603
604         skb->protocol = info->proto;
605
606         skb_dst_drop(skb);
607
608         skb_set_transport_header(skb, info->hdrlen);
609
610         return end_dt_vrf_rcv(skb, info->family, vrf);
611
612 drop:
613         kfree_skb(skb);
614         return ERR_PTR(-EINVAL);
615 }
616
617 static int input_action_end_dt4(struct sk_buff *skb,
618                                 struct seg6_local_lwt *slwt)
619 {
620         struct iphdr *iph;
621         int err;
622
623         if (!decap_and_validate(skb, IPPROTO_IPIP))
624                 goto drop;
625
626         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
627                 goto drop;
628
629         skb = end_dt_vrf_core(skb, slwt);
630         if (!skb)
631                 /* packet has been processed and consumed by the VRF */
632                 return 0;
633
634         if (IS_ERR(skb))
635                 return PTR_ERR(skb);
636
637         iph = ip_hdr(skb);
638
639         err = ip_route_input(skb, iph->daddr, iph->saddr, 0, skb->dev);
640         if (unlikely(err))
641                 goto drop;
642
643         return dst_input(skb);
644
645 drop:
646         kfree_skb(skb);
647         return -EINVAL;
648 }
649
650 static int seg6_end_dt4_build(struct seg6_local_lwt *slwt, const void *cfg,
651                               struct netlink_ext_ack *extack)
652 {
653         return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET, extack);
654 }
655
656 static enum
657 seg6_end_dt_mode seg6_end_dt6_parse_mode(struct seg6_local_lwt *slwt)
658 {
659         unsigned long parsed_optattrs = slwt->parsed_optattrs;
660         bool legacy, vrfmode;
661
662         legacy  = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE));
663         vrfmode = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE));
664
665         if (!(legacy ^ vrfmode))
666                 /* both are absent or present: invalid DT6 mode */
667                 return DT_INVALID_MODE;
668
669         return legacy ? DT_LEGACY_MODE : DT_VRF_MODE;
670 }
671
672 static enum seg6_end_dt_mode seg6_end_dt6_get_mode(struct seg6_local_lwt *slwt)
673 {
674         struct seg6_end_dt_info *info = &slwt->dt_info;
675
676         return info->mode;
677 }
678
679 static int seg6_end_dt6_build(struct seg6_local_lwt *slwt, const void *cfg,
680                               struct netlink_ext_ack *extack)
681 {
682         enum seg6_end_dt_mode mode = seg6_end_dt6_parse_mode(slwt);
683         struct seg6_end_dt_info *info = &slwt->dt_info;
684
685         switch (mode) {
686         case DT_LEGACY_MODE:
687                 info->mode = DT_LEGACY_MODE;
688                 return 0;
689         case DT_VRF_MODE:
690                 return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET6, extack);
691         default:
692                 NL_SET_ERR_MSG(extack, "table or vrftable must be specified");
693                 return -EINVAL;
694         }
695 }
696 #endif
697
698 static int input_action_end_dt6(struct sk_buff *skb,
699                                 struct seg6_local_lwt *slwt)
700 {
701         if (!decap_and_validate(skb, IPPROTO_IPV6))
702                 goto drop;
703
704         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
705                 goto drop;
706
707 #ifdef CONFIG_NET_L3_MASTER_DEV
708         if (seg6_end_dt6_get_mode(slwt) == DT_LEGACY_MODE)
709                 goto legacy_mode;
710
711         /* DT6_VRF_MODE */
712         skb = end_dt_vrf_core(skb, slwt);
713         if (!skb)
714                 /* packet has been processed and consumed by the VRF */
715                 return 0;
716
717         if (IS_ERR(skb))
718                 return PTR_ERR(skb);
719
720         /* note: this time we do not need to specify the table because the VRF
721          * takes care of selecting the correct table.
722          */
723         seg6_lookup_any_nexthop(skb, NULL, 0, true);
724
725         return dst_input(skb);
726
727 legacy_mode:
728 #endif
729         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
730
731         seg6_lookup_any_nexthop(skb, NULL, slwt->table, true);
732
733         return dst_input(skb);
734
735 drop:
736         kfree_skb(skb);
737         return -EINVAL;
738 }
739
740 /* push an SRH on top of the current one */
741 static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
742 {
743         struct ipv6_sr_hdr *srh;
744         int err = -EINVAL;
745
746         srh = get_and_validate_srh(skb);
747         if (!srh)
748                 goto drop;
749
750         err = seg6_do_srh_inline(skb, slwt->srh);
751         if (err)
752                 goto drop;
753
754         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
755         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
756
757         seg6_lookup_nexthop(skb, NULL, 0);
758
759         return dst_input(skb);
760
761 drop:
762         kfree_skb(skb);
763         return err;
764 }
765
766 /* encapsulate within an outer IPv6 header and a specified SRH */
767 static int input_action_end_b6_encap(struct sk_buff *skb,
768                                      struct seg6_local_lwt *slwt)
769 {
770         struct ipv6_sr_hdr *srh;
771         int err = -EINVAL;
772
773         srh = get_and_validate_srh(skb);
774         if (!srh)
775                 goto drop;
776
777         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
778
779         skb_reset_inner_headers(skb);
780         skb->encapsulation = 1;
781
782         err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
783         if (err)
784                 goto drop;
785
786         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
787         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
788
789         seg6_lookup_nexthop(skb, NULL, 0);
790
791         return dst_input(skb);
792
793 drop:
794         kfree_skb(skb);
795         return err;
796 }
797
798 DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
799
800 bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
801 {
802         struct seg6_bpf_srh_state *srh_state =
803                 this_cpu_ptr(&seg6_bpf_srh_states);
804         struct ipv6_sr_hdr *srh = srh_state->srh;
805
806         if (unlikely(srh == NULL))
807                 return false;
808
809         if (unlikely(!srh_state->valid)) {
810                 if ((srh_state->hdrlen & 7) != 0)
811                         return false;
812
813                 srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
814                 if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true))
815                         return false;
816
817                 srh_state->valid = true;
818         }
819
820         return true;
821 }
822
823 static int input_action_end_bpf(struct sk_buff *skb,
824                                 struct seg6_local_lwt *slwt)
825 {
826         struct seg6_bpf_srh_state *srh_state =
827                 this_cpu_ptr(&seg6_bpf_srh_states);
828         struct ipv6_sr_hdr *srh;
829         int ret;
830
831         srh = get_and_validate_srh(skb);
832         if (!srh) {
833                 kfree_skb(skb);
834                 return -EINVAL;
835         }
836         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
837
838         /* preempt_disable is needed to protect the per-CPU buffer srh_state,
839          * which is also accessed by the bpf_lwt_seg6_* helpers
840          */
841         preempt_disable();
842         srh_state->srh = srh;
843         srh_state->hdrlen = srh->hdrlen << 3;
844         srh_state->valid = true;
845
846         rcu_read_lock();
847         bpf_compute_data_pointers(skb);
848         ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
849         rcu_read_unlock();
850
851         switch (ret) {
852         case BPF_OK:
853         case BPF_REDIRECT:
854                 break;
855         case BPF_DROP:
856                 goto drop;
857         default:
858                 pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
859                 goto drop;
860         }
861
862         if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
863                 goto drop;
864
865         preempt_enable();
866         if (ret != BPF_REDIRECT)
867                 seg6_lookup_nexthop(skb, NULL, 0);
868
869         return dst_input(skb);
870
871 drop:
872         preempt_enable();
873         kfree_skb(skb);
874         return -EINVAL;
875 }
876
877 static struct seg6_action_desc seg6_action_table[] = {
878         {
879                 .action         = SEG6_LOCAL_ACTION_END,
880                 .attrs          = 0,
881                 .input          = input_action_end,
882         },
883         {
884                 .action         = SEG6_LOCAL_ACTION_END_X,
885                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
886                 .input          = input_action_end_x,
887         },
888         {
889                 .action         = SEG6_LOCAL_ACTION_END_T,
890                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
891                 .input          = input_action_end_t,
892         },
893         {
894                 .action         = SEG6_LOCAL_ACTION_END_DX2,
895                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_OIF),
896                 .input          = input_action_end_dx2,
897         },
898         {
899                 .action         = SEG6_LOCAL_ACTION_END_DX6,
900                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
901                 .input          = input_action_end_dx6,
902         },
903         {
904                 .action         = SEG6_LOCAL_ACTION_END_DX4,
905                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH4),
906                 .input          = input_action_end_dx4,
907         },
908         {
909                 .action         = SEG6_LOCAL_ACTION_END_DT4,
910                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
911 #ifdef CONFIG_NET_L3_MASTER_DEV
912                 .input          = input_action_end_dt4,
913                 .slwt_ops       = {
914                                         .build_state = seg6_end_dt4_build,
915                                   },
916 #endif
917         },
918         {
919                 .action         = SEG6_LOCAL_ACTION_END_DT6,
920 #ifdef CONFIG_NET_L3_MASTER_DEV
921                 .attrs          = 0,
922                 .optattrs       = SEG6_F_ATTR(SEG6_LOCAL_TABLE) |
923                                   SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
924                 .slwt_ops       = {
925                                         .build_state = seg6_end_dt6_build,
926                                   },
927 #else
928                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
929 #endif
930                 .input          = input_action_end_dt6,
931         },
932         {
933                 .action         = SEG6_LOCAL_ACTION_END_B6,
934                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
935                 .input          = input_action_end_b6,
936         },
937         {
938                 .action         = SEG6_LOCAL_ACTION_END_B6_ENCAP,
939                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
940                 .input          = input_action_end_b6_encap,
941                 .static_headroom        = sizeof(struct ipv6hdr),
942         },
943         {
944                 .action         = SEG6_LOCAL_ACTION_END_BPF,
945                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_BPF),
946                 .input          = input_action_end_bpf,
947         },
948
949 };
950
951 static struct seg6_action_desc *__get_action_desc(int action)
952 {
953         struct seg6_action_desc *desc;
954         int i, count;
955
956         count = ARRAY_SIZE(seg6_action_table);
957         for (i = 0; i < count; i++) {
958                 desc = &seg6_action_table[i];
959                 if (desc->action == action)
960                         return desc;
961         }
962
963         return NULL;
964 }
965
966 static int seg6_local_input(struct sk_buff *skb)
967 {
968         struct dst_entry *orig_dst = skb_dst(skb);
969         struct seg6_action_desc *desc;
970         struct seg6_local_lwt *slwt;
971
972         if (skb->protocol != htons(ETH_P_IPV6)) {
973                 kfree_skb(skb);
974                 return -EINVAL;
975         }
976
977         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
978         desc = slwt->desc;
979
980         return desc->input(skb, slwt);
981 }
982
983 static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
984         [SEG6_LOCAL_ACTION]     = { .type = NLA_U32 },
985         [SEG6_LOCAL_SRH]        = { .type = NLA_BINARY },
986         [SEG6_LOCAL_TABLE]      = { .type = NLA_U32 },
987         [SEG6_LOCAL_VRFTABLE]   = { .type = NLA_U32 },
988         [SEG6_LOCAL_NH4]        = { .type = NLA_BINARY,
989                                     .len = sizeof(struct in_addr) },
990         [SEG6_LOCAL_NH6]        = { .type = NLA_BINARY,
991                                     .len = sizeof(struct in6_addr) },
992         [SEG6_LOCAL_IIF]        = { .type = NLA_U32 },
993         [SEG6_LOCAL_OIF]        = { .type = NLA_U32 },
994         [SEG6_LOCAL_BPF]        = { .type = NLA_NESTED },
995 };
996
997 static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
998 {
999         struct ipv6_sr_hdr *srh;
1000         int len;
1001
1002         srh = nla_data(attrs[SEG6_LOCAL_SRH]);
1003         len = nla_len(attrs[SEG6_LOCAL_SRH]);
1004
1005         /* SRH must contain at least one segment */
1006         if (len < sizeof(*srh) + sizeof(struct in6_addr))
1007                 return -EINVAL;
1008
1009         if (!seg6_validate_srh(srh, len, false))
1010                 return -EINVAL;
1011
1012         slwt->srh = kmemdup(srh, len, GFP_KERNEL);
1013         if (!slwt->srh)
1014                 return -ENOMEM;
1015
1016         slwt->headroom += len;
1017
1018         return 0;
1019 }
1020
1021 static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1022 {
1023         struct ipv6_sr_hdr *srh;
1024         struct nlattr *nla;
1025         int len;
1026
1027         srh = slwt->srh;
1028         len = (srh->hdrlen + 1) << 3;
1029
1030         nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
1031         if (!nla)
1032                 return -EMSGSIZE;
1033
1034         memcpy(nla_data(nla), srh, len);
1035
1036         return 0;
1037 }
1038
1039 static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1040 {
1041         int len = (a->srh->hdrlen + 1) << 3;
1042
1043         if (len != ((b->srh->hdrlen + 1) << 3))
1044                 return 1;
1045
1046         return memcmp(a->srh, b->srh, len);
1047 }
1048
1049 static void destroy_attr_srh(struct seg6_local_lwt *slwt)
1050 {
1051         kfree(slwt->srh);
1052 }
1053
1054 static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1055 {
1056         slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
1057
1058         return 0;
1059 }
1060
1061 static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1062 {
1063         if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
1064                 return -EMSGSIZE;
1065
1066         return 0;
1067 }
1068
1069 static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1070 {
1071         if (a->table != b->table)
1072                 return 1;
1073
1074         return 0;
1075 }
1076
1077 static struct
1078 seg6_end_dt_info *seg6_possible_end_dt_info(struct seg6_local_lwt *slwt)
1079 {
1080 #ifdef CONFIG_NET_L3_MASTER_DEV
1081         return &slwt->dt_info;
1082 #else
1083         return ERR_PTR(-EOPNOTSUPP);
1084 #endif
1085 }
1086
1087 static int parse_nla_vrftable(struct nlattr **attrs,
1088                               struct seg6_local_lwt *slwt)
1089 {
1090         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1091
1092         if (IS_ERR(info))
1093                 return PTR_ERR(info);
1094
1095         info->vrf_table = nla_get_u32(attrs[SEG6_LOCAL_VRFTABLE]);
1096
1097         return 0;
1098 }
1099
1100 static int put_nla_vrftable(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1101 {
1102         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1103
1104         if (IS_ERR(info))
1105                 return PTR_ERR(info);
1106
1107         if (nla_put_u32(skb, SEG6_LOCAL_VRFTABLE, info->vrf_table))
1108                 return -EMSGSIZE;
1109
1110         return 0;
1111 }
1112
1113 static int cmp_nla_vrftable(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1114 {
1115         struct seg6_end_dt_info *info_a = seg6_possible_end_dt_info(a);
1116         struct seg6_end_dt_info *info_b = seg6_possible_end_dt_info(b);
1117
1118         if (info_a->vrf_table != info_b->vrf_table)
1119                 return 1;
1120
1121         return 0;
1122 }
1123
1124 static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1125 {
1126         memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
1127                sizeof(struct in_addr));
1128
1129         return 0;
1130 }
1131
1132 static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1133 {
1134         struct nlattr *nla;
1135
1136         nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
1137         if (!nla)
1138                 return -EMSGSIZE;
1139
1140         memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
1141
1142         return 0;
1143 }
1144
1145 static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1146 {
1147         return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
1148 }
1149
1150 static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1151 {
1152         memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
1153                sizeof(struct in6_addr));
1154
1155         return 0;
1156 }
1157
1158 static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1159 {
1160         struct nlattr *nla;
1161
1162         nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
1163         if (!nla)
1164                 return -EMSGSIZE;
1165
1166         memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
1167
1168         return 0;
1169 }
1170
1171 static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1172 {
1173         return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
1174 }
1175
1176 static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1177 {
1178         slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
1179
1180         return 0;
1181 }
1182
1183 static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1184 {
1185         if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
1186                 return -EMSGSIZE;
1187
1188         return 0;
1189 }
1190
1191 static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1192 {
1193         if (a->iif != b->iif)
1194                 return 1;
1195
1196         return 0;
1197 }
1198
1199 static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1200 {
1201         slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
1202
1203         return 0;
1204 }
1205
1206 static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1207 {
1208         if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
1209                 return -EMSGSIZE;
1210
1211         return 0;
1212 }
1213
1214 static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1215 {
1216         if (a->oif != b->oif)
1217                 return 1;
1218
1219         return 0;
1220 }
1221
1222 #define MAX_PROG_NAME 256
1223 static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
1224         [SEG6_LOCAL_BPF_PROG]      = { .type = NLA_U32, },
1225         [SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
1226                                        .len = MAX_PROG_NAME },
1227 };
1228
1229 static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1230 {
1231         struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
1232         struct bpf_prog *p;
1233         int ret;
1234         u32 fd;
1235
1236         ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX,
1237                                           attrs[SEG6_LOCAL_BPF],
1238                                           bpf_prog_policy, NULL);
1239         if (ret < 0)
1240                 return ret;
1241
1242         if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
1243                 return -EINVAL;
1244
1245         slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
1246         if (!slwt->bpf.name)
1247                 return -ENOMEM;
1248
1249         fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
1250         p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
1251         if (IS_ERR(p)) {
1252                 kfree(slwt->bpf.name);
1253                 return PTR_ERR(p);
1254         }
1255
1256         slwt->bpf.prog = p;
1257         return 0;
1258 }
1259
1260 static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1261 {
1262         struct nlattr *nest;
1263
1264         if (!slwt->bpf.prog)
1265                 return 0;
1266
1267         nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF);
1268         if (!nest)
1269                 return -EMSGSIZE;
1270
1271         if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
1272                 return -EMSGSIZE;
1273
1274         if (slwt->bpf.name &&
1275             nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
1276                 return -EMSGSIZE;
1277
1278         return nla_nest_end(skb, nest);
1279 }
1280
1281 static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1282 {
1283         if (!a->bpf.name && !b->bpf.name)
1284                 return 0;
1285
1286         if (!a->bpf.name || !b->bpf.name)
1287                 return 1;
1288
1289         return strcmp(a->bpf.name, b->bpf.name);
1290 }
1291
1292 static void destroy_attr_bpf(struct seg6_local_lwt *slwt)
1293 {
1294         kfree(slwt->bpf.name);
1295         if (slwt->bpf.prog)
1296                 bpf_prog_put(slwt->bpf.prog);
1297 }
1298
1299 struct seg6_action_param {
1300         int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
1301         int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
1302         int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
1303
1304         /* optional destroy() callback useful for releasing resources which
1305          * have been previously acquired in the corresponding parse()
1306          * function.
1307          */
1308         void (*destroy)(struct seg6_local_lwt *slwt);
1309 };
1310
1311 static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
1312         [SEG6_LOCAL_SRH]        = { .parse = parse_nla_srh,
1313                                     .put = put_nla_srh,
1314                                     .cmp = cmp_nla_srh,
1315                                     .destroy = destroy_attr_srh },
1316
1317         [SEG6_LOCAL_TABLE]      = { .parse = parse_nla_table,
1318                                     .put = put_nla_table,
1319                                     .cmp = cmp_nla_table },
1320
1321         [SEG6_LOCAL_NH4]        = { .parse = parse_nla_nh4,
1322                                     .put = put_nla_nh4,
1323                                     .cmp = cmp_nla_nh4 },
1324
1325         [SEG6_LOCAL_NH6]        = { .parse = parse_nla_nh6,
1326                                     .put = put_nla_nh6,
1327                                     .cmp = cmp_nla_nh6 },
1328
1329         [SEG6_LOCAL_IIF]        = { .parse = parse_nla_iif,
1330                                     .put = put_nla_iif,
1331                                     .cmp = cmp_nla_iif },
1332
1333         [SEG6_LOCAL_OIF]        = { .parse = parse_nla_oif,
1334                                     .put = put_nla_oif,
1335                                     .cmp = cmp_nla_oif },
1336
1337         [SEG6_LOCAL_BPF]        = { .parse = parse_nla_bpf,
1338                                     .put = put_nla_bpf,
1339                                     .cmp = cmp_nla_bpf,
1340                                     .destroy = destroy_attr_bpf },
1341
1342         [SEG6_LOCAL_VRFTABLE]   = { .parse = parse_nla_vrftable,
1343                                     .put = put_nla_vrftable,
1344                                     .cmp = cmp_nla_vrftable },
1345
1346 };
1347
1348 /* call the destroy() callback (if available) for each set attribute in
1349  * @parsed_attrs, starting from the first attribute up to the @max_parsed
1350  * (excluded) attribute.
1351  */
1352 static void __destroy_attrs(unsigned long parsed_attrs, int max_parsed,
1353                             struct seg6_local_lwt *slwt)
1354 {
1355         struct seg6_action_param *param;
1356         int i;
1357
1358         /* Every required seg6local attribute is identified by an ID which is
1359          * encoded as a flag (i.e: 1 << ID) in the 'attrs' bitmask;
1360          *
1361          * We scan the 'parsed_attrs' bitmask, starting from the first attribute
1362          * up to the @max_parsed (excluded) attribute.
1363          * For each set attribute, we retrieve the corresponding destroy()
1364          * callback. If the callback is not available, then we skip to the next
1365          * attribute; otherwise, we call the destroy() callback.
1366          */
1367         for (i = 0; i < max_parsed; ++i) {
1368                 if (!(parsed_attrs & SEG6_F_ATTR(i)))
1369                         continue;
1370
1371                 param = &seg6_action_params[i];
1372
1373                 if (param->destroy)
1374                         param->destroy(slwt);
1375         }
1376 }
1377
1378 /* release all the resources that may have been acquired during parsing
1379  * operations.
1380  */
1381 static void destroy_attrs(struct seg6_local_lwt *slwt)
1382 {
1383         unsigned long attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1384
1385         __destroy_attrs(attrs, SEG6_LOCAL_MAX + 1, slwt);
1386 }
1387
1388 static int parse_nla_optional_attrs(struct nlattr **attrs,
1389                                     struct seg6_local_lwt *slwt)
1390 {
1391         struct seg6_action_desc *desc = slwt->desc;
1392         unsigned long parsed_optattrs = 0;
1393         struct seg6_action_param *param;
1394         int err, i;
1395
1396         for (i = 0; i < SEG6_LOCAL_MAX + 1; ++i) {
1397                 if (!(desc->optattrs & SEG6_F_ATTR(i)) || !attrs[i])
1398                         continue;
1399
1400                 /* once here, the i-th attribute is provided by the
1401                  * userspace AND it is identified optional as well.
1402                  */
1403                 param = &seg6_action_params[i];
1404
1405                 err = param->parse(attrs, slwt);
1406                 if (err < 0)
1407                         goto parse_optattrs_err;
1408
1409                 /* current attribute has been correctly parsed */
1410                 parsed_optattrs |= SEG6_F_ATTR(i);
1411         }
1412
1413         /* store in the tunnel state all the optional attributed successfully
1414          * parsed.
1415          */
1416         slwt->parsed_optattrs = parsed_optattrs;
1417
1418         return 0;
1419
1420 parse_optattrs_err:
1421         __destroy_attrs(parsed_optattrs, i, slwt);
1422
1423         return err;
1424 }
1425
1426 /* call the custom constructor of the behavior during its initialization phase
1427  * and after that all its attributes have been parsed successfully.
1428  */
1429 static int
1430 seg6_local_lwtunnel_build_state(struct seg6_local_lwt *slwt, const void *cfg,
1431                                 struct netlink_ext_ack *extack)
1432 {
1433         struct seg6_action_desc *desc = slwt->desc;
1434         struct seg6_local_lwtunnel_ops *ops;
1435
1436         ops = &desc->slwt_ops;
1437         if (!ops->build_state)
1438                 return 0;
1439
1440         return ops->build_state(slwt, cfg, extack);
1441 }
1442
1443 /* call the custom destructor of the behavior which is invoked before the
1444  * tunnel is going to be destroyed.
1445  */
1446 static void seg6_local_lwtunnel_destroy_state(struct seg6_local_lwt *slwt)
1447 {
1448         struct seg6_action_desc *desc = slwt->desc;
1449         struct seg6_local_lwtunnel_ops *ops;
1450
1451         ops = &desc->slwt_ops;
1452         if (!ops->destroy_state)
1453                 return;
1454
1455         ops->destroy_state(slwt);
1456 }
1457
1458 static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1459 {
1460         struct seg6_action_param *param;
1461         struct seg6_action_desc *desc;
1462         unsigned long invalid_attrs;
1463         int i, err;
1464
1465         desc = __get_action_desc(slwt->action);
1466         if (!desc)
1467                 return -EINVAL;
1468
1469         if (!desc->input)
1470                 return -EOPNOTSUPP;
1471
1472         slwt->desc = desc;
1473         slwt->headroom += desc->static_headroom;
1474
1475         /* Forcing the desc->optattrs *set* and the desc->attrs *set* to be
1476          * disjoined, this allow us to release acquired resources by optional
1477          * attributes and by required attributes independently from each other
1478          * without any interference.
1479          * In other terms, we are sure that we do not release some the acquired
1480          * resources twice.
1481          *
1482          * Note that if an attribute is configured both as required and as
1483          * optional, it means that the user has messed something up in the
1484          * seg6_action_table. Therefore, this check is required for SRv6
1485          * behaviors to work properly.
1486          */
1487         invalid_attrs = desc->attrs & desc->optattrs;
1488         if (invalid_attrs) {
1489                 WARN_ONCE(1,
1490                           "An attribute cannot be both required AND optional");
1491                 return -EINVAL;
1492         }
1493
1494         /* parse the required attributes */
1495         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1496                 if (desc->attrs & SEG6_F_ATTR(i)) {
1497                         if (!attrs[i])
1498                                 return -EINVAL;
1499
1500                         param = &seg6_action_params[i];
1501
1502                         err = param->parse(attrs, slwt);
1503                         if (err < 0)
1504                                 goto parse_attrs_err;
1505                 }
1506         }
1507
1508         /* parse the optional attributes, if any */
1509         err = parse_nla_optional_attrs(attrs, slwt);
1510         if (err < 0)
1511                 goto parse_attrs_err;
1512
1513         return 0;
1514
1515 parse_attrs_err:
1516         /* release any resource that may have been acquired during the i-1
1517          * parse() operations.
1518          */
1519         __destroy_attrs(desc->attrs, i, slwt);
1520
1521         return err;
1522 }
1523
1524 static int seg6_local_build_state(struct net *net, struct nlattr *nla,
1525                                   unsigned int family, const void *cfg,
1526                                   struct lwtunnel_state **ts,
1527                                   struct netlink_ext_ack *extack)
1528 {
1529         struct nlattr *tb[SEG6_LOCAL_MAX + 1];
1530         struct lwtunnel_state *newts;
1531         struct seg6_local_lwt *slwt;
1532         int err;
1533
1534         if (family != AF_INET6)
1535                 return -EINVAL;
1536
1537         err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla,
1538                                           seg6_local_policy, extack);
1539
1540         if (err < 0)
1541                 return err;
1542
1543         if (!tb[SEG6_LOCAL_ACTION])
1544                 return -EINVAL;
1545
1546         newts = lwtunnel_state_alloc(sizeof(*slwt));
1547         if (!newts)
1548                 return -ENOMEM;
1549
1550         slwt = seg6_local_lwtunnel(newts);
1551         slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
1552
1553         err = parse_nla_action(tb, slwt);
1554         if (err < 0)
1555                 goto out_free;
1556
1557         err = seg6_local_lwtunnel_build_state(slwt, cfg, extack);
1558         if (err < 0)
1559                 goto out_destroy_attrs;
1560
1561         newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
1562         newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
1563         newts->headroom = slwt->headroom;
1564
1565         *ts = newts;
1566
1567         return 0;
1568
1569 out_destroy_attrs:
1570         destroy_attrs(slwt);
1571 out_free:
1572         kfree(newts);
1573         return err;
1574 }
1575
1576 static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
1577 {
1578         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1579
1580         seg6_local_lwtunnel_destroy_state(slwt);
1581
1582         destroy_attrs(slwt);
1583
1584         return;
1585 }
1586
1587 static int seg6_local_fill_encap(struct sk_buff *skb,
1588                                  struct lwtunnel_state *lwt)
1589 {
1590         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1591         struct seg6_action_param *param;
1592         unsigned long attrs;
1593         int i, err;
1594
1595         if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
1596                 return -EMSGSIZE;
1597
1598         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1599
1600         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1601                 if (attrs & SEG6_F_ATTR(i)) {
1602                         param = &seg6_action_params[i];
1603                         err = param->put(skb, slwt);
1604                         if (err < 0)
1605                                 return err;
1606                 }
1607         }
1608
1609         return 0;
1610 }
1611
1612 static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
1613 {
1614         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1615         unsigned long attrs;
1616         int nlsize;
1617
1618         nlsize = nla_total_size(4); /* action */
1619
1620         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1621
1622         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_SRH))
1623                 nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
1624
1625         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE))
1626                 nlsize += nla_total_size(4);
1627
1628         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH4))
1629                 nlsize += nla_total_size(4);
1630
1631         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH6))
1632                 nlsize += nla_total_size(16);
1633
1634         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_IIF))
1635                 nlsize += nla_total_size(4);
1636
1637         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_OIF))
1638                 nlsize += nla_total_size(4);
1639
1640         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_BPF))
1641                 nlsize += nla_total_size(sizeof(struct nlattr)) +
1642                        nla_total_size(MAX_PROG_NAME) +
1643                        nla_total_size(4);
1644
1645         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE))
1646                 nlsize += nla_total_size(4);
1647
1648         return nlsize;
1649 }
1650
1651 static int seg6_local_cmp_encap(struct lwtunnel_state *a,
1652                                 struct lwtunnel_state *b)
1653 {
1654         struct seg6_local_lwt *slwt_a, *slwt_b;
1655         struct seg6_action_param *param;
1656         unsigned long attrs_a, attrs_b;
1657         int i;
1658
1659         slwt_a = seg6_local_lwtunnel(a);
1660         slwt_b = seg6_local_lwtunnel(b);
1661
1662         if (slwt_a->action != slwt_b->action)
1663                 return 1;
1664
1665         attrs_a = slwt_a->desc->attrs | slwt_a->parsed_optattrs;
1666         attrs_b = slwt_b->desc->attrs | slwt_b->parsed_optattrs;
1667
1668         if (attrs_a != attrs_b)
1669                 return 1;
1670
1671         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1672                 if (attrs_a & SEG6_F_ATTR(i)) {
1673                         param = &seg6_action_params[i];
1674                         if (param->cmp(slwt_a, slwt_b))
1675                                 return 1;
1676                 }
1677         }
1678
1679         return 0;
1680 }
1681
1682 static const struct lwtunnel_encap_ops seg6_local_ops = {
1683         .build_state    = seg6_local_build_state,
1684         .destroy_state  = seg6_local_destroy_state,
1685         .input          = seg6_local_input,
1686         .fill_encap     = seg6_local_fill_encap,
1687         .get_encap_size = seg6_local_get_encap_size,
1688         .cmp_encap      = seg6_local_cmp_encap,
1689         .owner          = THIS_MODULE,
1690 };
1691
1692 int __init seg6_local_init(void)
1693 {
1694         /* If the max total number of defined attributes is reached, then your
1695          * kernel build stops here.
1696          *
1697          * This check is required to avoid arithmetic overflows when processing
1698          * behavior attributes and the maximum number of defined attributes
1699          * exceeds the allowed value.
1700          */
1701         BUILD_BUG_ON(SEG6_LOCAL_MAX + 1 > BITS_PER_TYPE(unsigned long));
1702
1703         return lwtunnel_encap_add_ops(&seg6_local_ops,
1704                                       LWTUNNEL_ENCAP_SEG6_LOCAL);
1705 }
1706
1707 void seg6_local_exit(void)
1708 {
1709         lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
1710 }