Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
[sfrench/cifs-2.6.git] / drivers / virtio / virtio_pci_common.c
index 48d4d1cf1cb63f7b67f58f17e3d7874d14a00b09..705aebd74e560cd8d125cee28537e9734dabe518 100644 (file)
@@ -113,12 +113,13 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors,
 
        vp_dev->msix_vectors = nvectors;
 
-       vp_dev->msix_names = kmalloc(nvectors * sizeof *vp_dev->msix_names,
-                                    GFP_KERNEL);
+       vp_dev->msix_names = kmalloc_array(nvectors,
+                                          sizeof(*vp_dev->msix_names),
+                                          GFP_KERNEL);
        if (!vp_dev->msix_names)
                goto error;
        vp_dev->msix_affinity_masks
-               = kzalloc(nvectors * sizeof *vp_dev->msix_affinity_masks,
+               = kcalloc(nvectors, sizeof(*vp_dev->msix_affinity_masks),
                          GFP_KERNEL);
        if (!vp_dev->msix_affinity_masks)
                goto error;
@@ -577,6 +578,8 @@ static void virtio_pci_remove(struct pci_dev *pci_dev)
        struct virtio_pci_device *vp_dev = pci_get_drvdata(pci_dev);
        struct device *dev = get_device(&vp_dev->vdev.dev);
 
+       pci_disable_sriov(pci_dev);
+
        unregister_virtio_device(&vp_dev->vdev);
 
        if (vp_dev->ioaddr)
@@ -588,6 +591,33 @@ static void virtio_pci_remove(struct pci_dev *pci_dev)
        put_device(dev);
 }
 
+static int virtio_pci_sriov_configure(struct pci_dev *pci_dev, int num_vfs)
+{
+       struct virtio_pci_device *vp_dev = pci_get_drvdata(pci_dev);
+       struct virtio_device *vdev = &vp_dev->vdev;
+       int ret;
+
+       if (!(vdev->config->get_status(vdev) & VIRTIO_CONFIG_S_DRIVER_OK))
+               return -EBUSY;
+
+       if (!__virtio_test_bit(vdev, VIRTIO_F_SR_IOV))
+               return -EINVAL;
+
+       if (pci_vfs_assigned(pci_dev))
+               return -EPERM;
+
+       if (num_vfs == 0) {
+               pci_disable_sriov(pci_dev);
+               return 0;
+       }
+
+       ret = pci_enable_sriov(pci_dev, num_vfs);
+       if (ret < 0)
+               return ret;
+
+       return num_vfs;
+}
+
 static struct pci_driver virtio_pci_driver = {
        .name           = "virtio-pci",
        .id_table       = virtio_pci_id_table,
@@ -596,6 +626,7 @@ static struct pci_driver virtio_pci_driver = {
 #ifdef CONFIG_PM_SLEEP
        .driver.pm      = &virtio_pci_pm_ops,
 #endif
+       .sriov_configure = virtio_pci_sriov_configure,
 };
 
 module_pci_driver(virtio_pci_driver);