mm/memremap_pages: convert to 'struct range'
[sfrench/cifs-2.6.git] / drivers / dax / device.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright(c) 2016-2018 Intel Corporation. All rights reserved. */
3 #include <linux/memremap.h>
4 #include <linux/pagemap.h>
5 #include <linux/module.h>
6 #include <linux/device.h>
7 #include <linux/pfn_t.h>
8 #include <linux/cdev.h>
9 #include <linux/slab.h>
10 #include <linux/dax.h>
11 #include <linux/fs.h>
12 #include <linux/mm.h>
13 #include <linux/mman.h>
14 #include "dax-private.h"
15 #include "bus.h"
16
17 static int check_vma(struct dev_dax *dev_dax, struct vm_area_struct *vma,
18                 const char *func)
19 {
20         struct dax_region *dax_region = dev_dax->region;
21         struct device *dev = &dev_dax->dev;
22         unsigned long mask;
23
24         if (!dax_alive(dev_dax->dax_dev))
25                 return -ENXIO;
26
27         /* prevent private mappings from being established */
28         if ((vma->vm_flags & VM_MAYSHARE) != VM_MAYSHARE) {
29                 dev_info_ratelimited(dev,
30                                 "%s: %s: fail, attempted private mapping\n",
31                                 current->comm, func);
32                 return -EINVAL;
33         }
34
35         mask = dax_region->align - 1;
36         if (vma->vm_start & mask || vma->vm_end & mask) {
37                 dev_info_ratelimited(dev,
38                                 "%s: %s: fail, unaligned vma (%#lx - %#lx, %#lx)\n",
39                                 current->comm, func, vma->vm_start, vma->vm_end,
40                                 mask);
41                 return -EINVAL;
42         }
43
44         if (!vma_is_dax(vma)) {
45                 dev_info_ratelimited(dev,
46                                 "%s: %s: fail, vma is not DAX capable\n",
47                                 current->comm, func);
48                 return -EINVAL;
49         }
50
51         return 0;
52 }
53
54 /* see "strong" declaration in tools/testing/nvdimm/dax-dev.c */
55 __weak phys_addr_t dax_pgoff_to_phys(struct dev_dax *dev_dax, pgoff_t pgoff,
56                 unsigned long size)
57 {
58         struct range *range = &dev_dax->range;
59         phys_addr_t phys;
60
61         phys = pgoff * PAGE_SIZE + range->start;
62         if (phys >= range->start && phys <= range->end) {
63                 if (phys + size - 1 <= range->end)
64                         return phys;
65         }
66
67         return -1;
68 }
69
70 static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
71                                 struct vm_fault *vmf, pfn_t *pfn)
72 {
73         struct device *dev = &dev_dax->dev;
74         struct dax_region *dax_region;
75         phys_addr_t phys;
76         unsigned int fault_size = PAGE_SIZE;
77
78         if (check_vma(dev_dax, vmf->vma, __func__))
79                 return VM_FAULT_SIGBUS;
80
81         dax_region = dev_dax->region;
82         if (dax_region->align > PAGE_SIZE) {
83                 dev_dbg(dev, "alignment (%#x) > fault size (%#x)\n",
84                         dax_region->align, fault_size);
85                 return VM_FAULT_SIGBUS;
86         }
87
88         if (fault_size != dax_region->align)
89                 return VM_FAULT_SIGBUS;
90
91         phys = dax_pgoff_to_phys(dev_dax, vmf->pgoff, PAGE_SIZE);
92         if (phys == -1) {
93                 dev_dbg(dev, "pgoff_to_phys(%#lx) failed\n", vmf->pgoff);
94                 return VM_FAULT_SIGBUS;
95         }
96
97         *pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
98
99         return vmf_insert_mixed(vmf->vma, vmf->address, *pfn);
100 }
101
102 static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
103                                 struct vm_fault *vmf, pfn_t *pfn)
104 {
105         unsigned long pmd_addr = vmf->address & PMD_MASK;
106         struct device *dev = &dev_dax->dev;
107         struct dax_region *dax_region;
108         phys_addr_t phys;
109         pgoff_t pgoff;
110         unsigned int fault_size = PMD_SIZE;
111
112         if (check_vma(dev_dax, vmf->vma, __func__))
113                 return VM_FAULT_SIGBUS;
114
115         dax_region = dev_dax->region;
116         if (dax_region->align > PMD_SIZE) {
117                 dev_dbg(dev, "alignment (%#x) > fault size (%#x)\n",
118                         dax_region->align, fault_size);
119                 return VM_FAULT_SIGBUS;
120         }
121
122         if (fault_size < dax_region->align)
123                 return VM_FAULT_SIGBUS;
124         else if (fault_size > dax_region->align)
125                 return VM_FAULT_FALLBACK;
126
127         /* if we are outside of the VMA */
128         if (pmd_addr < vmf->vma->vm_start ||
129                         (pmd_addr + PMD_SIZE) > vmf->vma->vm_end)
130                 return VM_FAULT_SIGBUS;
131
132         pgoff = linear_page_index(vmf->vma, pmd_addr);
133         phys = dax_pgoff_to_phys(dev_dax, pgoff, PMD_SIZE);
134         if (phys == -1) {
135                 dev_dbg(dev, "pgoff_to_phys(%#lx) failed\n", pgoff);
136                 return VM_FAULT_SIGBUS;
137         }
138
139         *pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
140
141         return vmf_insert_pfn_pmd(vmf, *pfn, vmf->flags & FAULT_FLAG_WRITE);
142 }
143
144 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
145 static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
146                                 struct vm_fault *vmf, pfn_t *pfn)
147 {
148         unsigned long pud_addr = vmf->address & PUD_MASK;
149         struct device *dev = &dev_dax->dev;
150         struct dax_region *dax_region;
151         phys_addr_t phys;
152         pgoff_t pgoff;
153         unsigned int fault_size = PUD_SIZE;
154
155
156         if (check_vma(dev_dax, vmf->vma, __func__))
157                 return VM_FAULT_SIGBUS;
158
159         dax_region = dev_dax->region;
160         if (dax_region->align > PUD_SIZE) {
161                 dev_dbg(dev, "alignment (%#x) > fault size (%#x)\n",
162                         dax_region->align, fault_size);
163                 return VM_FAULT_SIGBUS;
164         }
165
166         if (fault_size < dax_region->align)
167                 return VM_FAULT_SIGBUS;
168         else if (fault_size > dax_region->align)
169                 return VM_FAULT_FALLBACK;
170
171         /* if we are outside of the VMA */
172         if (pud_addr < vmf->vma->vm_start ||
173                         (pud_addr + PUD_SIZE) > vmf->vma->vm_end)
174                 return VM_FAULT_SIGBUS;
175
176         pgoff = linear_page_index(vmf->vma, pud_addr);
177         phys = dax_pgoff_to_phys(dev_dax, pgoff, PUD_SIZE);
178         if (phys == -1) {
179                 dev_dbg(dev, "pgoff_to_phys(%#lx) failed\n", pgoff);
180                 return VM_FAULT_SIGBUS;
181         }
182
183         *pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
184
185         return vmf_insert_pfn_pud(vmf, *pfn, vmf->flags & FAULT_FLAG_WRITE);
186 }
187 #else
188 static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
189                                 struct vm_fault *vmf, pfn_t *pfn)
190 {
191         return VM_FAULT_FALLBACK;
192 }
193 #endif /* !CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD */
194
195 static vm_fault_t dev_dax_huge_fault(struct vm_fault *vmf,
196                 enum page_entry_size pe_size)
197 {
198         struct file *filp = vmf->vma->vm_file;
199         unsigned long fault_size;
200         vm_fault_t rc = VM_FAULT_SIGBUS;
201         int id;
202         pfn_t pfn;
203         struct dev_dax *dev_dax = filp->private_data;
204
205         dev_dbg(&dev_dax->dev, "%s: %s (%#lx - %#lx) size = %d\n", current->comm,
206                         (vmf->flags & FAULT_FLAG_WRITE) ? "write" : "read",
207                         vmf->vma->vm_start, vmf->vma->vm_end, pe_size);
208
209         id = dax_read_lock();
210         switch (pe_size) {
211         case PE_SIZE_PTE:
212                 fault_size = PAGE_SIZE;
213                 rc = __dev_dax_pte_fault(dev_dax, vmf, &pfn);
214                 break;
215         case PE_SIZE_PMD:
216                 fault_size = PMD_SIZE;
217                 rc = __dev_dax_pmd_fault(dev_dax, vmf, &pfn);
218                 break;
219         case PE_SIZE_PUD:
220                 fault_size = PUD_SIZE;
221                 rc = __dev_dax_pud_fault(dev_dax, vmf, &pfn);
222                 break;
223         default:
224                 rc = VM_FAULT_SIGBUS;
225         }
226
227         if (rc == VM_FAULT_NOPAGE) {
228                 unsigned long i;
229                 pgoff_t pgoff;
230
231                 /*
232                  * In the device-dax case the only possibility for a
233                  * VM_FAULT_NOPAGE result is when device-dax capacity is
234                  * mapped. No need to consider the zero page, or racing
235                  * conflicting mappings.
236                  */
237                 pgoff = linear_page_index(vmf->vma, vmf->address
238                                 & ~(fault_size - 1));
239                 for (i = 0; i < fault_size / PAGE_SIZE; i++) {
240                         struct page *page;
241
242                         page = pfn_to_page(pfn_t_to_pfn(pfn) + i);
243                         if (page->mapping)
244                                 continue;
245                         page->mapping = filp->f_mapping;
246                         page->index = pgoff + i;
247                 }
248         }
249         dax_read_unlock(id);
250
251         return rc;
252 }
253
254 static vm_fault_t dev_dax_fault(struct vm_fault *vmf)
255 {
256         return dev_dax_huge_fault(vmf, PE_SIZE_PTE);
257 }
258
259 static int dev_dax_split(struct vm_area_struct *vma, unsigned long addr)
260 {
261         struct file *filp = vma->vm_file;
262         struct dev_dax *dev_dax = filp->private_data;
263         struct dax_region *dax_region = dev_dax->region;
264
265         if (!IS_ALIGNED(addr, dax_region->align))
266                 return -EINVAL;
267         return 0;
268 }
269
270 static unsigned long dev_dax_pagesize(struct vm_area_struct *vma)
271 {
272         struct file *filp = vma->vm_file;
273         struct dev_dax *dev_dax = filp->private_data;
274         struct dax_region *dax_region = dev_dax->region;
275
276         return dax_region->align;
277 }
278
279 static const struct vm_operations_struct dax_vm_ops = {
280         .fault = dev_dax_fault,
281         .huge_fault = dev_dax_huge_fault,
282         .split = dev_dax_split,
283         .pagesize = dev_dax_pagesize,
284 };
285
286 static int dax_mmap(struct file *filp, struct vm_area_struct *vma)
287 {
288         struct dev_dax *dev_dax = filp->private_data;
289         int rc, id;
290
291         dev_dbg(&dev_dax->dev, "trace\n");
292
293         /*
294          * We lock to check dax_dev liveness and will re-check at
295          * fault time.
296          */
297         id = dax_read_lock();
298         rc = check_vma(dev_dax, vma, __func__);
299         dax_read_unlock(id);
300         if (rc)
301                 return rc;
302
303         vma->vm_ops = &dax_vm_ops;
304         vma->vm_flags |= VM_HUGEPAGE;
305         return 0;
306 }
307
308 /* return an unmapped area aligned to the dax region specified alignment */
309 static unsigned long dax_get_unmapped_area(struct file *filp,
310                 unsigned long addr, unsigned long len, unsigned long pgoff,
311                 unsigned long flags)
312 {
313         unsigned long off, off_end, off_align, len_align, addr_align, align;
314         struct dev_dax *dev_dax = filp ? filp->private_data : NULL;
315         struct dax_region *dax_region;
316
317         if (!dev_dax || addr)
318                 goto out;
319
320         dax_region = dev_dax->region;
321         align = dax_region->align;
322         off = pgoff << PAGE_SHIFT;
323         off_end = off + len;
324         off_align = round_up(off, align);
325
326         if ((off_end <= off_align) || ((off_end - off_align) < align))
327                 goto out;
328
329         len_align = len + align;
330         if ((off + len_align) < off)
331                 goto out;
332
333         addr_align = current->mm->get_unmapped_area(filp, addr, len_align,
334                         pgoff, flags);
335         if (!IS_ERR_VALUE(addr_align)) {
336                 addr_align += (off - addr_align) & (align - 1);
337                 return addr_align;
338         }
339  out:
340         return current->mm->get_unmapped_area(filp, addr, len, pgoff, flags);
341 }
342
343 static const struct address_space_operations dev_dax_aops = {
344         .set_page_dirty         = noop_set_page_dirty,
345         .invalidatepage         = noop_invalidatepage,
346 };
347
348 static int dax_open(struct inode *inode, struct file *filp)
349 {
350         struct dax_device *dax_dev = inode_dax(inode);
351         struct inode *__dax_inode = dax_inode(dax_dev);
352         struct dev_dax *dev_dax = dax_get_private(dax_dev);
353
354         dev_dbg(&dev_dax->dev, "trace\n");
355         inode->i_mapping = __dax_inode->i_mapping;
356         inode->i_mapping->host = __dax_inode;
357         inode->i_mapping->a_ops = &dev_dax_aops;
358         filp->f_mapping = inode->i_mapping;
359         filp->f_wb_err = filemap_sample_wb_err(filp->f_mapping);
360         filp->f_sb_err = file_sample_sb_err(filp);
361         filp->private_data = dev_dax;
362         inode->i_flags = S_DAX;
363
364         return 0;
365 }
366
367 static int dax_release(struct inode *inode, struct file *filp)
368 {
369         struct dev_dax *dev_dax = filp->private_data;
370
371         dev_dbg(&dev_dax->dev, "trace\n");
372         return 0;
373 }
374
375 static const struct file_operations dax_fops = {
376         .llseek = noop_llseek,
377         .owner = THIS_MODULE,
378         .open = dax_open,
379         .release = dax_release,
380         .get_unmapped_area = dax_get_unmapped_area,
381         .mmap = dax_mmap,
382         .mmap_supported_flags = MAP_SYNC,
383 };
384
385 static void dev_dax_cdev_del(void *cdev)
386 {
387         cdev_del(cdev);
388 }
389
390 static void dev_dax_kill(void *dev_dax)
391 {
392         kill_dev_dax(dev_dax);
393 }
394
395 int dev_dax_probe(struct dev_dax *dev_dax)
396 {
397         struct dax_device *dax_dev = dev_dax->dax_dev;
398         struct range *range = &dev_dax->range;
399         struct device *dev = &dev_dax->dev;
400         struct dev_pagemap *pgmap;
401         struct inode *inode;
402         struct cdev *cdev;
403         void *addr;
404         int rc;
405
406         /* 1:1 map region resource range to device-dax instance range */
407         if (!devm_request_mem_region(dev, range->start, range_len(range),
408                                 dev_name(dev))) {
409                 dev_warn(dev, "could not reserve range: %#llx - %#llx\n",
410                                 range->start, range->end);
411                 return -EBUSY;
412         }
413
414         pgmap = dev_dax->pgmap;
415         if (!pgmap) {
416                 pgmap = devm_kzalloc(dev, sizeof(*pgmap), GFP_KERNEL);
417                 if (!pgmap)
418                         return -ENOMEM;
419                 pgmap->range = *range;
420         }
421         pgmap->type = MEMORY_DEVICE_GENERIC;
422         addr = devm_memremap_pages(dev, pgmap);
423         if (IS_ERR(addr))
424                 return PTR_ERR(addr);
425
426         inode = dax_inode(dax_dev);
427         cdev = inode->i_cdev;
428         cdev_init(cdev, &dax_fops);
429         if (dev->class) {
430                 /* for the CONFIG_DEV_DAX_PMEM_COMPAT case */
431                 cdev->owner = dev->parent->driver->owner;
432         } else
433                 cdev->owner = dev->driver->owner;
434         cdev_set_parent(cdev, &dev->kobj);
435         rc = cdev_add(cdev, dev->devt, 1);
436         if (rc)
437                 return rc;
438
439         rc = devm_add_action_or_reset(dev, dev_dax_cdev_del, cdev);
440         if (rc)
441                 return rc;
442
443         run_dax(dax_dev);
444         return devm_add_action_or_reset(dev, dev_dax_kill, dev_dax);
445 }
446 EXPORT_SYMBOL_GPL(dev_dax_probe);
447
448 static int dev_dax_remove(struct dev_dax *dev_dax)
449 {
450         /* all probe actions are unwound by devm */
451         return 0;
452 }
453
454 static struct dax_device_driver device_dax_driver = {
455         .probe = dev_dax_probe,
456         .remove = dev_dax_remove,
457         .match_always = 1,
458 };
459
460 static int __init dax_init(void)
461 {
462         return dax_driver_register(&device_dax_driver);
463 }
464
465 static void __exit dax_exit(void)
466 {
467         dax_driver_unregister(&device_dax_driver);
468 }
469
470 MODULE_AUTHOR("Intel Corporation");
471 MODULE_LICENSE("GPL v2");
472 module_init(dax_init);
473 module_exit(dax_exit);
474 MODULE_ALIAS_DAX_DEVICE(0);