Merge branch 'work.afs' of git://git.kernel.org/pub/scm/linux/kernel/git/viro/vfs
[sfrench/cifs-2.6.git] / net / 9p / trans_virtio.c
index 7728b0acde09aa904f4cc564de23f264abd8d2b6..b1d39cabf125a7f90b7029f4355a0a7025af526d 100644 (file)
@@ -155,7 +155,7 @@ static void req_done(struct virtqueue *vq)
                }
 
                if (len) {
-                       req->rc->size = len;
+                       req->rc.size = len;
                        p9_client_cb(chan->client, req, REQ_STATUS_RCVD);
                }
        }
@@ -207,6 +207,13 @@ static int p9_virtio_cancel(struct p9_client *client, struct p9_req_t *req)
        return 1;
 }
 
+/* Reply won't come, so drop req ref */
+static int p9_virtio_cancelled(struct p9_client *client, struct p9_req_t *req)
+{
+       p9_req_put(req);
+       return 0;
+}
+
 /**
  * pack_sg_list_p - Just like pack_sg_list. Instead of taking a buffer,
  * this takes a list of pages.
@@ -273,12 +280,12 @@ req_retry:
        out_sgs = in_sgs = 0;
        /* Handle out VirtIO ring buffers */
        out = pack_sg_list(chan->sg, 0,
-                          VIRTQUEUE_NUM, req->tc->sdata, req->tc->size);
+                          VIRTQUEUE_NUM, req->tc.sdata, req->tc.size);
        if (out)
                sgs[out_sgs++] = chan->sg;
 
        in = pack_sg_list(chan->sg, out,
-                         VIRTQUEUE_NUM, req->rc->sdata, req->rc->capacity);
+                         VIRTQUEUE_NUM, req->rc.sdata, req->rc.capacity);
        if (in)
                sgs[out_sgs + in_sgs++] = chan->sg + out;
 
@@ -322,7 +329,7 @@ static int p9_get_mapped_pages(struct virtio_chan *chan,
        if (!iov_iter_count(data))
                return 0;
 
-       if (!(data->type & ITER_KVEC)) {
+       if (!iov_iter_is_kvec(data)) {
                int n;
                /*
                 * We allow only p9_max_pages pinned. We wait for the
@@ -404,6 +411,7 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req,
        struct scatterlist *sgs[4];
        size_t offs;
        int need_drop = 0;
+       int kicked = 0;
 
        p9_debug(P9_DEBUG_TRANS, "virtio request\n");
 
@@ -411,29 +419,33 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req,
                __le32 sz;
                int n = p9_get_mapped_pages(chan, &out_pages, uodata,
                                            outlen, &offs, &need_drop);
-               if (n < 0)
-                       return n;
+               if (n < 0) {
+                       err = n;
+                       goto err_out;
+               }
                out_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE);
                if (n != outlen) {
                        __le32 v = cpu_to_le32(n);
-                       memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4);
+                       memcpy(&req->tc.sdata[req->tc.size - 4], &v, 4);
                        outlen = n;
                }
                /* The size field of the message must include the length of the
                 * header and the length of the data.  We didn't actually know
                 * the length of the data until this point so add it in now.
                 */
-               sz = cpu_to_le32(req->tc->size + outlen);
-               memcpy(&req->tc->sdata[0], &sz, sizeof(sz));
+               sz = cpu_to_le32(req->tc.size + outlen);
+               memcpy(&req->tc.sdata[0], &sz, sizeof(sz));
        } else if (uidata) {
                int n = p9_get_mapped_pages(chan, &in_pages, uidata,
                                            inlen, &offs, &need_drop);
-               if (n < 0)
-                       return n;
+               if (n < 0) {
+                       err = n;
+                       goto err_out;
+               }
                in_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE);
                if (n != inlen) {
                        __le32 v = cpu_to_le32(n);
-                       memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4);
+                       memcpy(&req->tc.sdata[req->tc.size - 4], &v, 4);
                        inlen = n;
                }
        }
@@ -445,7 +457,7 @@ req_retry_pinned:
 
        /* out data */
        out = pack_sg_list(chan->sg, 0,
-                          VIRTQUEUE_NUM, req->tc->sdata, req->tc->size);
+                          VIRTQUEUE_NUM, req->tc.sdata, req->tc.size);
 
        if (out)
                sgs[out_sgs++] = chan->sg;
@@ -464,7 +476,7 @@ req_retry_pinned:
         * alloced memory and payload onto the user buffer.
         */
        in = pack_sg_list(chan->sg, out,
-                         VIRTQUEUE_NUM, req->rc->sdata, in_hdr_len);
+                         VIRTQUEUE_NUM, req->rc.sdata, in_hdr_len);
        if (in)
                sgs[out_sgs + in_sgs++] = chan->sg + out;
 
@@ -498,6 +510,7 @@ req_retry_pinned:
        }
        virtqueue_kick(chan->vq);
        spin_unlock_irqrestore(&chan->lock, flags);
+       kicked = 1;
        p9_debug(P9_DEBUG_TRANS, "virtio request kicked\n");
        err = wait_event_killable(req->wq, req->status >= REQ_STATUS_RCVD);
        /*
@@ -518,6 +531,10 @@ err_out:
        }
        kvfree(in_pages);
        kvfree(out_pages);
+       if (!kicked) {
+               /* reply won't come */
+               p9_req_put(req);
+       }
        return err;
 }
 
@@ -750,6 +767,7 @@ static struct p9_trans_module p9_virtio_trans = {
        .request = p9_virtio_request,
        .zc_request = p9_virtio_zc_request,
        .cancel = p9_virtio_cancel,
+       .cancelled = p9_virtio_cancelled,
        /*
         * We leave one entry for input and one entry for response
         * headers. We also skip one more entry to accomodate, address