Merge tag 'iommu-drivers-move-v5.8' of git://git.kernel.org/pub/scm/linux/kernel...
[sfrench/cifs-2.6.git] / drivers / iommu / amd / iommu_v2.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
4  * Author: Joerg Roedel <jroedel@suse.de>
5  */
6
7 #define pr_fmt(fmt)     "AMD-Vi: " fmt
8
9 #include <linux/mmu_notifier.h>
10 #include <linux/amd-iommu.h>
11 #include <linux/mm_types.h>
12 #include <linux/profile.h>
13 #include <linux/module.h>
14 #include <linux/sched.h>
15 #include <linux/sched/mm.h>
16 #include <linux/wait.h>
17 #include <linux/pci.h>
18 #include <linux/gfp.h>
19
20 #include "amd_iommu.h"
21
22 MODULE_LICENSE("GPL v2");
23 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
24
25 #define MAX_DEVICES             0x10000
26 #define PRI_QUEUE_SIZE          512
27
28 struct pri_queue {
29         atomic_t inflight;
30         bool finish;
31         int status;
32 };
33
34 struct pasid_state {
35         struct list_head list;                  /* For global state-list */
36         atomic_t count;                         /* Reference count */
37         unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
38                                                    calls */
39         struct mm_struct *mm;                   /* mm_struct for the faults */
40         struct mmu_notifier mn;                 /* mmu_notifier handle */
41         struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
42         struct device_state *device_state;      /* Link to our device_state */
43         int pasid;                              /* PASID index */
44         bool invalid;                           /* Used during setup and
45                                                    teardown of the pasid */
46         spinlock_t lock;                        /* Protect pri_queues and
47                                                    mmu_notifer_count */
48         wait_queue_head_t wq;                   /* To wait for count == 0 */
49 };
50
51 struct device_state {
52         struct list_head list;
53         u16 devid;
54         atomic_t count;
55         struct pci_dev *pdev;
56         struct pasid_state **states;
57         struct iommu_domain *domain;
58         int pasid_levels;
59         int max_pasids;
60         amd_iommu_invalid_ppr_cb inv_ppr_cb;
61         amd_iommu_invalidate_ctx inv_ctx_cb;
62         spinlock_t lock;
63         wait_queue_head_t wq;
64 };
65
66 struct fault {
67         struct work_struct work;
68         struct device_state *dev_state;
69         struct pasid_state *state;
70         struct mm_struct *mm;
71         u64 address;
72         u16 devid;
73         u16 pasid;
74         u16 tag;
75         u16 finish;
76         u16 flags;
77 };
78
79 static LIST_HEAD(state_list);
80 static spinlock_t state_lock;
81
82 static struct workqueue_struct *iommu_wq;
83
84 static void free_pasid_states(struct device_state *dev_state);
85
86 static u16 device_id(struct pci_dev *pdev)
87 {
88         u16 devid;
89
90         devid = pdev->bus->number;
91         devid = (devid << 8) | pdev->devfn;
92
93         return devid;
94 }
95
96 static struct device_state *__get_device_state(u16 devid)
97 {
98         struct device_state *dev_state;
99
100         list_for_each_entry(dev_state, &state_list, list) {
101                 if (dev_state->devid == devid)
102                         return dev_state;
103         }
104
105         return NULL;
106 }
107
108 static struct device_state *get_device_state(u16 devid)
109 {
110         struct device_state *dev_state;
111         unsigned long flags;
112
113         spin_lock_irqsave(&state_lock, flags);
114         dev_state = __get_device_state(devid);
115         if (dev_state != NULL)
116                 atomic_inc(&dev_state->count);
117         spin_unlock_irqrestore(&state_lock, flags);
118
119         return dev_state;
120 }
121
122 static void free_device_state(struct device_state *dev_state)
123 {
124         struct iommu_group *group;
125
126         /*
127          * First detach device from domain - No more PRI requests will arrive
128          * from that device after it is unbound from the IOMMUv2 domain.
129          */
130         group = iommu_group_get(&dev_state->pdev->dev);
131         if (WARN_ON(!group))
132                 return;
133
134         iommu_detach_group(dev_state->domain, group);
135
136         iommu_group_put(group);
137
138         /* Everything is down now, free the IOMMUv2 domain */
139         iommu_domain_free(dev_state->domain);
140
141         /* Finally get rid of the device-state */
142         kfree(dev_state);
143 }
144
145 static void put_device_state(struct device_state *dev_state)
146 {
147         if (atomic_dec_and_test(&dev_state->count))
148                 wake_up(&dev_state->wq);
149 }
150
151 /* Must be called under dev_state->lock */
152 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
153                                                   int pasid, bool alloc)
154 {
155         struct pasid_state **root, **ptr;
156         int level, index;
157
158         level = dev_state->pasid_levels;
159         root  = dev_state->states;
160
161         while (true) {
162
163                 index = (pasid >> (9 * level)) & 0x1ff;
164                 ptr   = &root[index];
165
166                 if (level == 0)
167                         break;
168
169                 if (*ptr == NULL) {
170                         if (!alloc)
171                                 return NULL;
172
173                         *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
174                         if (*ptr == NULL)
175                                 return NULL;
176                 }
177
178                 root   = (struct pasid_state **)*ptr;
179                 level -= 1;
180         }
181
182         return ptr;
183 }
184
185 static int set_pasid_state(struct device_state *dev_state,
186                            struct pasid_state *pasid_state,
187                            int pasid)
188 {
189         struct pasid_state **ptr;
190         unsigned long flags;
191         int ret;
192
193         spin_lock_irqsave(&dev_state->lock, flags);
194         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
195
196         ret = -ENOMEM;
197         if (ptr == NULL)
198                 goto out_unlock;
199
200         ret = -ENOMEM;
201         if (*ptr != NULL)
202                 goto out_unlock;
203
204         *ptr = pasid_state;
205
206         ret = 0;
207
208 out_unlock:
209         spin_unlock_irqrestore(&dev_state->lock, flags);
210
211         return ret;
212 }
213
214 static void clear_pasid_state(struct device_state *dev_state, int pasid)
215 {
216         struct pasid_state **ptr;
217         unsigned long flags;
218
219         spin_lock_irqsave(&dev_state->lock, flags);
220         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
221
222         if (ptr == NULL)
223                 goto out_unlock;
224
225         *ptr = NULL;
226
227 out_unlock:
228         spin_unlock_irqrestore(&dev_state->lock, flags);
229 }
230
231 static struct pasid_state *get_pasid_state(struct device_state *dev_state,
232                                            int pasid)
233 {
234         struct pasid_state **ptr, *ret = NULL;
235         unsigned long flags;
236
237         spin_lock_irqsave(&dev_state->lock, flags);
238         ptr = __get_pasid_state_ptr(dev_state, pasid, false);
239
240         if (ptr == NULL)
241                 goto out_unlock;
242
243         ret = *ptr;
244         if (ret)
245                 atomic_inc(&ret->count);
246
247 out_unlock:
248         spin_unlock_irqrestore(&dev_state->lock, flags);
249
250         return ret;
251 }
252
253 static void free_pasid_state(struct pasid_state *pasid_state)
254 {
255         kfree(pasid_state);
256 }
257
258 static void put_pasid_state(struct pasid_state *pasid_state)
259 {
260         if (atomic_dec_and_test(&pasid_state->count))
261                 wake_up(&pasid_state->wq);
262 }
263
264 static void put_pasid_state_wait(struct pasid_state *pasid_state)
265 {
266         atomic_dec(&pasid_state->count);
267         wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
268         free_pasid_state(pasid_state);
269 }
270
271 static void unbind_pasid(struct pasid_state *pasid_state)
272 {
273         struct iommu_domain *domain;
274
275         domain = pasid_state->device_state->domain;
276
277         /*
278          * Mark pasid_state as invalid, no more faults will we added to the
279          * work queue after this is visible everywhere.
280          */
281         pasid_state->invalid = true;
282
283         /* Make sure this is visible */
284         smp_wmb();
285
286         /* After this the device/pasid can't access the mm anymore */
287         amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
288
289         /* Make sure no more pending faults are in the queue */
290         flush_workqueue(iommu_wq);
291 }
292
293 static void free_pasid_states_level1(struct pasid_state **tbl)
294 {
295         int i;
296
297         for (i = 0; i < 512; ++i) {
298                 if (tbl[i] == NULL)
299                         continue;
300
301                 free_page((unsigned long)tbl[i]);
302         }
303 }
304
305 static void free_pasid_states_level2(struct pasid_state **tbl)
306 {
307         struct pasid_state **ptr;
308         int i;
309
310         for (i = 0; i < 512; ++i) {
311                 if (tbl[i] == NULL)
312                         continue;
313
314                 ptr = (struct pasid_state **)tbl[i];
315                 free_pasid_states_level1(ptr);
316         }
317 }
318
319 static void free_pasid_states(struct device_state *dev_state)
320 {
321         struct pasid_state *pasid_state;
322         int i;
323
324         for (i = 0; i < dev_state->max_pasids; ++i) {
325                 pasid_state = get_pasid_state(dev_state, i);
326                 if (pasid_state == NULL)
327                         continue;
328
329                 put_pasid_state(pasid_state);
330
331                 /*
332                  * This will call the mn_release function and
333                  * unbind the PASID
334                  */
335                 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
336
337                 put_pasid_state_wait(pasid_state); /* Reference taken in
338                                                       amd_iommu_bind_pasid */
339
340                 /* Drop reference taken in amd_iommu_bind_pasid */
341                 put_device_state(dev_state);
342         }
343
344         if (dev_state->pasid_levels == 2)
345                 free_pasid_states_level2(dev_state->states);
346         else if (dev_state->pasid_levels == 1)
347                 free_pasid_states_level1(dev_state->states);
348         else
349                 BUG_ON(dev_state->pasid_levels != 0);
350
351         free_page((unsigned long)dev_state->states);
352 }
353
354 static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
355 {
356         return container_of(mn, struct pasid_state, mn);
357 }
358
359 static void mn_invalidate_range(struct mmu_notifier *mn,
360                                 struct mm_struct *mm,
361                                 unsigned long start, unsigned long end)
362 {
363         struct pasid_state *pasid_state;
364         struct device_state *dev_state;
365
366         pasid_state = mn_to_state(mn);
367         dev_state   = pasid_state->device_state;
368
369         if ((start ^ (end - 1)) < PAGE_SIZE)
370                 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
371                                      start);
372         else
373                 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
374 }
375
376 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
377 {
378         struct pasid_state *pasid_state;
379         struct device_state *dev_state;
380         bool run_inv_ctx_cb;
381
382         might_sleep();
383
384         pasid_state    = mn_to_state(mn);
385         dev_state      = pasid_state->device_state;
386         run_inv_ctx_cb = !pasid_state->invalid;
387
388         if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
389                 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
390
391         unbind_pasid(pasid_state);
392 }
393
394 static const struct mmu_notifier_ops iommu_mn = {
395         .release                = mn_release,
396         .invalidate_range       = mn_invalidate_range,
397 };
398
399 static void set_pri_tag_status(struct pasid_state *pasid_state,
400                                u16 tag, int status)
401 {
402         unsigned long flags;
403
404         spin_lock_irqsave(&pasid_state->lock, flags);
405         pasid_state->pri[tag].status = status;
406         spin_unlock_irqrestore(&pasid_state->lock, flags);
407 }
408
409 static void finish_pri_tag(struct device_state *dev_state,
410                            struct pasid_state *pasid_state,
411                            u16 tag)
412 {
413         unsigned long flags;
414
415         spin_lock_irqsave(&pasid_state->lock, flags);
416         if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
417             pasid_state->pri[tag].finish) {
418                 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
419                                        pasid_state->pri[tag].status, tag);
420                 pasid_state->pri[tag].finish = false;
421                 pasid_state->pri[tag].status = PPR_SUCCESS;
422         }
423         spin_unlock_irqrestore(&pasid_state->lock, flags);
424 }
425
426 static void handle_fault_error(struct fault *fault)
427 {
428         int status;
429
430         if (!fault->dev_state->inv_ppr_cb) {
431                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
432                 return;
433         }
434
435         status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
436                                               fault->pasid,
437                                               fault->address,
438                                               fault->flags);
439         switch (status) {
440         case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
441                 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
442                 break;
443         case AMD_IOMMU_INV_PRI_RSP_INVALID:
444                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
445                 break;
446         case AMD_IOMMU_INV_PRI_RSP_FAIL:
447                 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
448                 break;
449         default:
450                 BUG();
451         }
452 }
453
454 static bool access_error(struct vm_area_struct *vma, struct fault *fault)
455 {
456         unsigned long requested = 0;
457
458         if (fault->flags & PPR_FAULT_EXEC)
459                 requested |= VM_EXEC;
460
461         if (fault->flags & PPR_FAULT_READ)
462                 requested |= VM_READ;
463
464         if (fault->flags & PPR_FAULT_WRITE)
465                 requested |= VM_WRITE;
466
467         return (requested & ~vma->vm_flags) != 0;
468 }
469
470 static void do_fault(struct work_struct *work)
471 {
472         struct fault *fault = container_of(work, struct fault, work);
473         struct vm_area_struct *vma;
474         vm_fault_t ret = VM_FAULT_ERROR;
475         unsigned int flags = 0;
476         struct mm_struct *mm;
477         u64 address;
478
479         mm = fault->state->mm;
480         address = fault->address;
481
482         if (fault->flags & PPR_FAULT_USER)
483                 flags |= FAULT_FLAG_USER;
484         if (fault->flags & PPR_FAULT_WRITE)
485                 flags |= FAULT_FLAG_WRITE;
486         flags |= FAULT_FLAG_REMOTE;
487
488         mmap_read_lock(mm);
489         vma = find_extend_vma(mm, address);
490         if (!vma || address < vma->vm_start)
491                 /* failed to get a vma in the right range */
492                 goto out;
493
494         /* Check if we have the right permissions on the vma */
495         if (access_error(vma, fault))
496                 goto out;
497
498         ret = handle_mm_fault(vma, address, flags);
499 out:
500         mmap_read_unlock(mm);
501
502         if (ret & VM_FAULT_ERROR)
503                 /* failed to service fault */
504                 handle_fault_error(fault);
505
506         finish_pri_tag(fault->dev_state, fault->state, fault->tag);
507
508         put_pasid_state(fault->state);
509
510         kfree(fault);
511 }
512
513 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
514 {
515         struct amd_iommu_fault *iommu_fault;
516         struct pasid_state *pasid_state;
517         struct device_state *dev_state;
518         struct pci_dev *pdev = NULL;
519         unsigned long flags;
520         struct fault *fault;
521         bool finish;
522         u16 tag, devid;
523         int ret;
524
525         iommu_fault = data;
526         tag         = iommu_fault->tag & 0x1ff;
527         finish      = (iommu_fault->tag >> 9) & 1;
528
529         devid = iommu_fault->device_id;
530         pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
531                                            devid & 0xff);
532         if (!pdev)
533                 return -ENODEV;
534
535         ret = NOTIFY_DONE;
536
537         /* In kdump kernel pci dev is not initialized yet -> send INVALID */
538         if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) {
539                 amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
540                                        PPR_INVALID, tag);
541                 goto out;
542         }
543
544         dev_state = get_device_state(iommu_fault->device_id);
545         if (dev_state == NULL)
546                 goto out;
547
548         pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
549         if (pasid_state == NULL || pasid_state->invalid) {
550                 /* We know the device but not the PASID -> send INVALID */
551                 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
552                                        PPR_INVALID, tag);
553                 goto out_drop_state;
554         }
555
556         spin_lock_irqsave(&pasid_state->lock, flags);
557         atomic_inc(&pasid_state->pri[tag].inflight);
558         if (finish)
559                 pasid_state->pri[tag].finish = true;
560         spin_unlock_irqrestore(&pasid_state->lock, flags);
561
562         fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
563         if (fault == NULL) {
564                 /* We are OOM - send success and let the device re-fault */
565                 finish_pri_tag(dev_state, pasid_state, tag);
566                 goto out_drop_state;
567         }
568
569         fault->dev_state = dev_state;
570         fault->address   = iommu_fault->address;
571         fault->state     = pasid_state;
572         fault->tag       = tag;
573         fault->finish    = finish;
574         fault->pasid     = iommu_fault->pasid;
575         fault->flags     = iommu_fault->flags;
576         INIT_WORK(&fault->work, do_fault);
577
578         queue_work(iommu_wq, &fault->work);
579
580         ret = NOTIFY_OK;
581
582 out_drop_state:
583
584         if (ret != NOTIFY_OK && pasid_state)
585                 put_pasid_state(pasid_state);
586
587         put_device_state(dev_state);
588
589 out:
590         return ret;
591 }
592
593 static struct notifier_block ppr_nb = {
594         .notifier_call = ppr_notifier,
595 };
596
597 int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
598                          struct task_struct *task)
599 {
600         struct pasid_state *pasid_state;
601         struct device_state *dev_state;
602         struct mm_struct *mm;
603         u16 devid;
604         int ret;
605
606         might_sleep();
607
608         if (!amd_iommu_v2_supported())
609                 return -ENODEV;
610
611         devid     = device_id(pdev);
612         dev_state = get_device_state(devid);
613
614         if (dev_state == NULL)
615                 return -EINVAL;
616
617         ret = -EINVAL;
618         if (pasid < 0 || pasid >= dev_state->max_pasids)
619                 goto out;
620
621         ret = -ENOMEM;
622         pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
623         if (pasid_state == NULL)
624                 goto out;
625
626
627         atomic_set(&pasid_state->count, 1);
628         init_waitqueue_head(&pasid_state->wq);
629         spin_lock_init(&pasid_state->lock);
630
631         mm                        = get_task_mm(task);
632         pasid_state->mm           = mm;
633         pasid_state->device_state = dev_state;
634         pasid_state->pasid        = pasid;
635         pasid_state->invalid      = true; /* Mark as valid only if we are
636                                              done with setting up the pasid */
637         pasid_state->mn.ops       = &iommu_mn;
638
639         if (pasid_state->mm == NULL)
640                 goto out_free;
641
642         mmu_notifier_register(&pasid_state->mn, mm);
643
644         ret = set_pasid_state(dev_state, pasid_state, pasid);
645         if (ret)
646                 goto out_unregister;
647
648         ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
649                                         __pa(pasid_state->mm->pgd));
650         if (ret)
651                 goto out_clear_state;
652
653         /* Now we are ready to handle faults */
654         pasid_state->invalid = false;
655
656         /*
657          * Drop the reference to the mm_struct here. We rely on the
658          * mmu_notifier release call-back to inform us when the mm
659          * is going away.
660          */
661         mmput(mm);
662
663         return 0;
664
665 out_clear_state:
666         clear_pasid_state(dev_state, pasid);
667
668 out_unregister:
669         mmu_notifier_unregister(&pasid_state->mn, mm);
670         mmput(mm);
671
672 out_free:
673         free_pasid_state(pasid_state);
674
675 out:
676         put_device_state(dev_state);
677
678         return ret;
679 }
680 EXPORT_SYMBOL(amd_iommu_bind_pasid);
681
682 void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
683 {
684         struct pasid_state *pasid_state;
685         struct device_state *dev_state;
686         u16 devid;
687
688         might_sleep();
689
690         if (!amd_iommu_v2_supported())
691                 return;
692
693         devid = device_id(pdev);
694         dev_state = get_device_state(devid);
695         if (dev_state == NULL)
696                 return;
697
698         if (pasid < 0 || pasid >= dev_state->max_pasids)
699                 goto out;
700
701         pasid_state = get_pasid_state(dev_state, pasid);
702         if (pasid_state == NULL)
703                 goto out;
704         /*
705          * Drop reference taken here. We are safe because we still hold
706          * the reference taken in the amd_iommu_bind_pasid function.
707          */
708         put_pasid_state(pasid_state);
709
710         /* Clear the pasid state so that the pasid can be re-used */
711         clear_pasid_state(dev_state, pasid_state->pasid);
712
713         /*
714          * Call mmu_notifier_unregister to drop our reference
715          * to pasid_state->mm
716          */
717         mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
718
719         put_pasid_state_wait(pasid_state); /* Reference taken in
720                                               amd_iommu_bind_pasid */
721 out:
722         /* Drop reference taken in this function */
723         put_device_state(dev_state);
724
725         /* Drop reference taken in amd_iommu_bind_pasid */
726         put_device_state(dev_state);
727 }
728 EXPORT_SYMBOL(amd_iommu_unbind_pasid);
729
730 int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
731 {
732         struct device_state *dev_state;
733         struct iommu_group *group;
734         unsigned long flags;
735         int ret, tmp;
736         u16 devid;
737
738         might_sleep();
739
740         if (!amd_iommu_v2_supported())
741                 return -ENODEV;
742
743         if (pasids <= 0 || pasids > (PASID_MASK + 1))
744                 return -EINVAL;
745
746         devid = device_id(pdev);
747
748         dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
749         if (dev_state == NULL)
750                 return -ENOMEM;
751
752         spin_lock_init(&dev_state->lock);
753         init_waitqueue_head(&dev_state->wq);
754         dev_state->pdev  = pdev;
755         dev_state->devid = devid;
756
757         tmp = pasids;
758         for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
759                 dev_state->pasid_levels += 1;
760
761         atomic_set(&dev_state->count, 1);
762         dev_state->max_pasids = pasids;
763
764         ret = -ENOMEM;
765         dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
766         if (dev_state->states == NULL)
767                 goto out_free_dev_state;
768
769         dev_state->domain = iommu_domain_alloc(&pci_bus_type);
770         if (dev_state->domain == NULL)
771                 goto out_free_states;
772
773         amd_iommu_domain_direct_map(dev_state->domain);
774
775         ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
776         if (ret)
777                 goto out_free_domain;
778
779         group = iommu_group_get(&pdev->dev);
780         if (!group) {
781                 ret = -EINVAL;
782                 goto out_free_domain;
783         }
784
785         ret = iommu_attach_group(dev_state->domain, group);
786         if (ret != 0)
787                 goto out_drop_group;
788
789         iommu_group_put(group);
790
791         spin_lock_irqsave(&state_lock, flags);
792
793         if (__get_device_state(devid) != NULL) {
794                 spin_unlock_irqrestore(&state_lock, flags);
795                 ret = -EBUSY;
796                 goto out_free_domain;
797         }
798
799         list_add_tail(&dev_state->list, &state_list);
800
801         spin_unlock_irqrestore(&state_lock, flags);
802
803         return 0;
804
805 out_drop_group:
806         iommu_group_put(group);
807
808 out_free_domain:
809         iommu_domain_free(dev_state->domain);
810
811 out_free_states:
812         free_page((unsigned long)dev_state->states);
813
814 out_free_dev_state:
815         kfree(dev_state);
816
817         return ret;
818 }
819 EXPORT_SYMBOL(amd_iommu_init_device);
820
821 void amd_iommu_free_device(struct pci_dev *pdev)
822 {
823         struct device_state *dev_state;
824         unsigned long flags;
825         u16 devid;
826
827         if (!amd_iommu_v2_supported())
828                 return;
829
830         devid = device_id(pdev);
831
832         spin_lock_irqsave(&state_lock, flags);
833
834         dev_state = __get_device_state(devid);
835         if (dev_state == NULL) {
836                 spin_unlock_irqrestore(&state_lock, flags);
837                 return;
838         }
839
840         list_del(&dev_state->list);
841
842         spin_unlock_irqrestore(&state_lock, flags);
843
844         /* Get rid of any remaining pasid states */
845         free_pasid_states(dev_state);
846
847         put_device_state(dev_state);
848         /*
849          * Wait until the last reference is dropped before freeing
850          * the device state.
851          */
852         wait_event(dev_state->wq, !atomic_read(&dev_state->count));
853         free_device_state(dev_state);
854 }
855 EXPORT_SYMBOL(amd_iommu_free_device);
856
857 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
858                                  amd_iommu_invalid_ppr_cb cb)
859 {
860         struct device_state *dev_state;
861         unsigned long flags;
862         u16 devid;
863         int ret;
864
865         if (!amd_iommu_v2_supported())
866                 return -ENODEV;
867
868         devid = device_id(pdev);
869
870         spin_lock_irqsave(&state_lock, flags);
871
872         ret = -EINVAL;
873         dev_state = __get_device_state(devid);
874         if (dev_state == NULL)
875                 goto out_unlock;
876
877         dev_state->inv_ppr_cb = cb;
878
879         ret = 0;
880
881 out_unlock:
882         spin_unlock_irqrestore(&state_lock, flags);
883
884         return ret;
885 }
886 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
887
888 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
889                                     amd_iommu_invalidate_ctx cb)
890 {
891         struct device_state *dev_state;
892         unsigned long flags;
893         u16 devid;
894         int ret;
895
896         if (!amd_iommu_v2_supported())
897                 return -ENODEV;
898
899         devid = device_id(pdev);
900
901         spin_lock_irqsave(&state_lock, flags);
902
903         ret = -EINVAL;
904         dev_state = __get_device_state(devid);
905         if (dev_state == NULL)
906                 goto out_unlock;
907
908         dev_state->inv_ctx_cb = cb;
909
910         ret = 0;
911
912 out_unlock:
913         spin_unlock_irqrestore(&state_lock, flags);
914
915         return ret;
916 }
917 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
918
919 static int __init amd_iommu_v2_init(void)
920 {
921         int ret;
922
923         pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
924
925         if (!amd_iommu_v2_supported()) {
926                 pr_info("AMD IOMMUv2 functionality not available on this system\n");
927                 /*
928                  * Load anyway to provide the symbols to other modules
929                  * which may use AMD IOMMUv2 optionally.
930                  */
931                 return 0;
932         }
933
934         spin_lock_init(&state_lock);
935
936         ret = -ENOMEM;
937         iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
938         if (iommu_wq == NULL)
939                 goto out;
940
941         amd_iommu_register_ppr_notifier(&ppr_nb);
942
943         return 0;
944
945 out:
946         return ret;
947 }
948
949 static void __exit amd_iommu_v2_exit(void)
950 {
951         struct device_state *dev_state;
952         int i;
953
954         if (!amd_iommu_v2_supported())
955                 return;
956
957         amd_iommu_unregister_ppr_notifier(&ppr_nb);
958
959         flush_workqueue(iommu_wq);
960
961         /*
962          * The loop below might call flush_workqueue(), so call
963          * destroy_workqueue() after it
964          */
965         for (i = 0; i < MAX_DEVICES; ++i) {
966                 dev_state = get_device_state(i);
967
968                 if (dev_state == NULL)
969                         continue;
970
971                 WARN_ON_ONCE(1);
972
973                 put_device_state(dev_state);
974                 amd_iommu_free_device(dev_state->pdev);
975         }
976
977         destroy_workqueue(iommu_wq);
978 }
979
980 module_init(amd_iommu_v2_init);
981 module_exit(amd_iommu_v2_exit);