Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[sfrench/cifs-2.6.git] / net / tls / tls_sw.c
index e23f94a5549b878ee3f9c75860425fbf3c632dbe..2d399b6c407564b5da022ab358d92468359c0ece 100644 (file)
@@ -780,7 +780,7 @@ static int tls_push_record(struct sock *sk, int flags,
 
 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
                               bool full_record, u8 record_type,
-                              size_t *copied, int flags)
+                              ssize_t *copied, int flags)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
@@ -796,9 +796,10 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
        psock = sk_psock_get(sk);
        if (!psock || !policy) {
                err = tls_push_record(sk, flags, record_type);
-               if (err && err != -EINPROGRESS) {
+               if (err && sk->sk_err == EBADMSG) {
                        *copied -= sk_msg_free(sk, msg);
                        tls_free_open_rec(sk);
+                       err = -sk->sk_err;
                }
                if (psock)
                        sk_psock_put(sk, psock);
@@ -824,9 +825,10 @@ more_data:
        switch (psock->eval) {
        case __SK_PASS:
                err = tls_push_record(sk, flags, record_type);
-               if (err && err != -EINPROGRESS) {
+               if (err && sk->sk_err == EBADMSG) {
                        *copied -= sk_msg_free(sk, msg);
                        tls_free_open_rec(sk);
+                       err = -sk->sk_err;
                        goto out_err;
                }
                break;
@@ -916,7 +918,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
        unsigned char record_type = TLS_RECORD_TYPE_DATA;
        bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
        bool eor = !(msg->msg_flags & MSG_MORE);
-       size_t try_to_copy, copied = 0;
+       size_t try_to_copy;
+       ssize_t copied = 0;
        struct sk_msg *msg_pl, *msg_en;
        struct tls_rec *rec;
        int required_size;
@@ -1118,7 +1121,7 @@ send_end:
 
        release_sock(sk);
        mutex_unlock(&tls_ctx->tx_lock);
-       return copied ? copied : ret;
+       return copied > 0 ? copied : ret;
 }
 
 static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
@@ -1132,7 +1135,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
        struct sk_msg *msg_pl;
        struct tls_rec *rec;
        int num_async = 0;
-       size_t copied = 0;
+       ssize_t copied = 0;
        bool full_record;
        int record_room;
        int ret = 0;
@@ -1234,7 +1237,7 @@ wait_for_memory:
        }
 sendpage_end:
        ret = sk_stream_error(sk, flags, ret);
-       return copied ? copied : ret;
+       return copied > 0 ? copied : ret;
 }
 
 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,