MIPS: mark __fls() and __ffs() as __always_inline
[sfrench/cifs-2.6.git] / drivers / iommu / amd_iommu_v2.c
1 /*
2  * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
3  * Author: Joerg Roedel <jroedel@suse.de>
4  *
5  * This program is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 as published
7  * by the Free Software Foundation.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17  */
18
19 #define pr_fmt(fmt)     "AMD-Vi: " fmt
20
21 #include <linux/mmu_notifier.h>
22 #include <linux/amd-iommu.h>
23 #include <linux/mm_types.h>
24 #include <linux/profile.h>
25 #include <linux/module.h>
26 #include <linux/sched.h>
27 #include <linux/sched/mm.h>
28 #include <linux/iommu.h>
29 #include <linux/wait.h>
30 #include <linux/pci.h>
31 #include <linux/gfp.h>
32
33 #include "amd_iommu_types.h"
34 #include "amd_iommu_proto.h"
35
36 MODULE_LICENSE("GPL v2");
37 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
38
39 #define MAX_DEVICES             0x10000
40 #define PRI_QUEUE_SIZE          512
41
42 struct pri_queue {
43         atomic_t inflight;
44         bool finish;
45         int status;
46 };
47
48 struct pasid_state {
49         struct list_head list;                  /* For global state-list */
50         atomic_t count;                         /* Reference count */
51         unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
52                                                    calls */
53         struct mm_struct *mm;                   /* mm_struct for the faults */
54         struct mmu_notifier mn;                 /* mmu_notifier handle */
55         struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
56         struct device_state *device_state;      /* Link to our device_state */
57         int pasid;                              /* PASID index */
58         bool invalid;                           /* Used during setup and
59                                                    teardown of the pasid */
60         spinlock_t lock;                        /* Protect pri_queues and
61                                                    mmu_notifer_count */
62         wait_queue_head_t wq;                   /* To wait for count == 0 */
63 };
64
65 struct device_state {
66         struct list_head list;
67         u16 devid;
68         atomic_t count;
69         struct pci_dev *pdev;
70         struct pasid_state **states;
71         struct iommu_domain *domain;
72         int pasid_levels;
73         int max_pasids;
74         amd_iommu_invalid_ppr_cb inv_ppr_cb;
75         amd_iommu_invalidate_ctx inv_ctx_cb;
76         spinlock_t lock;
77         wait_queue_head_t wq;
78 };
79
80 struct fault {
81         struct work_struct work;
82         struct device_state *dev_state;
83         struct pasid_state *state;
84         struct mm_struct *mm;
85         u64 address;
86         u16 devid;
87         u16 pasid;
88         u16 tag;
89         u16 finish;
90         u16 flags;
91 };
92
93 static LIST_HEAD(state_list);
94 static spinlock_t state_lock;
95
96 static struct workqueue_struct *iommu_wq;
97
98 static void free_pasid_states(struct device_state *dev_state);
99
100 static u16 device_id(struct pci_dev *pdev)
101 {
102         u16 devid;
103
104         devid = pdev->bus->number;
105         devid = (devid << 8) | pdev->devfn;
106
107         return devid;
108 }
109
110 static struct device_state *__get_device_state(u16 devid)
111 {
112         struct device_state *dev_state;
113
114         list_for_each_entry(dev_state, &state_list, list) {
115                 if (dev_state->devid == devid)
116                         return dev_state;
117         }
118
119         return NULL;
120 }
121
122 static struct device_state *get_device_state(u16 devid)
123 {
124         struct device_state *dev_state;
125         unsigned long flags;
126
127         spin_lock_irqsave(&state_lock, flags);
128         dev_state = __get_device_state(devid);
129         if (dev_state != NULL)
130                 atomic_inc(&dev_state->count);
131         spin_unlock_irqrestore(&state_lock, flags);
132
133         return dev_state;
134 }
135
136 static void free_device_state(struct device_state *dev_state)
137 {
138         struct iommu_group *group;
139
140         /*
141          * First detach device from domain - No more PRI requests will arrive
142          * from that device after it is unbound from the IOMMUv2 domain.
143          */
144         group = iommu_group_get(&dev_state->pdev->dev);
145         if (WARN_ON(!group))
146                 return;
147
148         iommu_detach_group(dev_state->domain, group);
149
150         iommu_group_put(group);
151
152         /* Everything is down now, free the IOMMUv2 domain */
153         iommu_domain_free(dev_state->domain);
154
155         /* Finally get rid of the device-state */
156         kfree(dev_state);
157 }
158
159 static void put_device_state(struct device_state *dev_state)
160 {
161         if (atomic_dec_and_test(&dev_state->count))
162                 wake_up(&dev_state->wq);
163 }
164
165 /* Must be called under dev_state->lock */
166 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
167                                                   int pasid, bool alloc)
168 {
169         struct pasid_state **root, **ptr;
170         int level, index;
171
172         level = dev_state->pasid_levels;
173         root  = dev_state->states;
174
175         while (true) {
176
177                 index = (pasid >> (9 * level)) & 0x1ff;
178                 ptr   = &root[index];
179
180                 if (level == 0)
181                         break;
182
183                 if (*ptr == NULL) {
184                         if (!alloc)
185                                 return NULL;
186
187                         *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
188                         if (*ptr == NULL)
189                                 return NULL;
190                 }
191
192                 root   = (struct pasid_state **)*ptr;
193                 level -= 1;
194         }
195
196         return ptr;
197 }
198
199 static int set_pasid_state(struct device_state *dev_state,
200                            struct pasid_state *pasid_state,
201                            int pasid)
202 {
203         struct pasid_state **ptr;
204         unsigned long flags;
205         int ret;
206
207         spin_lock_irqsave(&dev_state->lock, flags);
208         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
209
210         ret = -ENOMEM;
211         if (ptr == NULL)
212                 goto out_unlock;
213
214         ret = -ENOMEM;
215         if (*ptr != NULL)
216                 goto out_unlock;
217
218         *ptr = pasid_state;
219
220         ret = 0;
221
222 out_unlock:
223         spin_unlock_irqrestore(&dev_state->lock, flags);
224
225         return ret;
226 }
227
228 static void clear_pasid_state(struct device_state *dev_state, int pasid)
229 {
230         struct pasid_state **ptr;
231         unsigned long flags;
232
233         spin_lock_irqsave(&dev_state->lock, flags);
234         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
235
236         if (ptr == NULL)
237                 goto out_unlock;
238
239         *ptr = NULL;
240
241 out_unlock:
242         spin_unlock_irqrestore(&dev_state->lock, flags);
243 }
244
245 static struct pasid_state *get_pasid_state(struct device_state *dev_state,
246                                            int pasid)
247 {
248         struct pasid_state **ptr, *ret = NULL;
249         unsigned long flags;
250
251         spin_lock_irqsave(&dev_state->lock, flags);
252         ptr = __get_pasid_state_ptr(dev_state, pasid, false);
253
254         if (ptr == NULL)
255                 goto out_unlock;
256
257         ret = *ptr;
258         if (ret)
259                 atomic_inc(&ret->count);
260
261 out_unlock:
262         spin_unlock_irqrestore(&dev_state->lock, flags);
263
264         return ret;
265 }
266
267 static void free_pasid_state(struct pasid_state *pasid_state)
268 {
269         kfree(pasid_state);
270 }
271
272 static void put_pasid_state(struct pasid_state *pasid_state)
273 {
274         if (atomic_dec_and_test(&pasid_state->count))
275                 wake_up(&pasid_state->wq);
276 }
277
278 static void put_pasid_state_wait(struct pasid_state *pasid_state)
279 {
280         atomic_dec(&pasid_state->count);
281         wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
282         free_pasid_state(pasid_state);
283 }
284
285 static void unbind_pasid(struct pasid_state *pasid_state)
286 {
287         struct iommu_domain *domain;
288
289         domain = pasid_state->device_state->domain;
290
291         /*
292          * Mark pasid_state as invalid, no more faults will we added to the
293          * work queue after this is visible everywhere.
294          */
295         pasid_state->invalid = true;
296
297         /* Make sure this is visible */
298         smp_wmb();
299
300         /* After this the device/pasid can't access the mm anymore */
301         amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
302
303         /* Make sure no more pending faults are in the queue */
304         flush_workqueue(iommu_wq);
305 }
306
307 static void free_pasid_states_level1(struct pasid_state **tbl)
308 {
309         int i;
310
311         for (i = 0; i < 512; ++i) {
312                 if (tbl[i] == NULL)
313                         continue;
314
315                 free_page((unsigned long)tbl[i]);
316         }
317 }
318
319 static void free_pasid_states_level2(struct pasid_state **tbl)
320 {
321         struct pasid_state **ptr;
322         int i;
323
324         for (i = 0; i < 512; ++i) {
325                 if (tbl[i] == NULL)
326                         continue;
327
328                 ptr = (struct pasid_state **)tbl[i];
329                 free_pasid_states_level1(ptr);
330         }
331 }
332
333 static void free_pasid_states(struct device_state *dev_state)
334 {
335         struct pasid_state *pasid_state;
336         int i;
337
338         for (i = 0; i < dev_state->max_pasids; ++i) {
339                 pasid_state = get_pasid_state(dev_state, i);
340                 if (pasid_state == NULL)
341                         continue;
342
343                 put_pasid_state(pasid_state);
344
345                 /*
346                  * This will call the mn_release function and
347                  * unbind the PASID
348                  */
349                 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
350
351                 put_pasid_state_wait(pasid_state); /* Reference taken in
352                                                       amd_iommu_bind_pasid */
353
354                 /* Drop reference taken in amd_iommu_bind_pasid */
355                 put_device_state(dev_state);
356         }
357
358         if (dev_state->pasid_levels == 2)
359                 free_pasid_states_level2(dev_state->states);
360         else if (dev_state->pasid_levels == 1)
361                 free_pasid_states_level1(dev_state->states);
362         else
363                 BUG_ON(dev_state->pasid_levels != 0);
364
365         free_page((unsigned long)dev_state->states);
366 }
367
368 static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
369 {
370         return container_of(mn, struct pasid_state, mn);
371 }
372
373 static void mn_invalidate_range(struct mmu_notifier *mn,
374                                 struct mm_struct *mm,
375                                 unsigned long start, unsigned long end)
376 {
377         struct pasid_state *pasid_state;
378         struct device_state *dev_state;
379
380         pasid_state = mn_to_state(mn);
381         dev_state   = pasid_state->device_state;
382
383         if ((start ^ (end - 1)) < PAGE_SIZE)
384                 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
385                                      start);
386         else
387                 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
388 }
389
390 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
391 {
392         struct pasid_state *pasid_state;
393         struct device_state *dev_state;
394         bool run_inv_ctx_cb;
395
396         might_sleep();
397
398         pasid_state    = mn_to_state(mn);
399         dev_state      = pasid_state->device_state;
400         run_inv_ctx_cb = !pasid_state->invalid;
401
402         if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
403                 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
404
405         unbind_pasid(pasid_state);
406 }
407
408 static const struct mmu_notifier_ops iommu_mn = {
409         .release                = mn_release,
410         .invalidate_range       = mn_invalidate_range,
411 };
412
413 static void set_pri_tag_status(struct pasid_state *pasid_state,
414                                u16 tag, int status)
415 {
416         unsigned long flags;
417
418         spin_lock_irqsave(&pasid_state->lock, flags);
419         pasid_state->pri[tag].status = status;
420         spin_unlock_irqrestore(&pasid_state->lock, flags);
421 }
422
423 static void finish_pri_tag(struct device_state *dev_state,
424                            struct pasid_state *pasid_state,
425                            u16 tag)
426 {
427         unsigned long flags;
428
429         spin_lock_irqsave(&pasid_state->lock, flags);
430         if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
431             pasid_state->pri[tag].finish) {
432                 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
433                                        pasid_state->pri[tag].status, tag);
434                 pasid_state->pri[tag].finish = false;
435                 pasid_state->pri[tag].status = PPR_SUCCESS;
436         }
437         spin_unlock_irqrestore(&pasid_state->lock, flags);
438 }
439
440 static void handle_fault_error(struct fault *fault)
441 {
442         int status;
443
444         if (!fault->dev_state->inv_ppr_cb) {
445                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
446                 return;
447         }
448
449         status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
450                                               fault->pasid,
451                                               fault->address,
452                                               fault->flags);
453         switch (status) {
454         case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
455                 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
456                 break;
457         case AMD_IOMMU_INV_PRI_RSP_INVALID:
458                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
459                 break;
460         case AMD_IOMMU_INV_PRI_RSP_FAIL:
461                 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
462                 break;
463         default:
464                 BUG();
465         }
466 }
467
468 static bool access_error(struct vm_area_struct *vma, struct fault *fault)
469 {
470         unsigned long requested = 0;
471
472         if (fault->flags & PPR_FAULT_EXEC)
473                 requested |= VM_EXEC;
474
475         if (fault->flags & PPR_FAULT_READ)
476                 requested |= VM_READ;
477
478         if (fault->flags & PPR_FAULT_WRITE)
479                 requested |= VM_WRITE;
480
481         return (requested & ~vma->vm_flags) != 0;
482 }
483
484 static void do_fault(struct work_struct *work)
485 {
486         struct fault *fault = container_of(work, struct fault, work);
487         struct vm_area_struct *vma;
488         vm_fault_t ret = VM_FAULT_ERROR;
489         unsigned int flags = 0;
490         struct mm_struct *mm;
491         u64 address;
492
493         mm = fault->state->mm;
494         address = fault->address;
495
496         if (fault->flags & PPR_FAULT_USER)
497                 flags |= FAULT_FLAG_USER;
498         if (fault->flags & PPR_FAULT_WRITE)
499                 flags |= FAULT_FLAG_WRITE;
500         flags |= FAULT_FLAG_REMOTE;
501
502         down_read(&mm->mmap_sem);
503         vma = find_extend_vma(mm, address);
504         if (!vma || address < vma->vm_start)
505                 /* failed to get a vma in the right range */
506                 goto out;
507
508         /* Check if we have the right permissions on the vma */
509         if (access_error(vma, fault))
510                 goto out;
511
512         ret = handle_mm_fault(vma, address, flags);
513 out:
514         up_read(&mm->mmap_sem);
515
516         if (ret & VM_FAULT_ERROR)
517                 /* failed to service fault */
518                 handle_fault_error(fault);
519
520         finish_pri_tag(fault->dev_state, fault->state, fault->tag);
521
522         put_pasid_state(fault->state);
523
524         kfree(fault);
525 }
526
527 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
528 {
529         struct amd_iommu_fault *iommu_fault;
530         struct pasid_state *pasid_state;
531         struct device_state *dev_state;
532         unsigned long flags;
533         struct fault *fault;
534         bool finish;
535         u16 tag, devid;
536         int ret;
537         struct iommu_dev_data *dev_data;
538         struct pci_dev *pdev = NULL;
539
540         iommu_fault = data;
541         tag         = iommu_fault->tag & 0x1ff;
542         finish      = (iommu_fault->tag >> 9) & 1;
543
544         devid = iommu_fault->device_id;
545         pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
546                                            devid & 0xff);
547         if (!pdev)
548                 return -ENODEV;
549         dev_data = get_dev_data(&pdev->dev);
550
551         /* In kdump kernel pci dev is not initialized yet -> send INVALID */
552         ret = NOTIFY_DONE;
553         if (translation_pre_enabled(amd_iommu_rlookup_table[devid])
554                 && dev_data->defer_attach) {
555                 amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
556                                        PPR_INVALID, tag);
557                 goto out;
558         }
559
560         dev_state = get_device_state(iommu_fault->device_id);
561         if (dev_state == NULL)
562                 goto out;
563
564         pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
565         if (pasid_state == NULL || pasid_state->invalid) {
566                 /* We know the device but not the PASID -> send INVALID */
567                 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
568                                        PPR_INVALID, tag);
569                 goto out_drop_state;
570         }
571
572         spin_lock_irqsave(&pasid_state->lock, flags);
573         atomic_inc(&pasid_state->pri[tag].inflight);
574         if (finish)
575                 pasid_state->pri[tag].finish = true;
576         spin_unlock_irqrestore(&pasid_state->lock, flags);
577
578         fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
579         if (fault == NULL) {
580                 /* We are OOM - send success and let the device re-fault */
581                 finish_pri_tag(dev_state, pasid_state, tag);
582                 goto out_drop_state;
583         }
584
585         fault->dev_state = dev_state;
586         fault->address   = iommu_fault->address;
587         fault->state     = pasid_state;
588         fault->tag       = tag;
589         fault->finish    = finish;
590         fault->pasid     = iommu_fault->pasid;
591         fault->flags     = iommu_fault->flags;
592         INIT_WORK(&fault->work, do_fault);
593
594         queue_work(iommu_wq, &fault->work);
595
596         ret = NOTIFY_OK;
597
598 out_drop_state:
599
600         if (ret != NOTIFY_OK && pasid_state)
601                 put_pasid_state(pasid_state);
602
603         put_device_state(dev_state);
604
605 out:
606         return ret;
607 }
608
609 static struct notifier_block ppr_nb = {
610         .notifier_call = ppr_notifier,
611 };
612
613 int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
614                          struct task_struct *task)
615 {
616         struct pasid_state *pasid_state;
617         struct device_state *dev_state;
618         struct mm_struct *mm;
619         u16 devid;
620         int ret;
621
622         might_sleep();
623
624         if (!amd_iommu_v2_supported())
625                 return -ENODEV;
626
627         devid     = device_id(pdev);
628         dev_state = get_device_state(devid);
629
630         if (dev_state == NULL)
631                 return -EINVAL;
632
633         ret = -EINVAL;
634         if (pasid < 0 || pasid >= dev_state->max_pasids)
635                 goto out;
636
637         ret = -ENOMEM;
638         pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
639         if (pasid_state == NULL)
640                 goto out;
641
642
643         atomic_set(&pasid_state->count, 1);
644         init_waitqueue_head(&pasid_state->wq);
645         spin_lock_init(&pasid_state->lock);
646
647         mm                        = get_task_mm(task);
648         pasid_state->mm           = mm;
649         pasid_state->device_state = dev_state;
650         pasid_state->pasid        = pasid;
651         pasid_state->invalid      = true; /* Mark as valid only if we are
652                                              done with setting up the pasid */
653         pasid_state->mn.ops       = &iommu_mn;
654
655         if (pasid_state->mm == NULL)
656                 goto out_free;
657
658         mmu_notifier_register(&pasid_state->mn, mm);
659
660         ret = set_pasid_state(dev_state, pasid_state, pasid);
661         if (ret)
662                 goto out_unregister;
663
664         ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
665                                         __pa(pasid_state->mm->pgd));
666         if (ret)
667                 goto out_clear_state;
668
669         /* Now we are ready to handle faults */
670         pasid_state->invalid = false;
671
672         /*
673          * Drop the reference to the mm_struct here. We rely on the
674          * mmu_notifier release call-back to inform us when the mm
675          * is going away.
676          */
677         mmput(mm);
678
679         return 0;
680
681 out_clear_state:
682         clear_pasid_state(dev_state, pasid);
683
684 out_unregister:
685         mmu_notifier_unregister(&pasid_state->mn, mm);
686         mmput(mm);
687
688 out_free:
689         free_pasid_state(pasid_state);
690
691 out:
692         put_device_state(dev_state);
693
694         return ret;
695 }
696 EXPORT_SYMBOL(amd_iommu_bind_pasid);
697
698 void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
699 {
700         struct pasid_state *pasid_state;
701         struct device_state *dev_state;
702         u16 devid;
703
704         might_sleep();
705
706         if (!amd_iommu_v2_supported())
707                 return;
708
709         devid = device_id(pdev);
710         dev_state = get_device_state(devid);
711         if (dev_state == NULL)
712                 return;
713
714         if (pasid < 0 || pasid >= dev_state->max_pasids)
715                 goto out;
716
717         pasid_state = get_pasid_state(dev_state, pasid);
718         if (pasid_state == NULL)
719                 goto out;
720         /*
721          * Drop reference taken here. We are safe because we still hold
722          * the reference taken in the amd_iommu_bind_pasid function.
723          */
724         put_pasid_state(pasid_state);
725
726         /* Clear the pasid state so that the pasid can be re-used */
727         clear_pasid_state(dev_state, pasid_state->pasid);
728
729         /*
730          * Call mmu_notifier_unregister to drop our reference
731          * to pasid_state->mm
732          */
733         mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
734
735         put_pasid_state_wait(pasid_state); /* Reference taken in
736                                               amd_iommu_bind_pasid */
737 out:
738         /* Drop reference taken in this function */
739         put_device_state(dev_state);
740
741         /* Drop reference taken in amd_iommu_bind_pasid */
742         put_device_state(dev_state);
743 }
744 EXPORT_SYMBOL(amd_iommu_unbind_pasid);
745
746 int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
747 {
748         struct device_state *dev_state;
749         struct iommu_group *group;
750         unsigned long flags;
751         int ret, tmp;
752         u16 devid;
753
754         might_sleep();
755
756         if (!amd_iommu_v2_supported())
757                 return -ENODEV;
758
759         if (pasids <= 0 || pasids > (PASID_MASK + 1))
760                 return -EINVAL;
761
762         devid = device_id(pdev);
763
764         dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
765         if (dev_state == NULL)
766                 return -ENOMEM;
767
768         spin_lock_init(&dev_state->lock);
769         init_waitqueue_head(&dev_state->wq);
770         dev_state->pdev  = pdev;
771         dev_state->devid = devid;
772
773         tmp = pasids;
774         for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
775                 dev_state->pasid_levels += 1;
776
777         atomic_set(&dev_state->count, 1);
778         dev_state->max_pasids = pasids;
779
780         ret = -ENOMEM;
781         dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
782         if (dev_state->states == NULL)
783                 goto out_free_dev_state;
784
785         dev_state->domain = iommu_domain_alloc(&pci_bus_type);
786         if (dev_state->domain == NULL)
787                 goto out_free_states;
788
789         amd_iommu_domain_direct_map(dev_state->domain);
790
791         ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
792         if (ret)
793                 goto out_free_domain;
794
795         group = iommu_group_get(&pdev->dev);
796         if (!group) {
797                 ret = -EINVAL;
798                 goto out_free_domain;
799         }
800
801         ret = iommu_attach_group(dev_state->domain, group);
802         if (ret != 0)
803                 goto out_drop_group;
804
805         iommu_group_put(group);
806
807         spin_lock_irqsave(&state_lock, flags);
808
809         if (__get_device_state(devid) != NULL) {
810                 spin_unlock_irqrestore(&state_lock, flags);
811                 ret = -EBUSY;
812                 goto out_free_domain;
813         }
814
815         list_add_tail(&dev_state->list, &state_list);
816
817         spin_unlock_irqrestore(&state_lock, flags);
818
819         return 0;
820
821 out_drop_group:
822         iommu_group_put(group);
823
824 out_free_domain:
825         iommu_domain_free(dev_state->domain);
826
827 out_free_states:
828         free_page((unsigned long)dev_state->states);
829
830 out_free_dev_state:
831         kfree(dev_state);
832
833         return ret;
834 }
835 EXPORT_SYMBOL(amd_iommu_init_device);
836
837 void amd_iommu_free_device(struct pci_dev *pdev)
838 {
839         struct device_state *dev_state;
840         unsigned long flags;
841         u16 devid;
842
843         if (!amd_iommu_v2_supported())
844                 return;
845
846         devid = device_id(pdev);
847
848         spin_lock_irqsave(&state_lock, flags);
849
850         dev_state = __get_device_state(devid);
851         if (dev_state == NULL) {
852                 spin_unlock_irqrestore(&state_lock, flags);
853                 return;
854         }
855
856         list_del(&dev_state->list);
857
858         spin_unlock_irqrestore(&state_lock, flags);
859
860         /* Get rid of any remaining pasid states */
861         free_pasid_states(dev_state);
862
863         put_device_state(dev_state);
864         /*
865          * Wait until the last reference is dropped before freeing
866          * the device state.
867          */
868         wait_event(dev_state->wq, !atomic_read(&dev_state->count));
869         free_device_state(dev_state);
870 }
871 EXPORT_SYMBOL(amd_iommu_free_device);
872
873 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
874                                  amd_iommu_invalid_ppr_cb cb)
875 {
876         struct device_state *dev_state;
877         unsigned long flags;
878         u16 devid;
879         int ret;
880
881         if (!amd_iommu_v2_supported())
882                 return -ENODEV;
883
884         devid = device_id(pdev);
885
886         spin_lock_irqsave(&state_lock, flags);
887
888         ret = -EINVAL;
889         dev_state = __get_device_state(devid);
890         if (dev_state == NULL)
891                 goto out_unlock;
892
893         dev_state->inv_ppr_cb = cb;
894
895         ret = 0;
896
897 out_unlock:
898         spin_unlock_irqrestore(&state_lock, flags);
899
900         return ret;
901 }
902 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
903
904 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
905                                     amd_iommu_invalidate_ctx cb)
906 {
907         struct device_state *dev_state;
908         unsigned long flags;
909         u16 devid;
910         int ret;
911
912         if (!amd_iommu_v2_supported())
913                 return -ENODEV;
914
915         devid = device_id(pdev);
916
917         spin_lock_irqsave(&state_lock, flags);
918
919         ret = -EINVAL;
920         dev_state = __get_device_state(devid);
921         if (dev_state == NULL)
922                 goto out_unlock;
923
924         dev_state->inv_ctx_cb = cb;
925
926         ret = 0;
927
928 out_unlock:
929         spin_unlock_irqrestore(&state_lock, flags);
930
931         return ret;
932 }
933 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
934
935 static int __init amd_iommu_v2_init(void)
936 {
937         int ret;
938
939         pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
940
941         if (!amd_iommu_v2_supported()) {
942                 pr_info("AMD IOMMUv2 functionality not available on this system\n");
943                 /*
944                  * Load anyway to provide the symbols to other modules
945                  * which may use AMD IOMMUv2 optionally.
946                  */
947                 return 0;
948         }
949
950         spin_lock_init(&state_lock);
951
952         ret = -ENOMEM;
953         iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
954         if (iommu_wq == NULL)
955                 goto out;
956
957         amd_iommu_register_ppr_notifier(&ppr_nb);
958
959         return 0;
960
961 out:
962         return ret;
963 }
964
965 static void __exit amd_iommu_v2_exit(void)
966 {
967         struct device_state *dev_state;
968         int i;
969
970         if (!amd_iommu_v2_supported())
971                 return;
972
973         amd_iommu_unregister_ppr_notifier(&ppr_nb);
974
975         flush_workqueue(iommu_wq);
976
977         /*
978          * The loop below might call flush_workqueue(), so call
979          * destroy_workqueue() after it
980          */
981         for (i = 0; i < MAX_DEVICES; ++i) {
982                 dev_state = get_device_state(i);
983
984                 if (dev_state == NULL)
985                         continue;
986
987                 WARN_ON_ONCE(1);
988
989                 put_device_state(dev_state);
990                 amd_iommu_free_device(dev_state->pdev);
991         }
992
993         destroy_workqueue(iommu_wq);
994 }
995
996 module_init(amd_iommu_v2_init);
997 module_exit(amd_iommu_v2_exit);