vfio/mdev: Synchronize device create/remove with parent removal
[sfrench/cifs-2.6.git] / drivers / vfio / mdev / mdev_core.c
index 3cc1a05fde1c9de281900d8768dfd16bf0e3dbad..ae23151442cbd727d80633c0cbc6dd1ecb7ec7ff 100644 (file)
@@ -102,56 +102,35 @@ static void mdev_put_parent(struct mdev_parent *parent)
                kref_put(&parent->ref, mdev_release_parent);
 }
 
-static int mdev_device_create_ops(struct kobject *kobj,
-                                 struct mdev_device *mdev)
+/* Caller must hold parent unreg_sem read or write lock */
+static void mdev_device_remove_common(struct mdev_device *mdev)
 {
-       struct mdev_parent *parent = mdev->parent;
-       int ret;
-
-       ret = parent->ops->create(kobj, mdev);
-       if (ret)
-               return ret;
-
-       ret = sysfs_create_groups(&mdev->dev.kobj,
-                                 parent->ops->mdev_attr_groups);
-       if (ret)
-               parent->ops->remove(mdev);
-
-       return ret;
-}
-
-/*
- * mdev_device_remove_ops gets called from sysfs's 'remove' and when parent
- * device is being unregistered from mdev device framework.
- * - 'force_remove' is set to 'false' when called from sysfs's 'remove' which
- *   indicates that if the mdev device is active, used by VMM or userspace
- *   application, vendor driver could return error then don't remove the device.
- * - 'force_remove' is set to 'true' when called from mdev_unregister_device()
- *   which indicate that parent device is being removed from mdev device
- *   framework so remove mdev device forcefully.
- */
-static int mdev_device_remove_ops(struct mdev_device *mdev, bool force_remove)
-{
-       struct mdev_parent *parent = mdev->parent;
+       struct mdev_parent *parent;
+       struct mdev_type *type;
        int ret;
 
-       /*
-        * Vendor driver can return error if VMM or userspace application is
-        * using this mdev device.
-        */
+       type = to_mdev_type(mdev->type_kobj);
+       mdev_remove_sysfs_files(&mdev->dev, type);
+       device_del(&mdev->dev);
+       parent = mdev->parent;
+       lockdep_assert_held(&parent->unreg_sem);
        ret = parent->ops->remove(mdev);
-       if (ret && !force_remove)
-               return ret;
+       if (ret)
+               dev_err(&mdev->dev, "Remove failed: err=%d\n", ret);
 
-       sysfs_remove_groups(&mdev->dev.kobj, parent->ops->mdev_attr_groups);
-       return 0;
+       /* Balances with device_initialize() */
+       put_device(&mdev->dev);
+       mdev_put_parent(parent);
 }
 
 static int mdev_device_remove_cb(struct device *dev, void *data)
 {
-       if (dev_is_mdev(dev))
-               mdev_device_remove(dev, true);
+       if (dev_is_mdev(dev)) {
+               struct mdev_device *mdev;
 
+               mdev = to_mdev_device(dev);
+               mdev_device_remove_common(mdev);
+       }
        return 0;
 }
 
@@ -193,6 +172,7 @@ int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops)
        }
 
        kref_init(&parent->ref);
+       init_rwsem(&parent->unreg_sem);
 
        parent->dev = dev;
        parent->ops = ops;
@@ -251,21 +231,23 @@ void mdev_unregister_device(struct device *dev)
        dev_info(dev, "MDEV: Unregistering\n");
 
        list_del(&parent->next);
+       mutex_unlock(&parent_list_lock);
+
+       down_write(&parent->unreg_sem);
+
        class_compat_remove_link(mdev_bus_compat_class, dev, NULL);
 
        device_for_each_child(dev, NULL, mdev_device_remove_cb);
 
        parent_remove_sysfs_files(parent);
+       up_write(&parent->unreg_sem);
 
-       mutex_unlock(&parent_list_lock);
        mdev_put_parent(parent);
 }
 EXPORT_SYMBOL(mdev_unregister_device);
 
-static void mdev_device_release(struct device *dev)
+static void mdev_device_free(struct mdev_device *mdev)
 {
-       struct mdev_device *mdev = to_mdev_device(dev);
-
        mutex_lock(&mdev_list_lock);
        list_del(&mdev->next);
        mutex_unlock(&mdev_list_lock);
@@ -274,6 +256,13 @@ static void mdev_device_release(struct device *dev)
        kfree(mdev);
 }
 
+static void mdev_device_release(struct device *dev)
+{
+       struct mdev_device *mdev = to_mdev_device(dev);
+
+       mdev_device_free(mdev);
+}
+
 int mdev_device_create(struct kobject *kobj,
                       struct device *dev, const guid_t *uuid)
 {
@@ -310,46 +299,55 @@ int mdev_device_create(struct kobject *kobj,
 
        mdev->parent = parent;
 
+       /* Check if parent unregistration has started */
+       if (!down_read_trylock(&parent->unreg_sem)) {
+               mdev_device_free(mdev);
+               ret = -ENODEV;
+               goto mdev_fail;
+       }
+
+       device_initialize(&mdev->dev);
        mdev->dev.parent  = dev;
        mdev->dev.bus     = &mdev_bus_type;
        mdev->dev.release = mdev_device_release;
        dev_set_name(&mdev->dev, "%pUl", uuid);
+       mdev->dev.groups = parent->ops->mdev_attr_groups;
+       mdev->type_kobj = kobj;
 
-       ret = device_register(&mdev->dev);
-       if (ret) {
-               put_device(&mdev->dev);
-               goto mdev_fail;
-       }
+       ret = parent->ops->create(kobj, mdev);
+       if (ret)
+               goto ops_create_fail;
 
-       ret = mdev_device_create_ops(kobj, mdev);
+       ret = device_add(&mdev->dev);
        if (ret)
-               goto create_fail;
+               goto add_fail;
 
        ret = mdev_create_sysfs_files(&mdev->dev, type);
-       if (ret) {
-               mdev_device_remove_ops(mdev, true);
-               goto create_fail;
-       }
+       if (ret)
+               goto sysfs_fail;
 
-       mdev->type_kobj = kobj;
        mdev->active = true;
        dev_dbg(&mdev->dev, "MDEV: created\n");
+       up_read(&parent->unreg_sem);
 
        return 0;
 
-create_fail:
-       device_unregister(&mdev->dev);
+sysfs_fail:
+       device_del(&mdev->dev);
+add_fail:
+       parent->ops->remove(mdev);
+ops_create_fail:
+       up_read(&parent->unreg_sem);
+       put_device(&mdev->dev);
 mdev_fail:
        mdev_put_parent(parent);
        return ret;
 }
 
-int mdev_device_remove(struct device *dev, bool force_remove)
+int mdev_device_remove(struct device *dev)
 {
        struct mdev_device *mdev, *tmp;
        struct mdev_parent *parent;
-       struct mdev_type *type;
-       int ret;
 
        mdev = to_mdev_device(dev);
 
@@ -372,19 +370,13 @@ int mdev_device_remove(struct device *dev, bool force_remove)
        mdev->active = false;
        mutex_unlock(&mdev_list_lock);
 
-       type = to_mdev_type(mdev->type_kobj);
        parent = mdev->parent;
+       /* Check if parent unregistration has started */
+       if (!down_read_trylock(&parent->unreg_sem))
+               return -ENODEV;
 
-       ret = mdev_device_remove_ops(mdev, force_remove);
-       if (ret) {
-               mdev->active = true;
-               return ret;
-       }
-
-       mdev_remove_sysfs_files(dev, type);
-       device_unregister(dev);
-       mdev_put_parent(parent);
-
+       mdev_device_remove_common(mdev);
+       up_read(&parent->unreg_sem);
        return 0;
 }