net/tls: Fix connection stall on partial tls record
[sfrench/cifs-2.6.git] / net / tls / tls_main.c
index 0d379970960e6c2a101fbb489237f97f0edf68bc..20cd93be6236e03e2fbdbfc9e44cbb12d3261c84 100644 (file)
@@ -114,6 +114,7 @@ int tls_push_sg(struct sock *sk,
        size = sg->length - offset;
        offset += sg->offset;
 
+       ctx->in_tcp_sendpages = true;
        while (1) {
                if (sg_is_last(sg))
                        sendpage_flags = flags;
@@ -134,6 +135,7 @@ retry:
                        offset -= sg->offset;
                        ctx->partially_sent_offset = offset;
                        ctx->partially_sent_record = (void *)sg;
+                       ctx->in_tcp_sendpages = false;
                        return ret;
                }
 
@@ -148,6 +150,8 @@ retry:
        }
 
        clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
+       ctx->in_tcp_sendpages = false;
+       ctx->sk_write_space(sk);
 
        return 0;
 }
@@ -217,6 +221,10 @@ static void tls_write_space(struct sock *sk)
 {
        struct tls_context *ctx = tls_get_ctx(sk);
 
+       /* We are already sending pages, ignore notification */
+       if (ctx->in_tcp_sendpages)
+               return;
+
        if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) {
                gfp_t sk_allocation = sk->sk_allocation;
                int rc;
@@ -241,16 +249,13 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
        struct tls_context *ctx = tls_get_ctx(sk);
        long timeo = sock_sndtimeo(sk, 0);
        void (*sk_proto_close)(struct sock *sk, long timeout);
+       bool free_ctx = false;
 
        lock_sock(sk);
        sk_proto_close = ctx->sk_proto_close;
 
-       if (ctx->conf == TLS_HW_RECORD)
-               goto skip_tx_cleanup;
-
-       if (ctx->conf == TLS_BASE) {
-               kfree(ctx);
-               ctx = NULL;
+       if (ctx->conf == TLS_BASE || ctx->conf == TLS_HW_RECORD) {
+               free_ctx = true;
                goto skip_tx_cleanup;
        }
 
@@ -287,7 +292,7 @@ skip_tx_cleanup:
        /* free ctx for TLS_HW_RECORD, used by tcp_set_state
         * for sk->sk_prot->unhash [tls_hw_unhash]
         */
-       if (ctx && ctx->conf == TLS_HW_RECORD)
+       if (free_ctx)
                kfree(ctx);
 }