diff --git a/drivers/vfio/vfio.h b/drivers/vfio/vfio.h index 56fab31f8e0f..039e3208d286 100644 --- a/drivers/vfio/vfio.h +++ b/drivers/vfio/vfio.h @@ -41,7 +41,15 @@ enum vfio_group_type { struct vfio_group { struct device dev; struct cdev cdev; + /* + * When drivers is non-zero a driver is attached to the struct device + * that provided the iommu_group and thus the iommu_group is a valid + * pointer. When drivers is 0 the driver is being detached. Once users + * reaches 0 then the iommu_group is invalid. + */ + refcount_t drivers; refcount_t users; + struct completion users_comp; unsigned int container_users; struct iommu_group *iommu_group; struct vfio_container *container; diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index af5945c71c41..f19171cad9a2 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -125,8 +125,6 @@ static void vfio_release_device_set(struct vfio_device *device) xa_unlock(&vfio_device_set_xa); } -static void vfio_group_get(struct vfio_group *group); - /* * Group objects - create, release, get, put, search */ @@ -137,7 +135,7 @@ __vfio_group_get_from_iommu(struct iommu_group *iommu_group) list_for_each_entry(group, &vfio.group_list, vfio_next) { if (group->iommu_group == iommu_group) { - vfio_group_get(group); + refcount_inc(&group->drivers); return group; } } @@ -189,6 +187,8 @@ static struct vfio_group *vfio_group_alloc(struct iommu_group *iommu_group, group->cdev.owner = THIS_MODULE; refcount_set(&group->users, 1); + refcount_set(&group->drivers, 1); + init_completion(&group->users_comp); init_rwsem(&group->group_rwsem); INIT_LIST_HEAD(&group->device_list); mutex_init(&group->device_lock); @@ -247,8 +247,41 @@ static struct vfio_group *vfio_create_group(struct iommu_group *iommu_group, static void vfio_group_put(struct vfio_group *group) { - if (!refcount_dec_and_mutex_lock(&group->users, &vfio.group_lock)) + if (refcount_dec_and_test(&group->users)) + complete(&group->users_comp); +} + +static void vfio_device_remove_group(struct vfio_device *device) +{ + struct vfio_group *group = device->group; + + if (group->type == VFIO_NO_IOMMU || group->type == VFIO_EMULATED_IOMMU) + iommu_group_remove_device(device->dev); + + /* Pairs with vfio_create_group() / vfio_group_get_from_iommu() */ + if (!refcount_dec_and_mutex_lock(&group->drivers, &vfio.group_lock)) return; + list_del(&group->vfio_next); + + /* + * We could concurrently probe another driver in the group that might + * race vfio_device_remove_group() with vfio_get_group(), so we have to + * ensure that the sysfs is all cleaned up under lock otherwise the + * cdev_device_add() will fail due to the name aready existing. + */ + cdev_device_del(&group->cdev, &group->dev); + mutex_unlock(&vfio.group_lock); + + /* Matches the get from vfio_group_alloc() */ + vfio_group_put(group); + + /* + * Before we allow the last driver in the group to be unplugged the + * group must be sanitized so nothing else is or can reference it. This + * is because the group->iommu_group pointer should only be used so long + * as a device driver is attached to a device in the group. + */ + wait_for_completion(&group->users_comp); /* * These data structures all have paired operations that can only be @@ -259,19 +292,11 @@ static void vfio_group_put(struct vfio_group *group) WARN_ON(!list_empty(&group->device_list)); WARN_ON(group->container || group->container_users); WARN_ON(group->notifier.head); - - list_del(&group->vfio_next); - cdev_device_del(&group->cdev, &group->dev); - mutex_unlock(&vfio.group_lock); + group->iommu_group = NULL; put_device(&group->dev); } -static void vfio_group_get(struct vfio_group *group) -{ - refcount_inc(&group->users); -} - /* * Device objects - create, release, get, put, search */ @@ -494,6 +519,10 @@ static int __vfio_register_dev(struct vfio_device *device, struct vfio_device *existing_device; int ret; + /* + * In all cases group is the output of one of the group allocation + * functions and we have group->drivers incremented for us. + */ if (IS_ERR(group)) return PTR_ERR(group); @@ -533,10 +562,7 @@ static int __vfio_register_dev(struct vfio_device *device, return 0; err_out: - if (group->type == VFIO_NO_IOMMU || - group->type == VFIO_EMULATED_IOMMU) - iommu_group_remove_device(device->dev); - vfio_group_put(group); + vfio_device_remove_group(device); return ret; } @@ -627,11 +653,7 @@ void vfio_unregister_group_dev(struct vfio_device *device) /* Balances device_add in register path */ device_del(&device->device); - if (group->type == VFIO_NO_IOMMU || group->type == VFIO_EMULATED_IOMMU) - iommu_group_remove_device(device->dev); - - /* Matches the get in vfio_register_group_dev() */ - vfio_group_put(group); + vfio_device_remove_group(device); } EXPORT_SYMBOL_GPL(vfio_unregister_group_dev); @@ -884,7 +906,7 @@ static int vfio_group_fops_open(struct inode *inode, struct file *filep) down_write(&group->group_rwsem); - /* users can be zero if this races with vfio_group_put() */ + /* users can be zero if this races with vfio_device_remove_group() */ if (!refcount_inc_not_zero(&group->users)) { ret = -ENODEV; goto err_unlock;