Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
[sfrench/cifs-2.6.git] / include / linux / skmsg.h
index 8edbbf5f2f9325d120b67a661f2a33bbfec861a3..aba0f0f429bec0d1051a370714248f5d0d821776 100644 (file)
@@ -56,7 +56,8 @@ struct sk_msg {
 
 struct sk_psock_progs {
        struct bpf_prog                 *msg_parser;
-       struct bpf_prog                 *skb_parser;
+       struct bpf_prog                 *stream_parser;
+       struct bpf_prog                 *stream_verdict;
        struct bpf_prog                 *skb_verdict;
 };
 
@@ -70,12 +71,6 @@ struct sk_psock_link {
        void                            *link_raw;
 };
 
-struct sk_psock_parser {
-       struct strparser                strp;
-       bool                            enabled;
-       void (*saved_data_ready)(struct sock *sk);
-};
-
 struct sk_psock_work_state {
        struct sk_buff                  *skb;
        u32                             len;
@@ -90,9 +85,12 @@ struct sk_psock {
        u32                             eval;
        struct sk_msg                   *cork;
        struct sk_psock_progs           progs;
-       struct sk_psock_parser          parser;
+#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
+       struct strparser                strp;
+#endif
        struct sk_buff_head             ingress_skb;
        struct list_head                ingress_msg;
+       spinlock_t                      ingress_lock;
        unsigned long                   state;
        struct list_head                link;
        spinlock_t                      link_lock;
@@ -100,13 +98,14 @@ struct sk_psock {
        void (*saved_unhash)(struct sock *sk);
        void (*saved_close)(struct sock *sk, long timeout);
        void (*saved_write_space)(struct sock *sk);
+       void (*saved_data_ready)(struct sock *sk);
+       int  (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
+                                    bool restore);
        struct proto                    *sk_proto;
+       struct mutex                    work_mutex;
        struct sk_psock_work_state      work_state;
        struct work_struct              work;
-       union {
-               struct rcu_head         rcu;
-               struct work_struct      gc;
-       };
+       struct rcu_work                 rwork;
 };
 
 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
@@ -127,6 +126,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
                              struct sk_msg *msg, u32 bytes);
 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
                             struct sk_msg *msg, u32 bytes);
+int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
+                    long timeo, int *err);
+int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
+                  int len, int flags);
 
 static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
 {
@@ -287,7 +290,45 @@ static inline struct sk_psock *sk_psock(const struct sock *sk)
 static inline void sk_psock_queue_msg(struct sk_psock *psock,
                                      struct sk_msg *msg)
 {
+       spin_lock_bh(&psock->ingress_lock);
        list_add_tail(&msg->list, &psock->ingress_msg);
+       spin_unlock_bh(&psock->ingress_lock);
+}
+
+static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
+{
+       struct sk_msg *msg;
+
+       spin_lock_bh(&psock->ingress_lock);
+       msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
+       if (msg)
+               list_del(&msg->list);
+       spin_unlock_bh(&psock->ingress_lock);
+       return msg;
+}
+
+static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
+{
+       struct sk_msg *msg;
+
+       spin_lock_bh(&psock->ingress_lock);
+       msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
+       spin_unlock_bh(&psock->ingress_lock);
+       return msg;
+}
+
+static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
+                                              struct sk_msg *msg)
+{
+       struct sk_msg *ret;
+
+       spin_lock_bh(&psock->ingress_lock);
+       if (list_is_last(&msg->list, &psock->ingress_msg))
+               ret = NULL;
+       else
+               ret = list_next_entry(msg, list);
+       spin_unlock_bh(&psock->ingress_lock);
+       return ret;
 }
 
 static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
@@ -295,6 +336,13 @@ static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
        return psock ? list_empty(&psock->ingress_msg) : true;
 }
 
+static inline void kfree_sk_msg(struct sk_msg *msg)
+{
+       if (msg->skb)
+               consume_skb(msg->skb);
+       kfree(msg);
+}
+
 static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 {
        struct sock *sk = psock->sk;
@@ -304,10 +352,27 @@ static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 }
 
 struct sk_psock *sk_psock_init(struct sock *sk, int node);
+void sk_psock_stop(struct sk_psock *psock, bool wait);
 
+#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
+#else
+static inline int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
+{
+       return -EOPNOTSUPP;
+}
+
+static inline void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
+{
+}
+
+static inline void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
+{
+}
+#endif
+
 void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
 void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
 
@@ -327,8 +392,6 @@ static inline void sk_psock_free_link(struct sk_psock_link *link)
 
 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
 
-void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
-
 static inline void sk_psock_cork_free(struct sk_psock *psock)
 {
        if (psock->cork) {
@@ -338,25 +401,11 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
        }
 }
 
-static inline void sk_psock_update_proto(struct sock *sk,
-                                        struct sk_psock *psock,
-                                        struct proto *ops)
-{
-       /* Pairs with lockless read in sk_clone_lock() */
-       WRITE_ONCE(sk->sk_prot, ops);
-}
-
 static inline void sk_psock_restore_proto(struct sock *sk,
                                          struct sk_psock *psock)
 {
-       sk->sk_prot->unhash = psock->saved_unhash;
-       if (inet_csk_has_ulp(sk)) {
-               tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
-       } else {
-               sk->sk_write_space = psock->saved_write_space;
-               /* Pairs with lockless read in sk_clone_lock() */
-               WRITE_ONCE(sk->sk_prot, psock->sk_proto);
-       }
+       if (psock->psock_update_sk_prot)
+               psock->psock_update_sk_prot(sk, psock, true);
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
@@ -389,7 +438,6 @@ static inline struct sk_psock *sk_psock_get(struct sock *sk)
        return psock;
 }
 
-void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
 void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
 
 static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
@@ -400,8 +448,8 @@ static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
 
 static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
 {
-       if (psock->parser.enabled)
-               psock->parser.saved_data_ready(sk);
+       if (psock->saved_data_ready)
+               psock->saved_data_ready(sk);
        else
                sk->sk_data_ready(sk);
 }
@@ -430,7 +478,8 @@ static inline int psock_replace_prog(struct bpf_prog **pprog,
 static inline void psock_progs_drop(struct sk_psock_progs *progs)
 {
        psock_set_prog(&progs->msg_parser, NULL);
-       psock_set_prog(&progs->skb_parser, NULL);
+       psock_set_prog(&progs->stream_parser, NULL);
+       psock_set_prog(&progs->stream_verdict, NULL);
        psock_set_prog(&progs->skb_verdict, NULL);
 }
 
@@ -440,6 +489,44 @@ static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
 {
        if (!psock)
                return false;
-       return psock->parser.enabled;
+       return !!psock->saved_data_ready;
+}
+
+#if IS_ENABLED(CONFIG_NET_SOCK_MSG)
+
+/* We only have one bit so far. */
+#define BPF_F_PTR_MASK ~(BPF_F_INGRESS)
+
+static inline bool skb_bpf_ingress(const struct sk_buff *skb)
+{
+       unsigned long sk_redir = skb->_sk_redir;
+
+       return sk_redir & BPF_F_INGRESS;
+}
+
+static inline void skb_bpf_set_ingress(struct sk_buff *skb)
+{
+       skb->_sk_redir |= BPF_F_INGRESS;
+}
+
+static inline void skb_bpf_set_redir(struct sk_buff *skb, struct sock *sk_redir,
+                                    bool ingress)
+{
+       skb->_sk_redir = (unsigned long)sk_redir;
+       if (ingress)
+               skb->_sk_redir |= BPF_F_INGRESS;
+}
+
+static inline struct sock *skb_bpf_redirect_fetch(const struct sk_buff *skb)
+{
+       unsigned long sk_redir = skb->_sk_redir;
+
+       return (struct sock *)(sk_redir & BPF_F_PTR_MASK);
+}
+
+static inline void skb_bpf_redirect_clear(struct sk_buff *skb)
+{
+       skb->_sk_redir = 0;
 }
+#endif /* CONFIG_NET_SOCK_MSG */
 #endif /* _LINUX_SKMSG_H */