diff --git a/drivers/iommu/amd_iommu.c b/drivers/iommu/amd_iommu.c index 3a00b5ce73ab..7fa97a5e3eaf 100644 --- a/drivers/iommu/amd_iommu.c +++ b/drivers/iommu/amd_iommu.c @@ -325,6 +325,24 @@ static struct pci_dev *get_isolation_root(struct pci_dev *pdev) return dma_pdev; } +static int use_pdev_iommu_group(struct pci_dev *pdev, struct device *dev) +{ + struct iommu_group *group = iommu_group_get(&pdev->dev); + int ret; + + if (!group) { + group = iommu_group_alloc(); + if (IS_ERR(group)) + return PTR_ERR(group); + + WARN_ON(&pdev->dev != dev); + } + + ret = iommu_group_add_device(group, dev); + iommu_group_put(group); + return ret; +} + static int init_iommu_group(struct device *dev) { struct iommu_dev_data *dev_data; @@ -353,18 +371,8 @@ static int init_iommu_group(struct device *dev) dma_pdev = pci_dev_get(to_pci_dev(dev)); dma_pdev = get_isolation_root(dma_pdev); - group = iommu_group_get(&dma_pdev->dev); + ret = use_pdev_iommu_group(dma_pdev, dev); pci_dev_put(dma_pdev); - if (!group) { - group = iommu_group_alloc(); - if (IS_ERR(group)) - return PTR_ERR(group); - } - - ret = iommu_group_add_device(group, dev); - - iommu_group_put(group); - return ret; }