Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
[sfrench/cifs-2.6.git] / net / core / filter.c
index 255aeee7240265a26f21e7d3c548ae823cb396ed..adfdad234674dc1031d24b8ae635174fcfbd0dce 100644 (file)
@@ -2083,13 +2083,13 @@ static const struct bpf_func_proto bpf_csum_level_proto = {
 
 static inline int __bpf_rx_skb(struct net_device *dev, struct sk_buff *skb)
 {
-       return dev_forward_skb(dev, skb);
+       return dev_forward_skb_nomtu(dev, skb);
 }
 
 static inline int __bpf_rx_skb_no_mac(struct net_device *dev,
                                      struct sk_buff *skb)
 {
-       int ret = ____dev_forward_skb(dev, skb);
+       int ret = ____dev_forward_skb(dev, skb, false);
 
        if (likely(!ret)) {
                skb->dev = dev;
@@ -2480,7 +2480,7 @@ int skb_do_redirect(struct sk_buff *skb)
                        goto out_drop;
                dev = ops->ndo_get_peer_dev(dev);
                if (unlikely(!dev ||
-                            !is_skb_forwardable(dev, skb) ||
+                            !(dev->flags & IFF_UP) ||
                             net_eq(net, dev_net(dev))))
                        goto out_drop;
                skb->dev = dev;
@@ -3552,11 +3552,7 @@ static int bpf_skb_net_shrink(struct sk_buff *skb, u32 off, u32 len_diff,
        return 0;
 }
 
-static u32 __bpf_skb_max_len(const struct sk_buff *skb)
-{
-       return skb->dev ? skb->dev->mtu + skb->dev->hard_header_len :
-                         SKB_MAX_ALLOC;
-}
+#define BPF_SKB_MAX_LEN SKB_MAX_ALLOC
 
 BPF_CALL_4(sk_skb_adjust_room, struct sk_buff *, skb, s32, len_diff,
           u32, mode, u64, flags)
@@ -3605,7 +3601,7 @@ BPF_CALL_4(bpf_skb_adjust_room, struct sk_buff *, skb, s32, len_diff,
 {
        u32 len_cur, len_diff_abs = abs(len_diff);
        u32 len_min = bpf_skb_net_base_len(skb);
-       u32 len_max = __bpf_skb_max_len(skb);
+       u32 len_max = BPF_SKB_MAX_LEN;
        __be16 proto = skb->protocol;
        bool shrink = len_diff < 0;
        u32 off;
@@ -3688,7 +3684,7 @@ static int bpf_skb_trim_rcsum(struct sk_buff *skb, unsigned int new_len)
 static inline int __bpf_skb_change_tail(struct sk_buff *skb, u32 new_len,
                                        u64 flags)
 {
-       u32 max_len = __bpf_skb_max_len(skb);
+       u32 max_len = BPF_SKB_MAX_LEN;
        u32 min_len = __bpf_skb_min_len(skb);
        int ret;
 
@@ -3764,7 +3760,7 @@ static const struct bpf_func_proto sk_skb_change_tail_proto = {
 static inline int __bpf_skb_change_head(struct sk_buff *skb, u32 head_room,
                                        u64 flags)
 {
-       u32 max_len = __bpf_skb_max_len(skb);
+       u32 max_len = BPF_SKB_MAX_LEN;
        u32 new_len = skb->len + head_room;
        int ret;
 
@@ -4631,6 +4627,18 @@ static const struct bpf_func_proto bpf_get_socket_cookie_sock_proto = {
        .arg1_type      = ARG_PTR_TO_CTX,
 };
 
+BPF_CALL_1(bpf_get_socket_ptr_cookie, struct sock *, sk)
+{
+       return sk ? sock_gen_cookie(sk) : 0;
+}
+
+const struct bpf_func_proto bpf_get_socket_ptr_cookie_proto = {
+       .func           = bpf_get_socket_ptr_cookie,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
+};
+
 BPF_CALL_1(bpf_get_socket_cookie_sock_ops, struct bpf_sock_ops_kern *, ctx)
 {
        return __sock_gen_cookie(ctx->sk);
@@ -4645,11 +4653,9 @@ static const struct bpf_func_proto bpf_get_socket_cookie_sock_ops_proto = {
 
 static u64 __bpf_get_netns_cookie(struct sock *sk)
 {
-#ifdef CONFIG_NET_NS
-       return __net_gen_cookie(sk ? sk->sk_net.net : &init_net);
-#else
-       return 0;
-#endif
+       const struct net *net = sk ? sock_net(sk) : &init_net;
+
+       return net->net_cookie;
 }
 
 BPF_CALL_1(bpf_get_netns_cookie_sock, struct sock *, ctx)
@@ -4770,6 +4776,10 @@ static int _bpf_setsockopt(struct sock *sk, int level, int optname,
                                ifindex = dev->ifindex;
                                dev_put(dev);
                        }
+                       fallthrough;
+               case SO_BINDTOIFINDEX:
+                       if (optname == SO_BINDTOIFINDEX)
+                               ifindex = val;
                        ret = sock_bindtoindex(sk, ifindex, false);
                        break;
                case SO_KEEPALIVE:
@@ -4932,8 +4942,25 @@ static int _bpf_getsockopt(struct sock *sk, int level, int optname,
 
        sock_owned_by_me(sk);
 
+       if (level == SOL_SOCKET) {
+               if (optlen != sizeof(int))
+                       goto err_clear;
+
+               switch (optname) {
+               case SO_MARK:
+                       *((int *)optval) = sk->sk_mark;
+                       break;
+               case SO_PRIORITY:
+                       *((int *)optval) = sk->sk_priority;
+                       break;
+               case SO_BINDTOIFINDEX:
+                       *((int *)optval) = sk->sk_bound_dev_if;
+                       break;
+               default:
+                       goto err_clear;
+               }
 #ifdef CONFIG_INET
-       if (level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) {
+       } else if (level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) {
                struct inet_connection_sock *icsk;
                struct tcp_sock *tp;
 
@@ -4986,12 +5013,12 @@ static int _bpf_getsockopt(struct sock *sk, int level, int optname,
                default:
                        goto err_clear;
                }
+#endif
 #endif
        } else {
                goto err_clear;
        }
        return 0;
-#endif
 err_clear:
        memset(optval, 0, optlen);
        return -EINVAL;
@@ -5272,12 +5299,14 @@ static const struct bpf_func_proto bpf_skb_get_xfrm_state_proto = {
 #if IS_ENABLED(CONFIG_INET) || IS_ENABLED(CONFIG_IPV6)
 static int bpf_fib_set_fwd_params(struct bpf_fib_lookup *params,
                                  const struct neighbour *neigh,
-                                 const struct net_device *dev)
+                                 const struct net_device *dev, u32 mtu)
 {
        memcpy(params->dmac, neigh->ha, ETH_ALEN);
        memcpy(params->smac, dev->dev_addr, ETH_ALEN);
        params->h_vlan_TCI = 0;
        params->h_vlan_proto = 0;
+       if (mtu)
+               params->mtu_result = mtu; /* union with tot_len */
 
        return 0;
 }
@@ -5293,8 +5322,8 @@ static int bpf_ipv4_fib_lookup(struct net *net, struct bpf_fib_lookup *params,
        struct net_device *dev;
        struct fib_result res;
        struct flowi4 fl4;
+       u32 mtu = 0;
        int err;
-       u32 mtu;
 
        dev = dev_get_by_index_rcu(net, params->ifindex);
        if (unlikely(!dev))
@@ -5361,8 +5390,10 @@ static int bpf_ipv4_fib_lookup(struct net *net, struct bpf_fib_lookup *params,
 
        if (check_mtu) {
                mtu = ip_mtu_from_fib_result(&res, params->ipv4_dst);
-               if (params->tot_len > mtu)
+               if (params->tot_len > mtu) {
+                       params->mtu_result = mtu; /* union with tot_len */
                        return BPF_FIB_LKUP_RET_FRAG_NEEDED;
+               }
        }
 
        nhc = res.nhc;
@@ -5396,7 +5427,7 @@ static int bpf_ipv4_fib_lookup(struct net *net, struct bpf_fib_lookup *params,
        if (!neigh)
                return BPF_FIB_LKUP_RET_NO_NEIGH;
 
-       return bpf_fib_set_fwd_params(params, neigh, dev);
+       return bpf_fib_set_fwd_params(params, neigh, dev, mtu);
 }
 #endif
 
@@ -5413,7 +5444,7 @@ static int bpf_ipv6_fib_lookup(struct net *net, struct bpf_fib_lookup *params,
        struct flowi6 fl6;
        int strict = 0;
        int oif, err;
-       u32 mtu;
+       u32 mtu = 0;
 
        /* link local addresses are never forwarded */
        if (rt6_need_strict(dst) || rt6_need_strict(src))
@@ -5488,8 +5519,10 @@ static int bpf_ipv6_fib_lookup(struct net *net, struct bpf_fib_lookup *params,
 
        if (check_mtu) {
                mtu = ipv6_stub->ip6_mtu_from_fib6(&res, dst, src);
-               if (params->tot_len > mtu)
+               if (params->tot_len > mtu) {
+                       params->mtu_result = mtu; /* union with tot_len */
                        return BPF_FIB_LKUP_RET_FRAG_NEEDED;
+               }
        }
 
        if (res.nh->fib_nh_lws)
@@ -5509,7 +5542,7 @@ static int bpf_ipv6_fib_lookup(struct net *net, struct bpf_fib_lookup *params,
        if (!neigh)
                return BPF_FIB_LKUP_RET_NO_NEIGH;
 
-       return bpf_fib_set_fwd_params(params, neigh, dev);
+       return bpf_fib_set_fwd_params(params, neigh, dev, mtu);
 }
 #endif
 
@@ -5552,6 +5585,7 @@ BPF_CALL_4(bpf_skb_fib_lookup, struct sk_buff *, skb,
 {
        struct net *net = dev_net(skb->dev);
        int rc = -EAFNOSUPPORT;
+       bool check_mtu = false;
 
        if (plen < sizeof(*params))
                return -EINVAL;
@@ -5559,25 +5593,33 @@ BPF_CALL_4(bpf_skb_fib_lookup, struct sk_buff *, skb,
        if (flags & ~(BPF_FIB_LOOKUP_DIRECT | BPF_FIB_LOOKUP_OUTPUT))
                return -EINVAL;
 
+       if (params->tot_len)
+               check_mtu = true;
+
        switch (params->family) {
 #if IS_ENABLED(CONFIG_INET)
        case AF_INET:
-               rc = bpf_ipv4_fib_lookup(net, params, flags, false);
+               rc = bpf_ipv4_fib_lookup(net, params, flags, check_mtu);
                break;
 #endif
 #if IS_ENABLED(CONFIG_IPV6)
        case AF_INET6:
-               rc = bpf_ipv6_fib_lookup(net, params, flags, false);
+               rc = bpf_ipv6_fib_lookup(net, params, flags, check_mtu);
                break;
 #endif
        }
 
-       if (!rc) {
+       if (rc == BPF_FIB_LKUP_RET_SUCCESS && !check_mtu) {
                struct net_device *dev;
 
+               /* When tot_len isn't provided by user, check skb
+                * against MTU of FIB lookup resulting net_device
+                */
                dev = dev_get_by_index_rcu(net, params->ifindex);
                if (!is_skb_forwardable(dev, skb))
                        rc = BPF_FIB_LKUP_RET_FRAG_NEEDED;
+
+               params->mtu_result = dev->mtu; /* union with tot_len */
        }
 
        return rc;
@@ -5593,6 +5635,116 @@ static const struct bpf_func_proto bpf_skb_fib_lookup_proto = {
        .arg4_type      = ARG_ANYTHING,
 };
 
+static struct net_device *__dev_via_ifindex(struct net_device *dev_curr,
+                                           u32 ifindex)
+{
+       struct net *netns = dev_net(dev_curr);
+
+       /* Non-redirect use-cases can use ifindex=0 and save ifindex lookup */
+       if (ifindex == 0)
+               return dev_curr;
+
+       return dev_get_by_index_rcu(netns, ifindex);
+}
+
+BPF_CALL_5(bpf_skb_check_mtu, struct sk_buff *, skb,
+          u32, ifindex, u32 *, mtu_len, s32, len_diff, u64, flags)
+{
+       int ret = BPF_MTU_CHK_RET_FRAG_NEEDED;
+       struct net_device *dev = skb->dev;
+       int skb_len, dev_len;
+       int mtu;
+
+       if (unlikely(flags & ~(BPF_MTU_CHK_SEGS)))
+               return -EINVAL;
+
+       if (unlikely(flags & BPF_MTU_CHK_SEGS && len_diff))
+               return -EINVAL;
+
+       dev = __dev_via_ifindex(dev, ifindex);
+       if (unlikely(!dev))
+               return -ENODEV;
+
+       mtu = READ_ONCE(dev->mtu);
+
+       dev_len = mtu + dev->hard_header_len;
+       skb_len = skb->len + len_diff; /* minus result pass check */
+       if (skb_len <= dev_len) {
+               ret = BPF_MTU_CHK_RET_SUCCESS;
+               goto out;
+       }
+       /* At this point, skb->len exceed MTU, but as it include length of all
+        * segments, it can still be below MTU.  The SKB can possibly get
+        * re-segmented in transmit path (see validate_xmit_skb).  Thus, user
+        * must choose if segs are to be MTU checked.
+        */
+       if (skb_is_gso(skb)) {
+               ret = BPF_MTU_CHK_RET_SUCCESS;
+
+               if (flags & BPF_MTU_CHK_SEGS &&
+                   !skb_gso_validate_network_len(skb, mtu))
+                       ret = BPF_MTU_CHK_RET_SEGS_TOOBIG;
+       }
+out:
+       /* BPF verifier guarantees valid pointer */
+       *mtu_len = mtu;
+
+       return ret;
+}
+
+BPF_CALL_5(bpf_xdp_check_mtu, struct xdp_buff *, xdp,
+          u32, ifindex, u32 *, mtu_len, s32, len_diff, u64, flags)
+{
+       struct net_device *dev = xdp->rxq->dev;
+       int xdp_len = xdp->data_end - xdp->data;
+       int ret = BPF_MTU_CHK_RET_SUCCESS;
+       int mtu, dev_len;
+
+       /* XDP variant doesn't support multi-buffer segment check (yet) */
+       if (unlikely(flags))
+               return -EINVAL;
+
+       dev = __dev_via_ifindex(dev, ifindex);
+       if (unlikely(!dev))
+               return -ENODEV;
+
+       mtu = READ_ONCE(dev->mtu);
+
+       /* Add L2-header as dev MTU is L3 size */
+       dev_len = mtu + dev->hard_header_len;
+
+       xdp_len += len_diff; /* minus result pass check */
+       if (xdp_len > dev_len)
+               ret = BPF_MTU_CHK_RET_FRAG_NEEDED;
+
+       /* BPF verifier guarantees valid pointer */
+       *mtu_len = mtu;
+
+       return ret;
+}
+
+static const struct bpf_func_proto bpf_skb_check_mtu_proto = {
+       .func           = bpf_skb_check_mtu,
+       .gpl_only       = true,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_PTR_TO_INT,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
+static const struct bpf_func_proto bpf_xdp_check_mtu_proto = {
+       .func           = bpf_xdp_check_mtu,
+       .gpl_only       = true,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_PTR_TO_INT,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF)
 static int bpf_push_seg6_encap(struct sk_buff *skb, u32 type, void *hdr, u32 len)
 {
@@ -7002,6 +7154,14 @@ sock_addr_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                case BPF_CGROUP_INET6_BIND:
                case BPF_CGROUP_INET4_CONNECT:
                case BPF_CGROUP_INET6_CONNECT:
+               case BPF_CGROUP_UDP4_RECVMSG:
+               case BPF_CGROUP_UDP6_RECVMSG:
+               case BPF_CGROUP_UDP4_SENDMSG:
+               case BPF_CGROUP_UDP6_SENDMSG:
+               case BPF_CGROUP_INET4_GETPEERNAME:
+               case BPF_CGROUP_INET6_GETPEERNAME:
+               case BPF_CGROUP_INET4_GETSOCKNAME:
+               case BPF_CGROUP_INET6_GETSOCKNAME:
                        return &bpf_sock_addr_setsockopt_proto;
                default:
                        return NULL;
@@ -7012,6 +7172,14 @@ sock_addr_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                case BPF_CGROUP_INET6_BIND:
                case BPF_CGROUP_INET4_CONNECT:
                case BPF_CGROUP_INET6_CONNECT:
+               case BPF_CGROUP_UDP4_RECVMSG:
+               case BPF_CGROUP_UDP6_RECVMSG:
+               case BPF_CGROUP_UDP4_SENDMSG:
+               case BPF_CGROUP_UDP6_SENDMSG:
+               case BPF_CGROUP_INET4_GETPEERNAME:
+               case BPF_CGROUP_INET6_GETPEERNAME:
+               case BPF_CGROUP_INET4_GETSOCKNAME:
+               case BPF_CGROUP_INET6_GETSOCKNAME:
                        return &bpf_sock_addr_getsockopt_proto;
                default:
                        return NULL;
@@ -7162,6 +7330,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_get_socket_uid_proto;
        case BPF_FUNC_fib_lookup:
                return &bpf_skb_fib_lookup_proto;
+       case BPF_FUNC_check_mtu:
+               return &bpf_skb_check_mtu_proto;
        case BPF_FUNC_sk_fullsock:
                return &bpf_sk_fullsock_proto;
        case BPF_FUNC_sk_storage_get:
@@ -7231,6 +7401,8 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_xdp_adjust_tail_proto;
        case BPF_FUNC_fib_lookup:
                return &bpf_xdp_fib_lookup_proto;
+       case BPF_FUNC_check_mtu:
+               return &bpf_xdp_check_mtu_proto;
 #ifdef CONFIG_INET
        case BPF_FUNC_sk_lookup_udp:
                return &bpf_xdp_sk_lookup_udp_proto;
@@ -8795,7 +8967,7 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
                                       target_size));
                break;
        case offsetof(struct bpf_sock, rx_queue_mapping):
-#ifdef CONFIG_XPS
+#ifdef CONFIG_SOCK_RX_QUEUE_MAPPING
                *insn++ = BPF_LDX_MEM(
                        BPF_FIELD_SIZEOF(struct sock, sk_rx_queue_mapping),
                        si->dst_reg, si->src_reg,