Merge branch 'for-5.4/ish' into for-linus
[sfrench/cifs-2.6.git] / net / core / sock_map.c
index 52d4faeee18b0cecc8432ee71db16efb852b8644..1330a7442e5b1e54d80d0b675f7356742bcdfbec 100644 (file)
@@ -247,6 +247,8 @@ static void sock_map_free(struct bpf_map *map)
        raw_spin_unlock_bh(&stab->lock);
        rcu_read_unlock();
 
+       synchronize_rcu();
+
        bpf_map_area_free(stab->sks);
        kfree(stab);
 }
@@ -276,16 +278,20 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
                             struct sock **psk)
 {
        struct sock *sk;
+       int err = 0;
 
        raw_spin_lock_bh(&stab->lock);
        sk = *psk;
        if (!sk_test || sk_test == sk)
-               *psk = NULL;
+               sk = xchg(psk, NULL);
+
+       if (likely(sk))
+               sock_map_unref(sk, psk);
+       else
+               err = -EINVAL;
+
        raw_spin_unlock_bh(&stab->lock);
-       if (unlikely(!sk))
-               return -EINVAL;
-       sock_map_unref(sk, psk);
-       return 0;
+       return err;
 }
 
 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
@@ -328,6 +334,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
                                  struct sock *sk, u64 flags)
 {
        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+       struct inet_connection_sock *icsk = inet_csk(sk);
        struct sk_psock_link *link;
        struct sk_psock *psock;
        struct sock *osk;
@@ -338,6 +345,8 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
                return -EINVAL;
        if (unlikely(idx >= map->max_entries))
                return -E2BIG;
+       if (unlikely(icsk->icsk_ulp_data))
+               return -EINVAL;
 
        link = sk_psock_init_link();
        if (!link)