Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
[sfrench/cifs-2.6.git] / drivers / vhost / vdpa.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25
26 #include "vhost.h"
27
28 enum {
29         VHOST_VDPA_BACKEND_FEATURES =
30         (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31         (1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
32         (1ULL << VHOST_BACKEND_F_IOTLB_ASID),
33 };
34
35 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
36
37 #define VHOST_VDPA_IOTLB_BUCKETS 16
38
39 struct vhost_vdpa_as {
40         struct hlist_node hash_link;
41         struct vhost_iotlb iotlb;
42         u32 id;
43 };
44
45 struct vhost_vdpa {
46         struct vhost_dev vdev;
47         struct iommu_domain *domain;
48         struct vhost_virtqueue *vqs;
49         struct completion completion;
50         struct vdpa_device *vdpa;
51         struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
52         struct device dev;
53         struct cdev cdev;
54         atomic_t opened;
55         u32 nvqs;
56         int virtio_id;
57         int minor;
58         struct eventfd_ctx *config_ctx;
59         int in_batch;
60         struct vdpa_iova_range range;
61         u32 batch_asid;
62 };
63
64 static DEFINE_IDA(vhost_vdpa_ida);
65
66 static dev_t vhost_vdpa_major;
67
68 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
69                                    struct vhost_iotlb *iotlb, u64 start,
70                                    u64 last, u32 asid);
71
72 static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
73 {
74         struct vhost_vdpa_as *as = container_of(iotlb, struct
75                                                 vhost_vdpa_as, iotlb);
76         return as->id;
77 }
78
79 static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
80 {
81         struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
82         struct vhost_vdpa_as *as;
83
84         hlist_for_each_entry(as, head, hash_link)
85                 if (as->id == asid)
86                         return as;
87
88         return NULL;
89 }
90
91 static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
92 {
93         struct vhost_vdpa_as *as = asid_to_as(v, asid);
94
95         if (!as)
96                 return NULL;
97
98         return &as->iotlb;
99 }
100
101 static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
102 {
103         struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
104         struct vhost_vdpa_as *as;
105
106         if (asid_to_as(v, asid))
107                 return NULL;
108
109         if (asid >= v->vdpa->nas)
110                 return NULL;
111
112         as = kmalloc(sizeof(*as), GFP_KERNEL);
113         if (!as)
114                 return NULL;
115
116         vhost_iotlb_init(&as->iotlb, 0, 0);
117         as->id = asid;
118         hlist_add_head(&as->hash_link, head);
119
120         return as;
121 }
122
123 static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
124                                                       u32 asid)
125 {
126         struct vhost_vdpa_as *as = asid_to_as(v, asid);
127
128         if (as)
129                 return as;
130
131         return vhost_vdpa_alloc_as(v, asid);
132 }
133
134 static void vhost_vdpa_reset_map(struct vhost_vdpa *v, u32 asid)
135 {
136         struct vdpa_device *vdpa = v->vdpa;
137         const struct vdpa_config_ops *ops = vdpa->config;
138
139         if (ops->reset_map)
140                 ops->reset_map(vdpa, asid);
141 }
142
143 static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
144 {
145         struct vhost_vdpa_as *as = asid_to_as(v, asid);
146
147         if (!as)
148                 return -EINVAL;
149
150         hlist_del(&as->hash_link);
151         vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1, asid);
152         /*
153          * Devices with vendor specific IOMMU may need to restore
154          * iotlb to the initial or default state, which cannot be
155          * cleaned up in the all range unmap call above. Give them
156          * a chance to clean up or reset the map to the desired
157          * state.
158          */
159         vhost_vdpa_reset_map(v, asid);
160         kfree(as);
161
162         return 0;
163 }
164
165 static void handle_vq_kick(struct vhost_work *work)
166 {
167         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
168                                                   poll.work);
169         struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
170         const struct vdpa_config_ops *ops = v->vdpa->config;
171
172         ops->kick_vq(v->vdpa, vq - v->vqs);
173 }
174
175 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
176 {
177         struct vhost_virtqueue *vq = private;
178         struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
179
180         if (call_ctx)
181                 eventfd_signal(call_ctx, 1);
182
183         return IRQ_HANDLED;
184 }
185
186 static irqreturn_t vhost_vdpa_config_cb(void *private)
187 {
188         struct vhost_vdpa *v = private;
189         struct eventfd_ctx *config_ctx = v->config_ctx;
190
191         if (config_ctx)
192                 eventfd_signal(config_ctx, 1);
193
194         return IRQ_HANDLED;
195 }
196
197 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
198 {
199         struct vhost_virtqueue *vq = &v->vqs[qid];
200         const struct vdpa_config_ops *ops = v->vdpa->config;
201         struct vdpa_device *vdpa = v->vdpa;
202         int ret, irq;
203
204         if (!ops->get_vq_irq)
205                 return;
206
207         irq = ops->get_vq_irq(vdpa, qid);
208         if (irq < 0)
209                 return;
210
211         irq_bypass_unregister_producer(&vq->call_ctx.producer);
212         if (!vq->call_ctx.ctx)
213                 return;
214
215         vq->call_ctx.producer.token = vq->call_ctx.ctx;
216         vq->call_ctx.producer.irq = irq;
217         ret = irq_bypass_register_producer(&vq->call_ctx.producer);
218         if (unlikely(ret))
219                 dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
220                          qid, vq->call_ctx.producer.token, ret);
221 }
222
223 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
224 {
225         struct vhost_virtqueue *vq = &v->vqs[qid];
226
227         irq_bypass_unregister_producer(&vq->call_ctx.producer);
228 }
229
230 static int _compat_vdpa_reset(struct vhost_vdpa *v)
231 {
232         struct vdpa_device *vdpa = v->vdpa;
233         u32 flags = 0;
234
235         if (v->vdev.vqs) {
236                 flags |= !vhost_backend_has_feature(v->vdev.vqs[0],
237                                                     VHOST_BACKEND_F_IOTLB_PERSIST) ?
238                          VDPA_RESET_F_CLEAN_MAP : 0;
239         }
240
241         return vdpa_reset(vdpa, flags);
242 }
243
244 static int vhost_vdpa_reset(struct vhost_vdpa *v)
245 {
246         v->in_batch = 0;
247         return _compat_vdpa_reset(v);
248 }
249
250 static long vhost_vdpa_bind_mm(struct vhost_vdpa *v)
251 {
252         struct vdpa_device *vdpa = v->vdpa;
253         const struct vdpa_config_ops *ops = vdpa->config;
254
255         if (!vdpa->use_va || !ops->bind_mm)
256                 return 0;
257
258         return ops->bind_mm(vdpa, v->vdev.mm);
259 }
260
261 static void vhost_vdpa_unbind_mm(struct vhost_vdpa *v)
262 {
263         struct vdpa_device *vdpa = v->vdpa;
264         const struct vdpa_config_ops *ops = vdpa->config;
265
266         if (!vdpa->use_va || !ops->unbind_mm)
267                 return;
268
269         ops->unbind_mm(vdpa);
270 }
271
272 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
273 {
274         struct vdpa_device *vdpa = v->vdpa;
275         const struct vdpa_config_ops *ops = vdpa->config;
276         u32 device_id;
277
278         device_id = ops->get_device_id(vdpa);
279
280         if (copy_to_user(argp, &device_id, sizeof(device_id)))
281                 return -EFAULT;
282
283         return 0;
284 }
285
286 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
287 {
288         struct vdpa_device *vdpa = v->vdpa;
289         const struct vdpa_config_ops *ops = vdpa->config;
290         u8 status;
291
292         status = ops->get_status(vdpa);
293
294         if (copy_to_user(statusp, &status, sizeof(status)))
295                 return -EFAULT;
296
297         return 0;
298 }
299
300 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
301 {
302         struct vdpa_device *vdpa = v->vdpa;
303         const struct vdpa_config_ops *ops = vdpa->config;
304         u8 status, status_old;
305         u32 nvqs = v->nvqs;
306         int ret;
307         u16 i;
308
309         if (copy_from_user(&status, statusp, sizeof(status)))
310                 return -EFAULT;
311
312         status_old = ops->get_status(vdpa);
313
314         /*
315          * Userspace shouldn't remove status bits unless reset the
316          * status to 0.
317          */
318         if (status != 0 && (status_old & ~status) != 0)
319                 return -EINVAL;
320
321         if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
322                 for (i = 0; i < nvqs; i++)
323                         vhost_vdpa_unsetup_vq_irq(v, i);
324
325         if (status == 0) {
326                 ret = _compat_vdpa_reset(v);
327                 if (ret)
328                         return ret;
329         } else
330                 vdpa_set_status(vdpa, status);
331
332         if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
333                 for (i = 0; i < nvqs; i++)
334                         vhost_vdpa_setup_vq_irq(v, i);
335
336         return 0;
337 }
338
339 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
340                                       struct vhost_vdpa_config *c)
341 {
342         struct vdpa_device *vdpa = v->vdpa;
343         size_t size = vdpa->config->get_config_size(vdpa);
344
345         if (c->len == 0 || c->off > size)
346                 return -EINVAL;
347
348         if (c->len > size - c->off)
349                 return -E2BIG;
350
351         return 0;
352 }
353
354 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
355                                   struct vhost_vdpa_config __user *c)
356 {
357         struct vdpa_device *vdpa = v->vdpa;
358         struct vhost_vdpa_config config;
359         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
360         u8 *buf;
361
362         if (copy_from_user(&config, c, size))
363                 return -EFAULT;
364         if (vhost_vdpa_config_validate(v, &config))
365                 return -EINVAL;
366         buf = kvzalloc(config.len, GFP_KERNEL);
367         if (!buf)
368                 return -ENOMEM;
369
370         vdpa_get_config(vdpa, config.off, buf, config.len);
371
372         if (copy_to_user(c->buf, buf, config.len)) {
373                 kvfree(buf);
374                 return -EFAULT;
375         }
376
377         kvfree(buf);
378         return 0;
379 }
380
381 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
382                                   struct vhost_vdpa_config __user *c)
383 {
384         struct vdpa_device *vdpa = v->vdpa;
385         struct vhost_vdpa_config config;
386         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
387         u8 *buf;
388
389         if (copy_from_user(&config, c, size))
390                 return -EFAULT;
391         if (vhost_vdpa_config_validate(v, &config))
392                 return -EINVAL;
393
394         buf = vmemdup_user(c->buf, config.len);
395         if (IS_ERR(buf))
396                 return PTR_ERR(buf);
397
398         vdpa_set_config(vdpa, config.off, buf, config.len);
399
400         kvfree(buf);
401         return 0;
402 }
403
404 static bool vhost_vdpa_can_suspend(const struct vhost_vdpa *v)
405 {
406         struct vdpa_device *vdpa = v->vdpa;
407         const struct vdpa_config_ops *ops = vdpa->config;
408
409         return ops->suspend;
410 }
411
412 static bool vhost_vdpa_can_resume(const struct vhost_vdpa *v)
413 {
414         struct vdpa_device *vdpa = v->vdpa;
415         const struct vdpa_config_ops *ops = vdpa->config;
416
417         return ops->resume;
418 }
419
420 static bool vhost_vdpa_has_desc_group(const struct vhost_vdpa *v)
421 {
422         struct vdpa_device *vdpa = v->vdpa;
423         const struct vdpa_config_ops *ops = vdpa->config;
424
425         return ops->get_vq_desc_group;
426 }
427
428 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
429 {
430         struct vdpa_device *vdpa = v->vdpa;
431         const struct vdpa_config_ops *ops = vdpa->config;
432         u64 features;
433
434         features = ops->get_device_features(vdpa);
435
436         if (copy_to_user(featurep, &features, sizeof(features)))
437                 return -EFAULT;
438
439         return 0;
440 }
441
442 static u64 vhost_vdpa_get_backend_features(const struct vhost_vdpa *v)
443 {
444         struct vdpa_device *vdpa = v->vdpa;
445         const struct vdpa_config_ops *ops = vdpa->config;
446
447         if (!ops->get_backend_features)
448                 return 0;
449         else
450                 return ops->get_backend_features(vdpa);
451 }
452
453 static bool vhost_vdpa_has_persistent_map(const struct vhost_vdpa *v)
454 {
455         struct vdpa_device *vdpa = v->vdpa;
456         const struct vdpa_config_ops *ops = vdpa->config;
457
458         return (!ops->set_map && !ops->dma_map) || ops->reset_map ||
459                vhost_vdpa_get_backend_features(v) & BIT_ULL(VHOST_BACKEND_F_IOTLB_PERSIST);
460 }
461
462 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
463 {
464         struct vdpa_device *vdpa = v->vdpa;
465         const struct vdpa_config_ops *ops = vdpa->config;
466         struct vhost_dev *d = &v->vdev;
467         u64 actual_features;
468         u64 features;
469         int i;
470
471         /*
472          * It's not allowed to change the features after they have
473          * been negotiated.
474          */
475         if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
476                 return -EBUSY;
477
478         if (copy_from_user(&features, featurep, sizeof(features)))
479                 return -EFAULT;
480
481         if (vdpa_set_features(vdpa, features))
482                 return -EINVAL;
483
484         /* let the vqs know what has been configured */
485         actual_features = ops->get_driver_features(vdpa);
486         for (i = 0; i < d->nvqs; ++i) {
487                 struct vhost_virtqueue *vq = d->vqs[i];
488
489                 mutex_lock(&vq->mutex);
490                 vq->acked_features = actual_features;
491                 mutex_unlock(&vq->mutex);
492         }
493
494         return 0;
495 }
496
497 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
498 {
499         struct vdpa_device *vdpa = v->vdpa;
500         const struct vdpa_config_ops *ops = vdpa->config;
501         u16 num;
502
503         num = ops->get_vq_num_max(vdpa);
504
505         if (copy_to_user(argp, &num, sizeof(num)))
506                 return -EFAULT;
507
508         return 0;
509 }
510
511 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
512 {
513         if (v->config_ctx) {
514                 eventfd_ctx_put(v->config_ctx);
515                 v->config_ctx = NULL;
516         }
517 }
518
519 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
520 {
521         struct vdpa_callback cb;
522         int fd;
523         struct eventfd_ctx *ctx;
524
525         cb.callback = vhost_vdpa_config_cb;
526         cb.private = v;
527         if (copy_from_user(&fd, argp, sizeof(fd)))
528                 return  -EFAULT;
529
530         ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
531         swap(ctx, v->config_ctx);
532
533         if (!IS_ERR_OR_NULL(ctx))
534                 eventfd_ctx_put(ctx);
535
536         if (IS_ERR(v->config_ctx)) {
537                 long ret = PTR_ERR(v->config_ctx);
538
539                 v->config_ctx = NULL;
540                 return ret;
541         }
542
543         v->vdpa->config->set_config_cb(v->vdpa, &cb);
544
545         return 0;
546 }
547
548 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
549 {
550         struct vhost_vdpa_iova_range range = {
551                 .first = v->range.first,
552                 .last = v->range.last,
553         };
554
555         if (copy_to_user(argp, &range, sizeof(range)))
556                 return -EFAULT;
557         return 0;
558 }
559
560 static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp)
561 {
562         struct vdpa_device *vdpa = v->vdpa;
563         const struct vdpa_config_ops *ops = vdpa->config;
564         u32 size;
565
566         size = ops->get_config_size(vdpa);
567
568         if (copy_to_user(argp, &size, sizeof(size)))
569                 return -EFAULT;
570
571         return 0;
572 }
573
574 static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp)
575 {
576         struct vdpa_device *vdpa = v->vdpa;
577
578         if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs)))
579                 return -EFAULT;
580
581         return 0;
582 }
583
584 /* After a successful return of ioctl the device must not process more
585  * virtqueue descriptors. The device can answer to read or writes of config
586  * fields as if it were not suspended. In particular, writing to "queue_enable"
587  * with a value of 1 will not make the device start processing buffers.
588  */
589 static long vhost_vdpa_suspend(struct vhost_vdpa *v)
590 {
591         struct vdpa_device *vdpa = v->vdpa;
592         const struct vdpa_config_ops *ops = vdpa->config;
593
594         if (!ops->suspend)
595                 return -EOPNOTSUPP;
596
597         return ops->suspend(vdpa);
598 }
599
600 /* After a successful return of this ioctl the device resumes processing
601  * virtqueue descriptors. The device becomes fully operational the same way it
602  * was before it was suspended.
603  */
604 static long vhost_vdpa_resume(struct vhost_vdpa *v)
605 {
606         struct vdpa_device *vdpa = v->vdpa;
607         const struct vdpa_config_ops *ops = vdpa->config;
608
609         if (!ops->resume)
610                 return -EOPNOTSUPP;
611
612         return ops->resume(vdpa);
613 }
614
615 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
616                                    void __user *argp)
617 {
618         struct vdpa_device *vdpa = v->vdpa;
619         const struct vdpa_config_ops *ops = vdpa->config;
620         struct vdpa_vq_state vq_state;
621         struct vdpa_callback cb;
622         struct vhost_virtqueue *vq;
623         struct vhost_vring_state s;
624         u32 idx;
625         long r;
626
627         r = get_user(idx, (u32 __user *)argp);
628         if (r < 0)
629                 return r;
630
631         if (idx >= v->nvqs)
632                 return -ENOBUFS;
633
634         idx = array_index_nospec(idx, v->nvqs);
635         vq = &v->vqs[idx];
636
637         switch (cmd) {
638         case VHOST_VDPA_SET_VRING_ENABLE:
639                 if (copy_from_user(&s, argp, sizeof(s)))
640                         return -EFAULT;
641                 ops->set_vq_ready(vdpa, idx, s.num);
642                 return 0;
643         case VHOST_VDPA_GET_VRING_GROUP:
644                 if (!ops->get_vq_group)
645                         return -EOPNOTSUPP;
646                 s.index = idx;
647                 s.num = ops->get_vq_group(vdpa, idx);
648                 if (s.num >= vdpa->ngroups)
649                         return -EIO;
650                 else if (copy_to_user(argp, &s, sizeof(s)))
651                         return -EFAULT;
652                 return 0;
653         case VHOST_VDPA_GET_VRING_DESC_GROUP:
654                 if (!vhost_vdpa_has_desc_group(v))
655                         return -EOPNOTSUPP;
656                 s.index = idx;
657                 s.num = ops->get_vq_desc_group(vdpa, idx);
658                 if (s.num >= vdpa->ngroups)
659                         return -EIO;
660                 else if (copy_to_user(argp, &s, sizeof(s)))
661                         return -EFAULT;
662                 return 0;
663         case VHOST_VDPA_SET_GROUP_ASID:
664                 if (copy_from_user(&s, argp, sizeof(s)))
665                         return -EFAULT;
666                 if (s.num >= vdpa->nas)
667                         return -EINVAL;
668                 if (!ops->set_group_asid)
669                         return -EOPNOTSUPP;
670                 return ops->set_group_asid(vdpa, idx, s.num);
671         case VHOST_GET_VRING_BASE:
672                 r = ops->get_vq_state(v->vdpa, idx, &vq_state);
673                 if (r)
674                         return r;
675
676                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
677                         vq->last_avail_idx = vq_state.packed.last_avail_idx |
678                                              (vq_state.packed.last_avail_counter << 15);
679                         vq->last_used_idx = vq_state.packed.last_used_idx |
680                                             (vq_state.packed.last_used_counter << 15);
681                 } else {
682                         vq->last_avail_idx = vq_state.split.avail_index;
683                 }
684                 break;
685         }
686
687         r = vhost_vring_ioctl(&v->vdev, cmd, argp);
688         if (r)
689                 return r;
690
691         switch (cmd) {
692         case VHOST_SET_VRING_ADDR:
693                 if (ops->set_vq_address(vdpa, idx,
694                                         (u64)(uintptr_t)vq->desc,
695                                         (u64)(uintptr_t)vq->avail,
696                                         (u64)(uintptr_t)vq->used))
697                         r = -EINVAL;
698                 break;
699
700         case VHOST_SET_VRING_BASE:
701                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
702                         vq_state.packed.last_avail_idx = vq->last_avail_idx & 0x7fff;
703                         vq_state.packed.last_avail_counter = !!(vq->last_avail_idx & 0x8000);
704                         vq_state.packed.last_used_idx = vq->last_used_idx & 0x7fff;
705                         vq_state.packed.last_used_counter = !!(vq->last_used_idx & 0x8000);
706                 } else {
707                         vq_state.split.avail_index = vq->last_avail_idx;
708                 }
709                 r = ops->set_vq_state(vdpa, idx, &vq_state);
710                 break;
711
712         case VHOST_SET_VRING_CALL:
713                 if (vq->call_ctx.ctx) {
714                         cb.callback = vhost_vdpa_virtqueue_cb;
715                         cb.private = vq;
716                         cb.trigger = vq->call_ctx.ctx;
717                 } else {
718                         cb.callback = NULL;
719                         cb.private = NULL;
720                         cb.trigger = NULL;
721                 }
722                 ops->set_vq_cb(vdpa, idx, &cb);
723                 vhost_vdpa_setup_vq_irq(v, idx);
724                 break;
725
726         case VHOST_SET_VRING_NUM:
727                 ops->set_vq_num(vdpa, idx, vq->num);
728                 break;
729         }
730
731         return r;
732 }
733
734 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
735                                       unsigned int cmd, unsigned long arg)
736 {
737         struct vhost_vdpa *v = filep->private_data;
738         struct vhost_dev *d = &v->vdev;
739         void __user *argp = (void __user *)arg;
740         u64 __user *featurep = argp;
741         u64 features;
742         long r = 0;
743
744         if (cmd == VHOST_SET_BACKEND_FEATURES) {
745                 if (copy_from_user(&features, featurep, sizeof(features)))
746                         return -EFAULT;
747                 if (features & ~(VHOST_VDPA_BACKEND_FEATURES |
748                                  BIT_ULL(VHOST_BACKEND_F_DESC_ASID) |
749                                  BIT_ULL(VHOST_BACKEND_F_IOTLB_PERSIST) |
750                                  BIT_ULL(VHOST_BACKEND_F_SUSPEND) |
751                                  BIT_ULL(VHOST_BACKEND_F_RESUME) |
752                                  BIT_ULL(VHOST_BACKEND_F_ENABLE_AFTER_DRIVER_OK)))
753                         return -EOPNOTSUPP;
754                 if ((features & BIT_ULL(VHOST_BACKEND_F_SUSPEND)) &&
755                      !vhost_vdpa_can_suspend(v))
756                         return -EOPNOTSUPP;
757                 if ((features & BIT_ULL(VHOST_BACKEND_F_RESUME)) &&
758                      !vhost_vdpa_can_resume(v))
759                         return -EOPNOTSUPP;
760                 if ((features & BIT_ULL(VHOST_BACKEND_F_DESC_ASID)) &&
761                     !(features & BIT_ULL(VHOST_BACKEND_F_IOTLB_ASID)))
762                         return -EINVAL;
763                 if ((features & BIT_ULL(VHOST_BACKEND_F_DESC_ASID)) &&
764                      !vhost_vdpa_has_desc_group(v))
765                         return -EOPNOTSUPP;
766                 if ((features & BIT_ULL(VHOST_BACKEND_F_IOTLB_PERSIST)) &&
767                      !vhost_vdpa_has_persistent_map(v))
768                         return -EOPNOTSUPP;
769                 vhost_set_backend_features(&v->vdev, features);
770                 return 0;
771         }
772
773         mutex_lock(&d->mutex);
774
775         switch (cmd) {
776         case VHOST_VDPA_GET_DEVICE_ID:
777                 r = vhost_vdpa_get_device_id(v, argp);
778                 break;
779         case VHOST_VDPA_GET_STATUS:
780                 r = vhost_vdpa_get_status(v, argp);
781                 break;
782         case VHOST_VDPA_SET_STATUS:
783                 r = vhost_vdpa_set_status(v, argp);
784                 break;
785         case VHOST_VDPA_GET_CONFIG:
786                 r = vhost_vdpa_get_config(v, argp);
787                 break;
788         case VHOST_VDPA_SET_CONFIG:
789                 r = vhost_vdpa_set_config(v, argp);
790                 break;
791         case VHOST_GET_FEATURES:
792                 r = vhost_vdpa_get_features(v, argp);
793                 break;
794         case VHOST_SET_FEATURES:
795                 r = vhost_vdpa_set_features(v, argp);
796                 break;
797         case VHOST_VDPA_GET_VRING_NUM:
798                 r = vhost_vdpa_get_vring_num(v, argp);
799                 break;
800         case VHOST_VDPA_GET_GROUP_NUM:
801                 if (copy_to_user(argp, &v->vdpa->ngroups,
802                                  sizeof(v->vdpa->ngroups)))
803                         r = -EFAULT;
804                 break;
805         case VHOST_VDPA_GET_AS_NUM:
806                 if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
807                         r = -EFAULT;
808                 break;
809         case VHOST_SET_LOG_BASE:
810         case VHOST_SET_LOG_FD:
811                 r = -ENOIOCTLCMD;
812                 break;
813         case VHOST_VDPA_SET_CONFIG_CALL:
814                 r = vhost_vdpa_set_config_call(v, argp);
815                 break;
816         case VHOST_GET_BACKEND_FEATURES:
817                 features = VHOST_VDPA_BACKEND_FEATURES;
818                 if (vhost_vdpa_can_suspend(v))
819                         features |= BIT_ULL(VHOST_BACKEND_F_SUSPEND);
820                 if (vhost_vdpa_can_resume(v))
821                         features |= BIT_ULL(VHOST_BACKEND_F_RESUME);
822                 if (vhost_vdpa_has_desc_group(v))
823                         features |= BIT_ULL(VHOST_BACKEND_F_DESC_ASID);
824                 if (vhost_vdpa_has_persistent_map(v))
825                         features |= BIT_ULL(VHOST_BACKEND_F_IOTLB_PERSIST);
826                 features |= vhost_vdpa_get_backend_features(v);
827                 if (copy_to_user(featurep, &features, sizeof(features)))
828                         r = -EFAULT;
829                 break;
830         case VHOST_VDPA_GET_IOVA_RANGE:
831                 r = vhost_vdpa_get_iova_range(v, argp);
832                 break;
833         case VHOST_VDPA_GET_CONFIG_SIZE:
834                 r = vhost_vdpa_get_config_size(v, argp);
835                 break;
836         case VHOST_VDPA_GET_VQS_COUNT:
837                 r = vhost_vdpa_get_vqs_count(v, argp);
838                 break;
839         case VHOST_VDPA_SUSPEND:
840                 r = vhost_vdpa_suspend(v);
841                 break;
842         case VHOST_VDPA_RESUME:
843                 r = vhost_vdpa_resume(v);
844                 break;
845         default:
846                 r = vhost_dev_ioctl(&v->vdev, cmd, argp);
847                 if (r == -ENOIOCTLCMD)
848                         r = vhost_vdpa_vring_ioctl(v, cmd, argp);
849                 break;
850         }
851
852         if (r)
853                 goto out;
854
855         switch (cmd) {
856         case VHOST_SET_OWNER:
857                 r = vhost_vdpa_bind_mm(v);
858                 if (r)
859                         vhost_dev_reset_owner(d, NULL);
860                 break;
861         }
862 out:
863         mutex_unlock(&d->mutex);
864         return r;
865 }
866 static void vhost_vdpa_general_unmap(struct vhost_vdpa *v,
867                                      struct vhost_iotlb_map *map, u32 asid)
868 {
869         struct vdpa_device *vdpa = v->vdpa;
870         const struct vdpa_config_ops *ops = vdpa->config;
871         if (ops->dma_map) {
872                 ops->dma_unmap(vdpa, asid, map->start, map->size);
873         } else if (ops->set_map == NULL) {
874                 iommu_unmap(v->domain, map->start, map->size);
875         }
876 }
877
878 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
879                                 u64 start, u64 last, u32 asid)
880 {
881         struct vhost_dev *dev = &v->vdev;
882         struct vhost_iotlb_map *map;
883         struct page *page;
884         unsigned long pfn, pinned;
885
886         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
887                 pinned = PFN_DOWN(map->size);
888                 for (pfn = PFN_DOWN(map->addr);
889                      pinned > 0; pfn++, pinned--) {
890                         page = pfn_to_page(pfn);
891                         if (map->perm & VHOST_ACCESS_WO)
892                                 set_page_dirty_lock(page);
893                         unpin_user_page(page);
894                 }
895                 atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
896                 vhost_vdpa_general_unmap(v, map, asid);
897                 vhost_iotlb_map_free(iotlb, map);
898         }
899 }
900
901 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
902                                 u64 start, u64 last, u32 asid)
903 {
904         struct vhost_iotlb_map *map;
905         struct vdpa_map_file *map_file;
906
907         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
908                 map_file = (struct vdpa_map_file *)map->opaque;
909                 fput(map_file->file);
910                 kfree(map_file);
911                 vhost_vdpa_general_unmap(v, map, asid);
912                 vhost_iotlb_map_free(iotlb, map);
913         }
914 }
915
916 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
917                                    struct vhost_iotlb *iotlb, u64 start,
918                                    u64 last, u32 asid)
919 {
920         struct vdpa_device *vdpa = v->vdpa;
921
922         if (vdpa->use_va)
923                 return vhost_vdpa_va_unmap(v, iotlb, start, last, asid);
924
925         return vhost_vdpa_pa_unmap(v, iotlb, start, last, asid);
926 }
927
928 static int perm_to_iommu_flags(u32 perm)
929 {
930         int flags = 0;
931
932         switch (perm) {
933         case VHOST_ACCESS_WO:
934                 flags |= IOMMU_WRITE;
935                 break;
936         case VHOST_ACCESS_RO:
937                 flags |= IOMMU_READ;
938                 break;
939         case VHOST_ACCESS_RW:
940                 flags |= (IOMMU_WRITE | IOMMU_READ);
941                 break;
942         default:
943                 WARN(1, "invalidate vhost IOTLB permission\n");
944                 break;
945         }
946
947         return flags | IOMMU_CACHE;
948 }
949
950 static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
951                           u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
952 {
953         struct vhost_dev *dev = &v->vdev;
954         struct vdpa_device *vdpa = v->vdpa;
955         const struct vdpa_config_ops *ops = vdpa->config;
956         u32 asid = iotlb_to_asid(iotlb);
957         int r = 0;
958
959         r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
960                                       pa, perm, opaque);
961         if (r)
962                 return r;
963
964         if (ops->dma_map) {
965                 r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
966         } else if (ops->set_map) {
967                 if (!v->in_batch)
968                         r = ops->set_map(vdpa, asid, iotlb);
969         } else {
970                 r = iommu_map(v->domain, iova, pa, size,
971                               perm_to_iommu_flags(perm), GFP_KERNEL);
972         }
973         if (r) {
974                 vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
975                 return r;
976         }
977
978         if (!vdpa->use_va)
979                 atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
980
981         return 0;
982 }
983
984 static void vhost_vdpa_unmap(struct vhost_vdpa *v,
985                              struct vhost_iotlb *iotlb,
986                              u64 iova, u64 size)
987 {
988         struct vdpa_device *vdpa = v->vdpa;
989         const struct vdpa_config_ops *ops = vdpa->config;
990         u32 asid = iotlb_to_asid(iotlb);
991
992         vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1, asid);
993
994         if (ops->set_map) {
995                 if (!v->in_batch)
996                         ops->set_map(vdpa, asid, iotlb);
997         }
998
999 }
1000
1001 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
1002                              struct vhost_iotlb *iotlb,
1003                              u64 iova, u64 size, u64 uaddr, u32 perm)
1004 {
1005         struct vhost_dev *dev = &v->vdev;
1006         u64 offset, map_size, map_iova = iova;
1007         struct vdpa_map_file *map_file;
1008         struct vm_area_struct *vma;
1009         int ret = 0;
1010
1011         mmap_read_lock(dev->mm);
1012
1013         while (size) {
1014                 vma = find_vma(dev->mm, uaddr);
1015                 if (!vma) {
1016                         ret = -EINVAL;
1017                         break;
1018                 }
1019                 map_size = min(size, vma->vm_end - uaddr);
1020                 if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
1021                         !(vma->vm_flags & (VM_IO | VM_PFNMAP))))
1022                         goto next;
1023
1024                 map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
1025                 if (!map_file) {
1026                         ret = -ENOMEM;
1027                         break;
1028                 }
1029                 offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
1030                 map_file->offset = offset;
1031                 map_file->file = get_file(vma->vm_file);
1032                 ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
1033                                      perm, map_file);
1034                 if (ret) {
1035                         fput(map_file->file);
1036                         kfree(map_file);
1037                         break;
1038                 }
1039 next:
1040                 size -= map_size;
1041                 uaddr += map_size;
1042                 map_iova += map_size;
1043         }
1044         if (ret)
1045                 vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
1046
1047         mmap_read_unlock(dev->mm);
1048
1049         return ret;
1050 }
1051
1052 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
1053                              struct vhost_iotlb *iotlb,
1054                              u64 iova, u64 size, u64 uaddr, u32 perm)
1055 {
1056         struct vhost_dev *dev = &v->vdev;
1057         struct page **page_list;
1058         unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
1059         unsigned int gup_flags = FOLL_LONGTERM;
1060         unsigned long npages, cur_base, map_pfn, last_pfn = 0;
1061         unsigned long lock_limit, sz2pin, nchunks, i;
1062         u64 start = iova;
1063         long pinned;
1064         int ret = 0;
1065
1066         /* Limit the use of memory for bookkeeping */
1067         page_list = (struct page **) __get_free_page(GFP_KERNEL);
1068         if (!page_list)
1069                 return -ENOMEM;
1070
1071         if (perm & VHOST_ACCESS_WO)
1072                 gup_flags |= FOLL_WRITE;
1073
1074         npages = PFN_UP(size + (iova & ~PAGE_MASK));
1075         if (!npages) {
1076                 ret = -EINVAL;
1077                 goto free;
1078         }
1079
1080         mmap_read_lock(dev->mm);
1081
1082         lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
1083         if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
1084                 ret = -ENOMEM;
1085                 goto unlock;
1086         }
1087
1088         cur_base = uaddr & PAGE_MASK;
1089         iova &= PAGE_MASK;
1090         nchunks = 0;
1091
1092         while (npages) {
1093                 sz2pin = min_t(unsigned long, npages, list_size);
1094                 pinned = pin_user_pages(cur_base, sz2pin,
1095                                         gup_flags, page_list);
1096                 if (sz2pin != pinned) {
1097                         if (pinned < 0) {
1098                                 ret = pinned;
1099                         } else {
1100                                 unpin_user_pages(page_list, pinned);
1101                                 ret = -ENOMEM;
1102                         }
1103                         goto out;
1104                 }
1105                 nchunks++;
1106
1107                 if (!last_pfn)
1108                         map_pfn = page_to_pfn(page_list[0]);
1109
1110                 for (i = 0; i < pinned; i++) {
1111                         unsigned long this_pfn = page_to_pfn(page_list[i]);
1112                         u64 csize;
1113
1114                         if (last_pfn && (this_pfn != last_pfn + 1)) {
1115                                 /* Pin a contiguous chunk of memory */
1116                                 csize = PFN_PHYS(last_pfn - map_pfn + 1);
1117                                 ret = vhost_vdpa_map(v, iotlb, iova, csize,
1118                                                      PFN_PHYS(map_pfn),
1119                                                      perm, NULL);
1120                                 if (ret) {
1121                                         /*
1122                                          * Unpin the pages that are left unmapped
1123                                          * from this point on in the current
1124                                          * page_list. The remaining outstanding
1125                                          * ones which may stride across several
1126                                          * chunks will be covered in the common
1127                                          * error path subsequently.
1128                                          */
1129                                         unpin_user_pages(&page_list[i],
1130                                                          pinned - i);
1131                                         goto out;
1132                                 }
1133
1134                                 map_pfn = this_pfn;
1135                                 iova += csize;
1136                                 nchunks = 0;
1137                         }
1138
1139                         last_pfn = this_pfn;
1140                 }
1141
1142                 cur_base += PFN_PHYS(pinned);
1143                 npages -= pinned;
1144         }
1145
1146         /* Pin the rest chunk */
1147         ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
1148                              PFN_PHYS(map_pfn), perm, NULL);
1149 out:
1150         if (ret) {
1151                 if (nchunks) {
1152                         unsigned long pfn;
1153
1154                         /*
1155                          * Unpin the outstanding pages which are yet to be
1156                          * mapped but haven't due to vdpa_map() or
1157                          * pin_user_pages() failure.
1158                          *
1159                          * Mapped pages are accounted in vdpa_map(), hence
1160                          * the corresponding unpinning will be handled by
1161                          * vdpa_unmap().
1162                          */
1163                         WARN_ON(!last_pfn);
1164                         for (pfn = map_pfn; pfn <= last_pfn; pfn++)
1165                                 unpin_user_page(pfn_to_page(pfn));
1166                 }
1167                 vhost_vdpa_unmap(v, iotlb, start, size);
1168         }
1169 unlock:
1170         mmap_read_unlock(dev->mm);
1171 free:
1172         free_page((unsigned long)page_list);
1173         return ret;
1174
1175 }
1176
1177 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
1178                                            struct vhost_iotlb *iotlb,
1179                                            struct vhost_iotlb_msg *msg)
1180 {
1181         struct vdpa_device *vdpa = v->vdpa;
1182
1183         if (msg->iova < v->range.first || !msg->size ||
1184             msg->iova > U64_MAX - msg->size + 1 ||
1185             msg->iova + msg->size - 1 > v->range.last)
1186                 return -EINVAL;
1187
1188         if (vhost_iotlb_itree_first(iotlb, msg->iova,
1189                                     msg->iova + msg->size - 1))
1190                 return -EEXIST;
1191
1192         if (vdpa->use_va)
1193                 return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
1194                                          msg->uaddr, msg->perm);
1195
1196         return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
1197                                  msg->perm);
1198 }
1199
1200 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1201                                         struct vhost_iotlb_msg *msg)
1202 {
1203         struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
1204         struct vdpa_device *vdpa = v->vdpa;
1205         const struct vdpa_config_ops *ops = vdpa->config;
1206         struct vhost_iotlb *iotlb = NULL;
1207         struct vhost_vdpa_as *as = NULL;
1208         int r = 0;
1209
1210         mutex_lock(&dev->mutex);
1211
1212         r = vhost_dev_check_owner(dev);
1213         if (r)
1214                 goto unlock;
1215
1216         if (msg->type == VHOST_IOTLB_UPDATE ||
1217             msg->type == VHOST_IOTLB_BATCH_BEGIN) {
1218                 as = vhost_vdpa_find_alloc_as(v, asid);
1219                 if (!as) {
1220                         dev_err(&v->dev, "can't find and alloc asid %d\n",
1221                                 asid);
1222                         r = -EINVAL;
1223                         goto unlock;
1224                 }
1225                 iotlb = &as->iotlb;
1226         } else
1227                 iotlb = asid_to_iotlb(v, asid);
1228
1229         if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
1230                 if (v->in_batch && v->batch_asid != asid) {
1231                         dev_info(&v->dev, "batch id %d asid %d\n",
1232                                  v->batch_asid, asid);
1233                 }
1234                 if (!iotlb)
1235                         dev_err(&v->dev, "no iotlb for asid %d\n", asid);
1236                 r = -EINVAL;
1237                 goto unlock;
1238         }
1239
1240         switch (msg->type) {
1241         case VHOST_IOTLB_UPDATE:
1242                 r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
1243                 break;
1244         case VHOST_IOTLB_INVALIDATE:
1245                 vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
1246                 break;
1247         case VHOST_IOTLB_BATCH_BEGIN:
1248                 v->batch_asid = asid;
1249                 v->in_batch = true;
1250                 break;
1251         case VHOST_IOTLB_BATCH_END:
1252                 if (v->in_batch && ops->set_map)
1253                         ops->set_map(vdpa, asid, iotlb);
1254                 v->in_batch = false;
1255                 break;
1256         default:
1257                 r = -EINVAL;
1258                 break;
1259         }
1260 unlock:
1261         mutex_unlock(&dev->mutex);
1262
1263         return r;
1264 }
1265
1266 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
1267                                          struct iov_iter *from)
1268 {
1269         struct file *file = iocb->ki_filp;
1270         struct vhost_vdpa *v = file->private_data;
1271         struct vhost_dev *dev = &v->vdev;
1272
1273         return vhost_chr_write_iter(dev, from);
1274 }
1275
1276 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
1277 {
1278         struct vdpa_device *vdpa = v->vdpa;
1279         const struct vdpa_config_ops *ops = vdpa->config;
1280         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1281         const struct bus_type *bus;
1282         int ret;
1283
1284         /* Device want to do DMA by itself */
1285         if (ops->set_map || ops->dma_map)
1286                 return 0;
1287
1288         bus = dma_dev->bus;
1289         if (!bus)
1290                 return -EFAULT;
1291
1292         if (!device_iommu_capable(dma_dev, IOMMU_CAP_CACHE_COHERENCY)) {
1293                 dev_warn_once(&v->dev,
1294                               "Failed to allocate domain, device is not IOMMU cache coherent capable\n");
1295                 return -ENOTSUPP;
1296         }
1297
1298         v->domain = iommu_domain_alloc(bus);
1299         if (!v->domain)
1300                 return -EIO;
1301
1302         ret = iommu_attach_device(v->domain, dma_dev);
1303         if (ret)
1304                 goto err_attach;
1305
1306         return 0;
1307
1308 err_attach:
1309         iommu_domain_free(v->domain);
1310         v->domain = NULL;
1311         return ret;
1312 }
1313
1314 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
1315 {
1316         struct vdpa_device *vdpa = v->vdpa;
1317         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1318
1319         if (v->domain) {
1320                 iommu_detach_device(v->domain, dma_dev);
1321                 iommu_domain_free(v->domain);
1322         }
1323
1324         v->domain = NULL;
1325 }
1326
1327 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
1328 {
1329         struct vdpa_iova_range *range = &v->range;
1330         struct vdpa_device *vdpa = v->vdpa;
1331         const struct vdpa_config_ops *ops = vdpa->config;
1332
1333         if (ops->get_iova_range) {
1334                 *range = ops->get_iova_range(vdpa);
1335         } else if (v->domain && v->domain->geometry.force_aperture) {
1336                 range->first = v->domain->geometry.aperture_start;
1337                 range->last = v->domain->geometry.aperture_end;
1338         } else {
1339                 range->first = 0;
1340                 range->last = ULLONG_MAX;
1341         }
1342 }
1343
1344 static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
1345 {
1346         struct vhost_vdpa_as *as;
1347         u32 asid;
1348
1349         for (asid = 0; asid < v->vdpa->nas; asid++) {
1350                 as = asid_to_as(v, asid);
1351                 if (as)
1352                         vhost_vdpa_remove_as(v, asid);
1353         }
1354
1355         vhost_vdpa_free_domain(v);
1356         vhost_dev_cleanup(&v->vdev);
1357         kfree(v->vdev.vqs);
1358         v->vdev.vqs = NULL;
1359 }
1360
1361 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
1362 {
1363         struct vhost_vdpa *v;
1364         struct vhost_dev *dev;
1365         struct vhost_virtqueue **vqs;
1366         int r, opened;
1367         u32 i, nvqs;
1368
1369         v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
1370
1371         opened = atomic_cmpxchg(&v->opened, 0, 1);
1372         if (opened)
1373                 return -EBUSY;
1374
1375         nvqs = v->nvqs;
1376         r = vhost_vdpa_reset(v);
1377         if (r)
1378                 goto err;
1379
1380         vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
1381         if (!vqs) {
1382                 r = -ENOMEM;
1383                 goto err;
1384         }
1385
1386         dev = &v->vdev;
1387         for (i = 0; i < nvqs; i++) {
1388                 vqs[i] = &v->vqs[i];
1389                 vqs[i]->handle_kick = handle_vq_kick;
1390         }
1391         vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
1392                        vhost_vdpa_process_iotlb_msg);
1393
1394         r = vhost_vdpa_alloc_domain(v);
1395         if (r)
1396                 goto err_alloc_domain;
1397
1398         vhost_vdpa_set_iova_range(v);
1399
1400         filep->private_data = v;
1401
1402         return 0;
1403
1404 err_alloc_domain:
1405         vhost_vdpa_cleanup(v);
1406 err:
1407         atomic_dec(&v->opened);
1408         return r;
1409 }
1410
1411 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1412 {
1413         u32 i;
1414
1415         for (i = 0; i < v->nvqs; i++)
1416                 vhost_vdpa_unsetup_vq_irq(v, i);
1417 }
1418
1419 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1420 {
1421         struct vhost_vdpa *v = filep->private_data;
1422         struct vhost_dev *d = &v->vdev;
1423
1424         mutex_lock(&d->mutex);
1425         filep->private_data = NULL;
1426         vhost_vdpa_clean_irq(v);
1427         vhost_vdpa_reset(v);
1428         vhost_dev_stop(&v->vdev);
1429         vhost_vdpa_unbind_mm(v);
1430         vhost_vdpa_config_put(v);
1431         vhost_vdpa_cleanup(v);
1432         mutex_unlock(&d->mutex);
1433
1434         atomic_dec(&v->opened);
1435         complete(&v->completion);
1436
1437         return 0;
1438 }
1439
1440 #ifdef CONFIG_MMU
1441 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1442 {
1443         struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1444         struct vdpa_device *vdpa = v->vdpa;
1445         const struct vdpa_config_ops *ops = vdpa->config;
1446         struct vdpa_notification_area notify;
1447         struct vm_area_struct *vma = vmf->vma;
1448         u16 index = vma->vm_pgoff;
1449
1450         notify = ops->get_vq_notification(vdpa, index);
1451
1452         vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1453         if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1454                             PFN_DOWN(notify.addr), PAGE_SIZE,
1455                             vma->vm_page_prot))
1456                 return VM_FAULT_SIGBUS;
1457
1458         return VM_FAULT_NOPAGE;
1459 }
1460
1461 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1462         .fault = vhost_vdpa_fault,
1463 };
1464
1465 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1466 {
1467         struct vhost_vdpa *v = vma->vm_file->private_data;
1468         struct vdpa_device *vdpa = v->vdpa;
1469         const struct vdpa_config_ops *ops = vdpa->config;
1470         struct vdpa_notification_area notify;
1471         unsigned long index = vma->vm_pgoff;
1472
1473         if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1474                 return -EINVAL;
1475         if ((vma->vm_flags & VM_SHARED) == 0)
1476                 return -EINVAL;
1477         if (vma->vm_flags & VM_READ)
1478                 return -EINVAL;
1479         if (index > 65535)
1480                 return -EINVAL;
1481         if (!ops->get_vq_notification)
1482                 return -ENOTSUPP;
1483
1484         /* To be safe and easily modelled by userspace, We only
1485          * support the doorbell which sits on the page boundary and
1486          * does not share the page with other registers.
1487          */
1488         notify = ops->get_vq_notification(vdpa, index);
1489         if (notify.addr & (PAGE_SIZE - 1))
1490                 return -EINVAL;
1491         if (vma->vm_end - vma->vm_start != notify.size)
1492                 return -ENOTSUPP;
1493
1494         vm_flags_set(vma, VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP);
1495         vma->vm_ops = &vhost_vdpa_vm_ops;
1496         return 0;
1497 }
1498 #endif /* CONFIG_MMU */
1499
1500 static const struct file_operations vhost_vdpa_fops = {
1501         .owner          = THIS_MODULE,
1502         .open           = vhost_vdpa_open,
1503         .release        = vhost_vdpa_release,
1504         .write_iter     = vhost_vdpa_chr_write_iter,
1505         .unlocked_ioctl = vhost_vdpa_unlocked_ioctl,
1506 #ifdef CONFIG_MMU
1507         .mmap           = vhost_vdpa_mmap,
1508 #endif /* CONFIG_MMU */
1509         .compat_ioctl   = compat_ptr_ioctl,
1510 };
1511
1512 static void vhost_vdpa_release_dev(struct device *device)
1513 {
1514         struct vhost_vdpa *v =
1515                container_of(device, struct vhost_vdpa, dev);
1516
1517         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1518         kfree(v->vqs);
1519         kfree(v);
1520 }
1521
1522 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1523 {
1524         const struct vdpa_config_ops *ops = vdpa->config;
1525         struct vhost_vdpa *v;
1526         int minor;
1527         int i, r;
1528
1529         /* We can't support platform IOMMU device with more than 1
1530          * group or as
1531          */
1532         if (!ops->set_map && !ops->dma_map &&
1533             (vdpa->ngroups > 1 || vdpa->nas > 1))
1534                 return -EOPNOTSUPP;
1535
1536         v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1537         if (!v)
1538                 return -ENOMEM;
1539
1540         minor = ida_simple_get(&vhost_vdpa_ida, 0,
1541                                VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1542         if (minor < 0) {
1543                 kfree(v);
1544                 return minor;
1545         }
1546
1547         atomic_set(&v->opened, 0);
1548         v->minor = minor;
1549         v->vdpa = vdpa;
1550         v->nvqs = vdpa->nvqs;
1551         v->virtio_id = ops->get_device_id(vdpa);
1552
1553         device_initialize(&v->dev);
1554         v->dev.release = vhost_vdpa_release_dev;
1555         v->dev.parent = &vdpa->dev;
1556         v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1557         v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1558                                GFP_KERNEL);
1559         if (!v->vqs) {
1560                 r = -ENOMEM;
1561                 goto err;
1562         }
1563
1564         r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1565         if (r)
1566                 goto err;
1567
1568         cdev_init(&v->cdev, &vhost_vdpa_fops);
1569         v->cdev.owner = THIS_MODULE;
1570
1571         r = cdev_device_add(&v->cdev, &v->dev);
1572         if (r)
1573                 goto err;
1574
1575         init_completion(&v->completion);
1576         vdpa_set_drvdata(vdpa, v);
1577
1578         for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
1579                 INIT_HLIST_HEAD(&v->as[i]);
1580
1581         return 0;
1582
1583 err:
1584         put_device(&v->dev);
1585         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1586         return r;
1587 }
1588
1589 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1590 {
1591         struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1592         int opened;
1593
1594         cdev_device_del(&v->cdev, &v->dev);
1595
1596         do {
1597                 opened = atomic_cmpxchg(&v->opened, 0, 1);
1598                 if (!opened)
1599                         break;
1600                 wait_for_completion(&v->completion);
1601         } while (1);
1602
1603         put_device(&v->dev);
1604 }
1605
1606 static struct vdpa_driver vhost_vdpa_driver = {
1607         .driver = {
1608                 .name   = "vhost_vdpa",
1609         },
1610         .probe  = vhost_vdpa_probe,
1611         .remove = vhost_vdpa_remove,
1612 };
1613
1614 static int __init vhost_vdpa_init(void)
1615 {
1616         int r;
1617
1618         r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1619                                 "vhost-vdpa");
1620         if (r)
1621                 goto err_alloc_chrdev;
1622
1623         r = vdpa_register_driver(&vhost_vdpa_driver);
1624         if (r)
1625                 goto err_vdpa_register_driver;
1626
1627         return 0;
1628
1629 err_vdpa_register_driver:
1630         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1631 err_alloc_chrdev:
1632         return r;
1633 }
1634 module_init(vhost_vdpa_init);
1635
1636 static void __exit vhost_vdpa_exit(void)
1637 {
1638         vdpa_unregister_driver(&vhost_vdpa_driver);
1639         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1640 }
1641 module_exit(vhost_vdpa_exit);
1642
1643 MODULE_VERSION("0.0.1");
1644 MODULE_LICENSE("GPL v2");
1645 MODULE_AUTHOR("Intel Corporation");
1646 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");