vfio/mdev: Synchronize device create/remove with parent removal
[sfrench/cifs-2.6.git] / drivers / vfio / mdev / mdev_core.c
1 /*
2  * Mediated device Core Driver
3  *
4  * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
5  *     Author: Neo Jia <cjia@nvidia.com>
6  *             Kirti Wankhede <kwankhede@nvidia.com>
7  *
8  * This program is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License version 2 as
10  * published by the Free Software Foundation.
11  */
12
13 #include <linux/module.h>
14 #include <linux/device.h>
15 #include <linux/slab.h>
16 #include <linux/uuid.h>
17 #include <linux/sysfs.h>
18 #include <linux/mdev.h>
19
20 #include "mdev_private.h"
21
22 #define DRIVER_VERSION          "0.1"
23 #define DRIVER_AUTHOR           "NVIDIA Corporation"
24 #define DRIVER_DESC             "Mediated device Core Driver"
25
26 static LIST_HEAD(parent_list);
27 static DEFINE_MUTEX(parent_list_lock);
28 static struct class_compat *mdev_bus_compat_class;
29
30 static LIST_HEAD(mdev_list);
31 static DEFINE_MUTEX(mdev_list_lock);
32
33 struct device *mdev_parent_dev(struct mdev_device *mdev)
34 {
35         return mdev->parent->dev;
36 }
37 EXPORT_SYMBOL(mdev_parent_dev);
38
39 void *mdev_get_drvdata(struct mdev_device *mdev)
40 {
41         return mdev->driver_data;
42 }
43 EXPORT_SYMBOL(mdev_get_drvdata);
44
45 void mdev_set_drvdata(struct mdev_device *mdev, void *data)
46 {
47         mdev->driver_data = data;
48 }
49 EXPORT_SYMBOL(mdev_set_drvdata);
50
51 struct device *mdev_dev(struct mdev_device *mdev)
52 {
53         return &mdev->dev;
54 }
55 EXPORT_SYMBOL(mdev_dev);
56
57 struct mdev_device *mdev_from_dev(struct device *dev)
58 {
59         return dev_is_mdev(dev) ? to_mdev_device(dev) : NULL;
60 }
61 EXPORT_SYMBOL(mdev_from_dev);
62
63 const guid_t *mdev_uuid(struct mdev_device *mdev)
64 {
65         return &mdev->uuid;
66 }
67 EXPORT_SYMBOL(mdev_uuid);
68
69 /* Should be called holding parent_list_lock */
70 static struct mdev_parent *__find_parent_device(struct device *dev)
71 {
72         struct mdev_parent *parent;
73
74         list_for_each_entry(parent, &parent_list, next) {
75                 if (parent->dev == dev)
76                         return parent;
77         }
78         return NULL;
79 }
80
81 static void mdev_release_parent(struct kref *kref)
82 {
83         struct mdev_parent *parent = container_of(kref, struct mdev_parent,
84                                                   ref);
85         struct device *dev = parent->dev;
86
87         kfree(parent);
88         put_device(dev);
89 }
90
91 static struct mdev_parent *mdev_get_parent(struct mdev_parent *parent)
92 {
93         if (parent)
94                 kref_get(&parent->ref);
95
96         return parent;
97 }
98
99 static void mdev_put_parent(struct mdev_parent *parent)
100 {
101         if (parent)
102                 kref_put(&parent->ref, mdev_release_parent);
103 }
104
105 /* Caller must hold parent unreg_sem read or write lock */
106 static void mdev_device_remove_common(struct mdev_device *mdev)
107 {
108         struct mdev_parent *parent;
109         struct mdev_type *type;
110         int ret;
111
112         type = to_mdev_type(mdev->type_kobj);
113         mdev_remove_sysfs_files(&mdev->dev, type);
114         device_del(&mdev->dev);
115         parent = mdev->parent;
116         lockdep_assert_held(&parent->unreg_sem);
117         ret = parent->ops->remove(mdev);
118         if (ret)
119                 dev_err(&mdev->dev, "Remove failed: err=%d\n", ret);
120
121         /* Balances with device_initialize() */
122         put_device(&mdev->dev);
123         mdev_put_parent(parent);
124 }
125
126 static int mdev_device_remove_cb(struct device *dev, void *data)
127 {
128         if (dev_is_mdev(dev)) {
129                 struct mdev_device *mdev;
130
131                 mdev = to_mdev_device(dev);
132                 mdev_device_remove_common(mdev);
133         }
134         return 0;
135 }
136
137 /*
138  * mdev_register_device : Register a device
139  * @dev: device structure representing parent device.
140  * @ops: Parent device operation structure to be registered.
141  *
142  * Add device to list of registered parent devices.
143  * Returns a negative value on error, otherwise 0.
144  */
145 int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops)
146 {
147         int ret;
148         struct mdev_parent *parent;
149
150         /* check for mandatory ops */
151         if (!ops || !ops->create || !ops->remove || !ops->supported_type_groups)
152                 return -EINVAL;
153
154         dev = get_device(dev);
155         if (!dev)
156                 return -EINVAL;
157
158         mutex_lock(&parent_list_lock);
159
160         /* Check for duplicate */
161         parent = __find_parent_device(dev);
162         if (parent) {
163                 parent = NULL;
164                 ret = -EEXIST;
165                 goto add_dev_err;
166         }
167
168         parent = kzalloc(sizeof(*parent), GFP_KERNEL);
169         if (!parent) {
170                 ret = -ENOMEM;
171                 goto add_dev_err;
172         }
173
174         kref_init(&parent->ref);
175         init_rwsem(&parent->unreg_sem);
176
177         parent->dev = dev;
178         parent->ops = ops;
179
180         if (!mdev_bus_compat_class) {
181                 mdev_bus_compat_class = class_compat_register("mdev_bus");
182                 if (!mdev_bus_compat_class) {
183                         ret = -ENOMEM;
184                         goto add_dev_err;
185                 }
186         }
187
188         ret = parent_create_sysfs_files(parent);
189         if (ret)
190                 goto add_dev_err;
191
192         ret = class_compat_create_link(mdev_bus_compat_class, dev, NULL);
193         if (ret)
194                 dev_warn(dev, "Failed to create compatibility class link\n");
195
196         list_add(&parent->next, &parent_list);
197         mutex_unlock(&parent_list_lock);
198
199         dev_info(dev, "MDEV: Registered\n");
200         return 0;
201
202 add_dev_err:
203         mutex_unlock(&parent_list_lock);
204         if (parent)
205                 mdev_put_parent(parent);
206         else
207                 put_device(dev);
208         return ret;
209 }
210 EXPORT_SYMBOL(mdev_register_device);
211
212 /*
213  * mdev_unregister_device : Unregister a parent device
214  * @dev: device structure representing parent device.
215  *
216  * Remove device from list of registered parent devices. Give a chance to free
217  * existing mediated devices for given device.
218  */
219
220 void mdev_unregister_device(struct device *dev)
221 {
222         struct mdev_parent *parent;
223
224         mutex_lock(&parent_list_lock);
225         parent = __find_parent_device(dev);
226
227         if (!parent) {
228                 mutex_unlock(&parent_list_lock);
229                 return;
230         }
231         dev_info(dev, "MDEV: Unregistering\n");
232
233         list_del(&parent->next);
234         mutex_unlock(&parent_list_lock);
235
236         down_write(&parent->unreg_sem);
237
238         class_compat_remove_link(mdev_bus_compat_class, dev, NULL);
239
240         device_for_each_child(dev, NULL, mdev_device_remove_cb);
241
242         parent_remove_sysfs_files(parent);
243         up_write(&parent->unreg_sem);
244
245         mdev_put_parent(parent);
246 }
247 EXPORT_SYMBOL(mdev_unregister_device);
248
249 static void mdev_device_free(struct mdev_device *mdev)
250 {
251         mutex_lock(&mdev_list_lock);
252         list_del(&mdev->next);
253         mutex_unlock(&mdev_list_lock);
254
255         dev_dbg(&mdev->dev, "MDEV: destroying\n");
256         kfree(mdev);
257 }
258
259 static void mdev_device_release(struct device *dev)
260 {
261         struct mdev_device *mdev = to_mdev_device(dev);
262
263         mdev_device_free(mdev);
264 }
265
266 int mdev_device_create(struct kobject *kobj,
267                        struct device *dev, const guid_t *uuid)
268 {
269         int ret;
270         struct mdev_device *mdev, *tmp;
271         struct mdev_parent *parent;
272         struct mdev_type *type = to_mdev_type(kobj);
273
274         parent = mdev_get_parent(type->parent);
275         if (!parent)
276                 return -EINVAL;
277
278         mutex_lock(&mdev_list_lock);
279
280         /* Check for duplicate */
281         list_for_each_entry(tmp, &mdev_list, next) {
282                 if (guid_equal(&tmp->uuid, uuid)) {
283                         mutex_unlock(&mdev_list_lock);
284                         ret = -EEXIST;
285                         goto mdev_fail;
286                 }
287         }
288
289         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
290         if (!mdev) {
291                 mutex_unlock(&mdev_list_lock);
292                 ret = -ENOMEM;
293                 goto mdev_fail;
294         }
295
296         guid_copy(&mdev->uuid, uuid);
297         list_add(&mdev->next, &mdev_list);
298         mutex_unlock(&mdev_list_lock);
299
300         mdev->parent = parent;
301
302         /* Check if parent unregistration has started */
303         if (!down_read_trylock(&parent->unreg_sem)) {
304                 mdev_device_free(mdev);
305                 ret = -ENODEV;
306                 goto mdev_fail;
307         }
308
309         device_initialize(&mdev->dev);
310         mdev->dev.parent  = dev;
311         mdev->dev.bus     = &mdev_bus_type;
312         mdev->dev.release = mdev_device_release;
313         dev_set_name(&mdev->dev, "%pUl", uuid);
314         mdev->dev.groups = parent->ops->mdev_attr_groups;
315         mdev->type_kobj = kobj;
316
317         ret = parent->ops->create(kobj, mdev);
318         if (ret)
319                 goto ops_create_fail;
320
321         ret = device_add(&mdev->dev);
322         if (ret)
323                 goto add_fail;
324
325         ret = mdev_create_sysfs_files(&mdev->dev, type);
326         if (ret)
327                 goto sysfs_fail;
328
329         mdev->active = true;
330         dev_dbg(&mdev->dev, "MDEV: created\n");
331         up_read(&parent->unreg_sem);
332
333         return 0;
334
335 sysfs_fail:
336         device_del(&mdev->dev);
337 add_fail:
338         parent->ops->remove(mdev);
339 ops_create_fail:
340         up_read(&parent->unreg_sem);
341         put_device(&mdev->dev);
342 mdev_fail:
343         mdev_put_parent(parent);
344         return ret;
345 }
346
347 int mdev_device_remove(struct device *dev)
348 {
349         struct mdev_device *mdev, *tmp;
350         struct mdev_parent *parent;
351
352         mdev = to_mdev_device(dev);
353
354         mutex_lock(&mdev_list_lock);
355         list_for_each_entry(tmp, &mdev_list, next) {
356                 if (tmp == mdev)
357                         break;
358         }
359
360         if (tmp != mdev) {
361                 mutex_unlock(&mdev_list_lock);
362                 return -ENODEV;
363         }
364
365         if (!mdev->active) {
366                 mutex_unlock(&mdev_list_lock);
367                 return -EAGAIN;
368         }
369
370         mdev->active = false;
371         mutex_unlock(&mdev_list_lock);
372
373         parent = mdev->parent;
374         /* Check if parent unregistration has started */
375         if (!down_read_trylock(&parent->unreg_sem))
376                 return -ENODEV;
377
378         mdev_device_remove_common(mdev);
379         up_read(&parent->unreg_sem);
380         return 0;
381 }
382
383 int mdev_set_iommu_device(struct device *dev, struct device *iommu_device)
384 {
385         struct mdev_device *mdev = to_mdev_device(dev);
386
387         mdev->iommu_device = iommu_device;
388
389         return 0;
390 }
391 EXPORT_SYMBOL(mdev_set_iommu_device);
392
393 struct device *mdev_get_iommu_device(struct device *dev)
394 {
395         struct mdev_device *mdev = to_mdev_device(dev);
396
397         return mdev->iommu_device;
398 }
399 EXPORT_SYMBOL(mdev_get_iommu_device);
400
401 static int __init mdev_init(void)
402 {
403         return mdev_bus_register();
404 }
405
406 static void __exit mdev_exit(void)
407 {
408         if (mdev_bus_compat_class)
409                 class_compat_unregister(mdev_bus_compat_class);
410
411         mdev_bus_unregister();
412 }
413
414 module_init(mdev_init)
415 module_exit(mdev_exit)
416
417 MODULE_VERSION(DRIVER_VERSION);
418 MODULE_LICENSE("GPL v2");
419 MODULE_AUTHOR(DRIVER_AUTHOR);
420 MODULE_DESCRIPTION(DRIVER_DESC);
421 MODULE_SOFTDEP("post: vfio_mdev");