diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 6673e7be507f..0734fbe5b651 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -524,7 +524,7 @@ static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova, static int vfio_dma_do_map(struct vfio_iommu *iommu, struct vfio_iommu_type1_dma_map *map) { - dma_addr_t end, iova; + dma_addr_t iova = map->iova; unsigned long vaddr = map->vaddr; size_t size = map->size; long npage; @@ -533,39 +533,30 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, struct vfio_dma *dma; unsigned long pfn; - end = map->iova + map->size; + /* Verify that none of our __u64 fields overflow */ + if (map->size != size || map->vaddr != vaddr || map->iova != iova) + return -EINVAL; mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1; + WARN_ON(mask & PAGE_MASK); + /* READ/WRITE from device perspective */ if (map->flags & VFIO_DMA_MAP_FLAG_WRITE) prot |= IOMMU_WRITE; if (map->flags & VFIO_DMA_MAP_FLAG_READ) prot |= IOMMU_READ; - if (!prot) - return -EINVAL; /* No READ/WRITE? */ - - if (vaddr & mask) - return -EINVAL; - if (map->iova & mask) - return -EINVAL; - if (!map->size || map->size & mask) + if (!prot || !size || (size | iova | vaddr) & mask) return -EINVAL; - WARN_ON(mask & PAGE_MASK); - - /* Don't allow IOVA wrap */ - if (end && end < map->iova) - return -EINVAL; - - /* Don't allow virtual address wrap */ - if (vaddr + map->size && vaddr + map->size < vaddr) + /* Don't allow IOVA or virtual address wrap */ + if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) return -EINVAL; mutex_lock(&iommu->lock); - if (vfio_find_dma(iommu, map->iova, map->size)) { + if (vfio_find_dma(iommu, iova, size)) { mutex_unlock(&iommu->lock); return -EEXIST; } @@ -576,17 +567,17 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, return -ENOMEM; } - dma->iova = map->iova; - dma->vaddr = map->vaddr; + dma->iova = iova; + dma->vaddr = vaddr; dma->prot = prot; /* Insert zero-sized and grow as we map chunks of it */ vfio_link_dma(iommu, dma); - for (iova = map->iova; iova < end; iova += size, vaddr += size) { + while (size) { /* Pin a contiguous chunk of memory */ - npage = vfio_pin_pages(vaddr, (end - iova) >> PAGE_SHIFT, - prot, &pfn); + npage = vfio_pin_pages(vaddr + dma->size, + size >> PAGE_SHIFT, prot, &pfn); if (npage <= 0) { WARN_ON(!npage); ret = (int)npage; @@ -594,14 +585,14 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, } /* Map it! */ - ret = vfio_iommu_map(iommu, iova, pfn, npage, prot); + ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage, prot); if (ret) { vfio_unpin_pages(pfn, npage, prot, true); break; } - size = npage << PAGE_SHIFT; - dma->size += size; + size -= npage << PAGE_SHIFT; + dma->size += npage << PAGE_SHIFT; } if (ret)