mptcp: add annotations around sk->sk_shutdown accesses
[sfrench/cifs-2.6.git] / net / mptcp / protocol.c
index a7dd7d8c9af2db8abef9c6b86fa75419e3e5f33e..af54a878ac277aeb0619ed47d55b700b518ace06 100644 (file)
@@ -603,7 +603,7 @@ static bool mptcp_check_data_fin(struct sock *sk)
                WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1);
                WRITE_ONCE(msk->rcv_data_fin, 0);
 
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
 
                switch (sk->sk_state) {
@@ -910,7 +910,7 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)
                /* hopefully temporary hack: propagate shutdown status
                 * to msk, when all subflows agree on it
                 */
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
 
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
                sk->sk_data_ready(sk);
@@ -2526,7 +2526,7 @@ static void mptcp_check_fastclose(struct mptcp_sock *msk)
        }
 
        inet_sk_state_store(sk, TCP_CLOSE);
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
        smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
        set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags);
 
@@ -2958,7 +2958,7 @@ bool __mptcp_close(struct sock *sk, long timeout)
        bool do_cancel_work = false;
        int subflows_alive = 0;
 
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
 
        if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) {
                mptcp_listen_inuse_dec(sk);
@@ -3101,7 +3101,7 @@ static int mptcp_disconnect(struct sock *sk, int flags)
        mptcp_pm_data_reset(msk);
        mptcp_ca_reset(sk);
 
-       sk->sk_shutdown = 0;
+       WRITE_ONCE(sk->sk_shutdown, 0);
        sk_error_report(sk);
        return 0;
 }
@@ -3806,9 +3806,6 @@ static __poll_t mptcp_check_writeable(struct mptcp_sock *msk)
 {
        struct sock *sk = (struct sock *)msk;
 
-       if (unlikely(sk->sk_shutdown & SEND_SHUTDOWN))
-               return EPOLLOUT | EPOLLWRNORM;
-
        if (sk_stream_is_writeable(sk))
                return EPOLLOUT | EPOLLWRNORM;
 
@@ -3826,6 +3823,7 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        struct sock *sk = sock->sk;
        struct mptcp_sock *msk;
        __poll_t mask = 0;
+       u8 shutdown;
        int state;
 
        msk = mptcp_sk(sk);
@@ -3842,17 +3840,22 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
                return inet_csk_listen_poll(ssock->sk);
        }
 
+       shutdown = READ_ONCE(sk->sk_shutdown);
+       if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
+               mask |= EPOLLHUP;
+       if (shutdown & RCV_SHUTDOWN)
+               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
+
        if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
                mask |= mptcp_check_readable(msk);
-               mask |= mptcp_check_writeable(msk);
+               if (shutdown & SEND_SHUTDOWN)
+                       mask |= EPOLLOUT | EPOLLWRNORM;
+               else
+                       mask |= mptcp_check_writeable(msk);
        } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) {
                /* cf tcp_poll() note about TFO */
                mask |= EPOLLOUT | EPOLLWRNORM;
        }
-       if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
-               mask |= EPOLLHUP;
-       if (sk->sk_shutdown & RCV_SHUTDOWN)
-               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
 
        /* This barrier is coupled with smp_wmb() in __mptcp_error_report() */
        smp_rmb();