sock: Introduce sk->sk_prot->psock_update_sk_prot()
authorCong Wang <cong.wang@bytedance.com>
Wed, 31 Mar 2021 02:32:31 +0000 (19:32 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Thu, 1 Apr 2021 17:56:14 +0000 (10:56 -0700)
Currently sockmap calls into each protocol to update the struct
proto and replace it. This certainly won't work when the protocol
is implemented as a module, for example, AF_UNIX.

Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
protocol can implement its own way to replace the struct proto.
This also helps get rid of symbol dependencies on CONFIG_INET.

Signed-off-by: Cong Wang <cong.wang@bytedance.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210331023237.41094-11-xiyou.wangcong@gmail.com
12 files changed:
include/linux/skmsg.h
include/net/sock.h
include/net/tcp.h
include/net/udp.h
net/core/skmsg.c
net/core/sock_map.c
net/ipv4/tcp_bpf.c
net/ipv4/tcp_ipv4.c
net/ipv4/udp.c
net/ipv4/udp_bpf.c
net/ipv6/tcp_ipv6.c
net/ipv6/udp.c

index c83dbc2d81d96a3bf48b1dbf2d093e6eec82591a..5e800ddc2dc6cdd78eb4d42a6cdfb77518422699 100644 (file)
@@ -99,6 +99,7 @@ struct sk_psock {
        void (*saved_close)(struct sock *sk, long timeout);
        void (*saved_write_space)(struct sock *sk);
        void (*saved_data_ready)(struct sock *sk);
+       int  (*psock_update_sk_prot)(struct sock *sk, bool restore);
        struct proto                    *sk_proto;
        struct mutex                    work_mutex;
        struct sk_psock_work_state      work_state;
@@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
        }
 }
 
-static inline void sk_psock_update_proto(struct sock *sk,
-                                        struct sk_psock *psock,
-                                        struct proto *ops)
-{
-       /* Pairs with lockless read in sk_clone_lock() */
-       WRITE_ONCE(sk->sk_prot, ops);
-}
-
 static inline void sk_psock_restore_proto(struct sock *sk,
                                          struct sk_psock *psock)
 {
        sk->sk_prot->unhash = psock->saved_unhash;
-       if (inet_csk_has_ulp(sk)) {
-               tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
-       } else {
-               sk->sk_write_space = psock->saved_write_space;
-               /* Pairs with lockless read in sk_clone_lock() */
-               WRITE_ONCE(sk->sk_prot, psock->sk_proto);
-       }
+       if (psock->psock_update_sk_prot)
+               psock->psock_update_sk_prot(sk, true);
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
index 0b6266fd6bf6f4496b09dd170869ff4db38dfeb9..8b4155e756c20320bc0ea5f427eb00a84fa4ff64 100644 (file)
@@ -1184,6 +1184,9 @@ struct proto {
        void                    (*unhash)(struct sock *sk);
        void                    (*rehash)(struct sock *sk);
        int                     (*get_port)(struct sock *sk, unsigned short snum);
+#ifdef CONFIG_BPF_SYSCALL
+       int                     (*psock_update_sk_prot)(struct sock *sk, bool restore);
+#endif
 
        /* Keeping track of sockets in use */
 #ifdef CONFIG_PROC_FS
index 075de26f449d27093ec6eeb114d7f53c328b2136..2efa4e5ea23dee148083abed7e542e6e3291940e 100644 (file)
@@ -2203,6 +2203,7 @@ struct sk_psock;
 
 #ifdef CONFIG_BPF_SYSCALL
 struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+int tcp_bpf_update_proto(struct sock *sk, bool restore);
 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 #endif /* CONFIG_BPF_SYSCALL */
 
index d4d064c59232876d54532c6ead05dbd911017f62..df7cc1edc2002c647469a8f94aa8a42b80baba4f 100644 (file)
@@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
 #ifdef CONFIG_BPF_SYSCALL
 struct sk_psock;
 struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+int udp_bpf_update_proto(struct sock *sk, bool restore);
 #endif
 
 #endif /* _UDP_H */
index a045812d7c78cacd4b9e693d9dd7f3eeec7d7714..9fc83f7cc1a050c3541c11a8121406345dce213f 100644 (file)
@@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
 
        write_lock_bh(&sk->sk_callback_lock);
 
-       if (inet_csk_has_ulp(sk)) {
-               psock = ERR_PTR(-EINVAL);
-               goto out;
-       }
-
        if (sk->sk_user_data) {
                psock = ERR_PTR(-EBUSY);
                goto out;
index c2a0411e08a820ed09cb6fd0d460fe65fe03ba5d..2915c7c8778bc4f721a1c49643924a418dfaa59e 100644 (file)
@@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
 
 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 {
-       struct proto *prot;
-
-       switch (sk->sk_type) {
-       case SOCK_STREAM:
-               prot = tcp_bpf_get_proto(sk, psock);
-               break;
-
-       case SOCK_DGRAM:
-               prot = udp_bpf_get_proto(sk, psock);
-               break;
-
-       default:
+       if (!sk->sk_prot->psock_update_sk_prot)
                return -EINVAL;
-       }
-
-       if (IS_ERR(prot))
-               return PTR_ERR(prot);
-
-       sk_psock_update_proto(sk, psock, prot);
-       return 0;
+       psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
+       return sk->sk_prot->psock_update_sk_prot(sk, false);
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
@@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)
 
 static bool sock_map_sk_is_suitable(const struct sock *sk)
 {
-       return sk_is_tcp(sk) || sk_is_udp(sk);
+       return !!sk->sk_prot->psock_update_sk_prot;
 }
 
 static bool sock_map_sk_state_allowed(const struct sock *sk)
index ae980716d896ce520767ee105cd266f3af361c8c..ac8cfbaeacd24213b22c8f1ae0012fe8d9e73507 100644 (file)
@@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int tcp_bpf_update_proto(struct sock *sk, bool restore)
 {
+       struct sk_psock *psock = sk_psock(sk);
        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
+       if (restore) {
+               if (inet_csk_has_ulp(sk)) {
+                       tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
+               } else {
+                       sk->sk_write_space = psock->saved_write_space;
+                       /* Pairs with lockless read in sk_clone_lock() */
+                       WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+               }
+               return 0;
+       }
+
+       if (inet_csk_has_ulp(sk))
+               return -EINVAL;
+
        if (sk->sk_family == AF_INET6) {
                if (tcp_bpf_assert_proto_ops(psock->sk_proto))
-                       return ERR_PTR(-EINVAL);
+                       return -EINVAL;
 
                tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
        }
 
-       return &tcp_bpf_prots[family][config];
+       /* Pairs with lockless read in sk_clone_lock() */
+       WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
+       return 0;
 }
+EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
 
 /* If a child got cloned from a listening socket that had tcp_bpf
  * protocol callbacks installed, we need to restore the callbacks to
index daad4f99db32839b9a6e0680c5472c5641228fd9..dfc6d1c0e710fc521fb831b72f843f8c7e1e602f 100644 (file)
@@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
        .hash                   = inet_hash,
        .unhash                 = inet_unhash,
        .get_port               = inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = tcp_bpf_update_proto,
+#endif
        .enter_memory_pressure  = tcp_enter_memory_pressure,
        .leave_memory_pressure  = tcp_leave_memory_pressure,
        .stream_memory_free     = tcp_stream_memory_free,
index 4a0478b17243aca6eff4a783132546f37c08d524..38952aaee3a138edd34833a470f76cbf16740840 100644 (file)
@@ -2849,6 +2849,9 @@ struct proto udp_prot = {
        .unhash                 = udp_lib_unhash,
        .rehash                 = udp_v4_rehash,
        .get_port               = udp_v4_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = udp_bpf_update_proto,
+#endif
        .memory_allocated       = &udp_memory_allocated,
        .sysctl_mem             = sysctl_udp_mem,
        .sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
index 7a94791efc1abe6948cbea7cd48616958332c395..6001f93cd3a0868648e8b041b0f7c2877c142810 100644 (file)
@@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
 }
 core_initcall(udp_bpf_v4_build_proto);
 
-struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int udp_bpf_update_proto(struct sock *sk, bool restore)
 {
        int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
+       struct sk_psock *psock = sk_psock(sk);
+
+       if (restore) {
+               sk->sk_write_space = psock->saved_write_space;
+               /* Pairs with lockless read in sk_clone_lock() */
+               WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+               return 0;
+       }
 
        if (sk->sk_family == AF_INET6)
                udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 
-       return &udp_bpf_prots[family];
+       /* Pairs with lockless read in sk_clone_lock() */
+       WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
+       return 0;
 }
+EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
index d0f007741e8ed3638999823ca0637f870a3a7c2c..bff22d6ef516acbe0a7d8cdbe1f94d0f5adc9f69 100644 (file)
@@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = {
        .hash                   = inet6_hash,
        .unhash                 = inet_unhash,
        .get_port               = inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = tcp_bpf_update_proto,
+#endif
        .enter_memory_pressure  = tcp_enter_memory_pressure,
        .leave_memory_pressure  = tcp_leave_memory_pressure,
        .stream_memory_free     = tcp_stream_memory_free,
index d25e5a9252fdbdc8f42d3b0aa5346fe81cc3c613..ef2c75bb4771a9911a3968f9990fd468a1825799 100644 (file)
@@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
        .unhash                 = udp_lib_unhash,
        .rehash                 = udp_v6_rehash,
        .get_port               = udp_v6_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = udp_bpf_update_proto,
+#endif
        .memory_allocated       = &udp_memory_allocated,
        .sysctl_mem             = sysctl_udp_mem,
        .sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),