diff --git a/drivers/misc/pci_endpoint_test.c b/drivers/misc/pci_endpoint_test.c index 5998df1c84e9..8682867ac14a 100644 --- a/drivers/misc/pci_endpoint_test.c +++ b/drivers/misc/pci_endpoint_test.c @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -52,6 +53,7 @@ #define STATUS_SRC_ADDR_INVALID BIT(7) #define STATUS_DST_ADDR_INVALID BIT(8) +#define PCI_ENDPOINT_TEST_STATUS 0x8 #define PCI_ENDPOINT_TEST_LOWER_SRC_ADDR 0x0c #define PCI_ENDPOINT_TEST_UPPER_SRC_ADDR 0x10 @@ -64,6 +66,9 @@ #define PCI_ENDPOINT_TEST_IRQ_TYPE 0x24 #define PCI_ENDPOINT_TEST_IRQ_NUMBER 0x28 +#define PCI_ENDPOINT_TEST_FLAGS 0x2c +#define FLAG_USE_DMA BIT(0) + #define PCI_DEVICE_ID_TI_AM654 0xb00c #define is_am654_pci_dev(pdev) \ @@ -315,11 +320,16 @@ static bool pci_endpoint_test_msi_irq(struct pci_endpoint_test *test, return false; } -static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, size_t size) +static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, + unsigned long arg) { + struct pci_endpoint_test_xfer_param param; bool ret = false; void *src_addr; void *dst_addr; + u32 flags = 0; + bool use_dma; + size_t size; dma_addr_t src_phys_addr; dma_addr_t dst_phys_addr; struct pci_dev *pdev = test->pdev; @@ -332,10 +342,22 @@ static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, size_t size) size_t alignment = test->alignment; u32 src_crc32; u32 dst_crc32; + int err; + err = copy_from_user(¶m, (void __user *)arg, sizeof(param)); + if (err) { + dev_err(dev, "Failed to get transfer param\n"); + return false; + } + + size = param.size; if (size > SIZE_MAX - alignment) goto err; + use_dma = !!(param.flags & PCITEST_FLAGS_USE_DMA); + if (use_dma) + flags |= FLAG_USE_DMA; + if (irq_type < IRQ_TYPE_LEGACY || irq_type > IRQ_TYPE_MSIX) { dev_err(dev, "Invalid IRQ type option\n"); goto err; @@ -406,6 +428,7 @@ static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, size_t size) pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_SIZE, size); + pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_FLAGS, flags); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_IRQ_TYPE, irq_type); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_IRQ_NUMBER, 1); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_COMMAND, @@ -434,9 +457,13 @@ static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, size_t size) return ret; } -static bool pci_endpoint_test_write(struct pci_endpoint_test *test, size_t size) +static bool pci_endpoint_test_write(struct pci_endpoint_test *test, + unsigned long arg) { + struct pci_endpoint_test_xfer_param param; bool ret = false; + u32 flags = 0; + bool use_dma; u32 reg; void *addr; dma_addr_t phys_addr; @@ -446,11 +473,24 @@ static bool pci_endpoint_test_write(struct pci_endpoint_test *test, size_t size) dma_addr_t orig_phys_addr; size_t offset; size_t alignment = test->alignment; + size_t size; u32 crc32; + int err; + err = copy_from_user(¶m, (void __user *)arg, sizeof(param)); + if (err != 0) { + dev_err(dev, "Failed to get transfer param\n"); + return false; + } + + size = param.size; if (size > SIZE_MAX - alignment) goto err; + use_dma = !!(param.flags & PCITEST_FLAGS_USE_DMA); + if (use_dma) + flags |= FLAG_USE_DMA; + if (irq_type < IRQ_TYPE_LEGACY || irq_type > IRQ_TYPE_MSIX) { dev_err(dev, "Invalid IRQ type option\n"); goto err; @@ -493,6 +533,7 @@ static bool pci_endpoint_test_write(struct pci_endpoint_test *test, size_t size) pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_SIZE, size); + pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_FLAGS, flags); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_IRQ_TYPE, irq_type); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_IRQ_NUMBER, 1); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_COMMAND, @@ -514,9 +555,14 @@ static bool pci_endpoint_test_write(struct pci_endpoint_test *test, size_t size) return ret; } -static bool pci_endpoint_test_read(struct pci_endpoint_test *test, size_t size) +static bool pci_endpoint_test_read(struct pci_endpoint_test *test, + unsigned long arg) { + struct pci_endpoint_test_xfer_param param; bool ret = false; + u32 flags = 0; + bool use_dma; + size_t size; void *addr; dma_addr_t phys_addr; struct pci_dev *pdev = test->pdev; @@ -526,10 +572,22 @@ static bool pci_endpoint_test_read(struct pci_endpoint_test *test, size_t size) size_t offset; size_t alignment = test->alignment; u32 crc32; + int err; + err = copy_from_user(¶m, (void __user *)arg, sizeof(param)); + if (err) { + dev_err(dev, "Failed to get transfer param\n"); + return false; + } + + size = param.size; if (size > SIZE_MAX - alignment) goto err; + use_dma = !!(param.flags & PCITEST_FLAGS_USE_DMA); + if (use_dma) + flags |= FLAG_USE_DMA; + if (irq_type < IRQ_TYPE_LEGACY || irq_type > IRQ_TYPE_MSIX) { dev_err(dev, "Invalid IRQ type option\n"); goto err; @@ -566,6 +624,7 @@ static bool pci_endpoint_test_read(struct pci_endpoint_test *test, size_t size) pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_SIZE, size); + pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_FLAGS, flags); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_IRQ_TYPE, irq_type); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_IRQ_NUMBER, 1); pci_endpoint_test_writel(test, PCI_ENDPOINT_TEST_COMMAND,