Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[sfrench/cifs-2.6.git] / net / mptcp / protocol.c
index 8ef2927ebca297bf60d51fae91732e09562fd496..c7af62c057bc727e456ad0e57f4271c08ce67ea2 100644 (file)
@@ -410,6 +410,7 @@ static void mptcp_close_wake_up(struct sock *sk)
                sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
 }
 
+/* called under the msk socket lock */
 static bool mptcp_pending_data_fin_ack(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
@@ -441,16 +442,17 @@ static void mptcp_check_data_fin_ack(struct sock *sk)
        }
 }
 
+/* can be called with no lock acquired */
 static bool mptcp_pending_data_fin(struct sock *sk, u64 *seq)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
        if (READ_ONCE(msk->rcv_data_fin) &&
-           ((1 << sk->sk_state) &
+           ((1 << inet_sk_state_load(sk)) &
             (TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2))) {
                u64 rcv_data_fin_seq = READ_ONCE(msk->rcv_data_fin_seq);
 
-               if (msk->ack_seq == rcv_data_fin_seq) {
+               if (READ_ONCE(msk->ack_seq) == rcv_data_fin_seq) {
                        if (seq)
                                *seq = rcv_data_fin_seq;
 
@@ -748,7 +750,7 @@ static bool __mptcp_ofo_queue(struct mptcp_sock *msk)
                        __skb_queue_tail(&sk->sk_receive_queue, skb);
                }
                msk->bytes_received += end_seq - msk->ack_seq;
-               msk->ack_seq = end_seq;
+               WRITE_ONCE(msk->ack_seq, end_seq);
                moved = true;
        }
        return moved;
@@ -985,6 +987,7 @@ static void dfrag_clear(struct sock *sk, struct mptcp_data_frag *dfrag)
        put_page(dfrag->page);
 }
 
+/* called under both the msk socket lock and the data lock */
 static void __mptcp_clean_una(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
@@ -1033,13 +1036,15 @@ static void __mptcp_clean_una(struct sock *sk)
                msk->recovery = false;
 
 out:
-       if (snd_una == READ_ONCE(msk->snd_nxt) &&
-           snd_una == READ_ONCE(msk->write_seq)) {
+       if (snd_una == msk->snd_nxt && snd_una == msk->write_seq) {
                if (mptcp_rtx_timer_pending(sk) && !mptcp_data_fin_enabled(msk))
                        mptcp_stop_rtx_timer(sk);
        } else {
                mptcp_reset_rtx_timer(sk);
        }
+
+       if (mptcp_pending_data_fin_ack(sk))
+               mptcp_schedule_work(sk);
 }
 
 static void __mptcp_clean_una_wakeup(struct sock *sk)
@@ -1499,7 +1504,7 @@ static void mptcp_update_post_push(struct mptcp_sock *msk,
         */
        if (likely(after64(snd_nxt_new, msk->snd_nxt))) {
                msk->bytes_sent += snd_nxt_new - msk->snd_nxt;
-               msk->snd_nxt = snd_nxt_new;
+               WRITE_ONCE(msk->snd_nxt, snd_nxt_new);
        }
 }
 
@@ -2114,7 +2119,7 @@ static unsigned int mptcp_inq_hint(const struct sock *sk)
 
        skb = skb_peek(&msk->receive_queue);
        if (skb) {
-               u64 hint_val = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq;
+               u64 hint_val = READ_ONCE(msk->ack_seq) - MPTCP_SKB_CB(skb)->map_seq;
 
                if (hint_val >= INT_MAX)
                        return INT_MAX;
@@ -2758,7 +2763,7 @@ static void __mptcp_init_sock(struct sock *sk)
        __skb_queue_head_init(&msk->receive_queue);
        msk->out_of_order_queue = RB_ROOT;
        msk->first_pending = NULL;
-       msk->rmem_fwd_alloc = 0;
+       WRITE_ONCE(msk->rmem_fwd_alloc, 0);
        WRITE_ONCE(msk->rmem_released, 0);
        msk->timer_ival = TCP_RTO_MIN;
        msk->scaling_ratio = TCP_DEFAULT_SCALING_RATIO;
@@ -2974,7 +2979,7 @@ static void __mptcp_destroy_sock(struct sock *sk)
 
        sk->sk_prot->destroy(sk);
 
-       WARN_ON_ONCE(msk->rmem_fwd_alloc);
+       WARN_ON_ONCE(READ_ONCE(msk->rmem_fwd_alloc));
        WARN_ON_ONCE(msk->rmem_released);
        sk_stream_kill_queues(sk);
        xfrm_sk_free_policy(sk);
@@ -3149,16 +3154,16 @@ static int mptcp_disconnect(struct sock *sk, int flags)
        WRITE_ONCE(msk->flags, 0);
        msk->cb_flags = 0;
        msk->recovery = false;
-       msk->can_ack = false;
-       msk->fully_established = false;
-       msk->rcv_data_fin = false;
-       msk->snd_data_fin_enable = false;
-       msk->rcv_fastclose = false;
-       msk->use_64bit_ack = false;
-       msk->bytes_consumed = 0;
+       WRITE_ONCE(msk->can_ack, false);
+       WRITE_ONCE(msk->fully_established, false);
+       WRITE_ONCE(msk->rcv_data_fin, false);
+       WRITE_ONCE(msk->snd_data_fin_enable, false);
+       WRITE_ONCE(msk->rcv_fastclose, false);
+       WRITE_ONCE(msk->use_64bit_ack, false);
        WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk)));
        mptcp_pm_data_reset(msk);
        mptcp_ca_reset(sk);
+       msk->bytes_consumed = 0;
        msk->bytes_acked = 0;
        msk->bytes_received = 0;
        msk->bytes_sent = 0;
@@ -3200,17 +3205,17 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
        __mptcp_init_sock(nsk);
 
        msk = mptcp_sk(nsk);
-       msk->local_key = subflow_req->local_key;
-       msk->token = subflow_req->token;
+       WRITE_ONCE(msk->local_key, subflow_req->local_key);
+       WRITE_ONCE(msk->token, subflow_req->token);
        msk->in_accept_queue = 1;
        WRITE_ONCE(msk->fully_established, false);
        if (mp_opt->suboptions & OPTION_MPTCP_CSUMREQD)
                WRITE_ONCE(msk->csum_enabled, true);
 
-       msk->write_seq = subflow_req->idsn + 1;
-       msk->snd_nxt = msk->write_seq;
-       msk->snd_una = msk->write_seq;
-       msk->wnd_end = msk->snd_nxt + req->rsk_rcv_wnd;
+       WRITE_ONCE(msk->write_seq, subflow_req->idsn + 1);
+       WRITE_ONCE(msk->snd_nxt, msk->write_seq);
+       WRITE_ONCE(msk->snd_una, msk->write_seq);
+       WRITE_ONCE(msk->wnd_end, msk->snd_nxt + req->rsk_rcv_wnd);
        msk->setsockopt_seq = mptcp_sk(sk)->setsockopt_seq;
        mptcp_init_sched(msk, mptcp_sk(sk)->sched);
 
@@ -3313,9 +3318,6 @@ void __mptcp_data_acked(struct sock *sk)
                __mptcp_clean_una(sk);
        else
                __set_bit(MPTCP_CLEAN_UNA, &mptcp_sk(sk)->cb_flags);
-
-       if (mptcp_pending_data_fin_ack(sk))
-               mptcp_schedule_work(sk);
 }
 
 void __mptcp_check_push(struct sock *sk, struct sock *ssk)