diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index d0f731c9920a..c956b85264ae 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -97,6 +97,7 @@ struct vfio_dma { struct vfio_group { struct iommu_group *iommu_group; struct list_head next; + bool mdev_group; /* An mdev group */ }; /* @@ -1311,6 +1312,75 @@ static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base) return ret; } +static struct device *vfio_mdev_get_iommu_device(struct device *dev) +{ + struct device *(*fn)(struct device *dev); + struct device *iommu_device; + + fn = symbol_get(mdev_get_iommu_device); + if (fn) { + iommu_device = fn(dev); + symbol_put(mdev_get_iommu_device); + + return iommu_device; + } + + return NULL; +} + +static int vfio_mdev_attach_domain(struct device *dev, void *data) +{ + struct iommu_domain *domain = data; + struct device *iommu_device; + + iommu_device = vfio_mdev_get_iommu_device(dev); + if (iommu_device) { + if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX)) + return iommu_aux_attach_device(domain, iommu_device); + else + return iommu_attach_device(domain, iommu_device); + } + + return -EINVAL; +} + +static int vfio_mdev_detach_domain(struct device *dev, void *data) +{ + struct iommu_domain *domain = data; + struct device *iommu_device; + + iommu_device = vfio_mdev_get_iommu_device(dev); + if (iommu_device) { + if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX)) + iommu_aux_detach_device(domain, iommu_device); + else + iommu_detach_device(domain, iommu_device); + } + + return 0; +} + +static int vfio_iommu_attach_group(struct vfio_domain *domain, + struct vfio_group *group) +{ + if (group->mdev_group) + return iommu_group_for_each_dev(group->iommu_group, + domain->domain, + vfio_mdev_attach_domain); + else + return iommu_attach_group(domain->domain, group->iommu_group); +} + +static void vfio_iommu_detach_group(struct vfio_domain *domain, + struct vfio_group *group) +{ + if (group->mdev_group) + iommu_group_for_each_dev(group->iommu_group, domain->domain, + vfio_mdev_detach_domain); + else + iommu_detach_group(domain->domain, group->iommu_group); +} + static int vfio_iommu_type1_attach_group(void *iommu_data, struct iommu_group *iommu_group) { @@ -1386,7 +1456,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, goto out_domain; } - ret = iommu_attach_group(domain->domain, iommu_group); + ret = vfio_iommu_attach_group(domain, group); if (ret) goto out_domain; @@ -1418,8 +1488,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, list_for_each_entry(d, &iommu->domain_list, next) { if (d->domain->ops == domain->domain->ops && d->prot == domain->prot) { - iommu_detach_group(domain->domain, iommu_group); - if (!iommu_attach_group(d->domain, iommu_group)) { + vfio_iommu_detach_group(domain, group); + if (!vfio_iommu_attach_group(d, group)) { list_add(&group->next, &d->group_list); iommu_domain_free(domain->domain); kfree(domain); @@ -1427,7 +1497,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, return 0; } - ret = iommu_attach_group(domain->domain, iommu_group); + ret = vfio_iommu_attach_group(domain, group); if (ret) goto out_domain; } @@ -1453,7 +1523,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data, return 0; out_detach: - iommu_detach_group(domain->domain, iommu_group); + vfio_iommu_detach_group(domain, group); out_domain: iommu_domain_free(domain->domain); out_free: @@ -1544,7 +1614,7 @@ static void vfio_iommu_type1_detach_group(void *iommu_data, if (!group) continue; - iommu_detach_group(domain->domain, iommu_group); + vfio_iommu_detach_group(domain, group); list_del(&group->next); kfree(group); /* @@ -1610,7 +1680,7 @@ static void vfio_release_domain(struct vfio_domain *domain, bool external) list_for_each_entry_safe(group, group_tmp, &domain->group_list, next) { if (!external) - iommu_detach_group(domain->domain, group->iommu_group); + vfio_iommu_detach_group(domain, group); list_del(&group->next); kfree(group); }