diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 9e8e004bb1c3..71bb46813031 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -25,6 +25,7 @@ #include #include #include +#include #include "vhost.h" @@ -663,6 +664,16 @@ int vhost_vq_access_ok(struct vhost_virtqueue *vq) } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); +static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) +{ + const struct vhost_memory_region *r1 = p1, *r2 = p2; + if (r1->guest_phys_addr < r2->guest_phys_addr) + return 1; + if (r1->guest_phys_addr > r2->guest_phys_addr) + return -1; + return 0; +} + static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) { struct vhost_memory mem, *newmem, *oldmem; @@ -682,9 +693,11 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) memcpy(newmem, &mem, size); if (copy_from_user(newmem->regions, m->regions, mem.nregions * sizeof *m->regions)) { - kfree(newmem); + kvfree(newmem); return -EFAULT; } + sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions), + vhost_memory_reg_sort_cmp, NULL); if (!memory_access_ok(d, newmem, 0)) { kfree(newmem); @@ -992,17 +1005,22 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl); static const struct vhost_memory_region *find_region(struct vhost_memory *mem, __u64 addr, __u32 len) { - struct vhost_memory_region *reg; - int i; + const struct vhost_memory_region *reg; + int start = 0, end = mem->nregions; - /* linear search is not brilliant, but we really have on the order of 6 - * regions in practice */ - for (i = 0; i < mem->nregions; ++i) { - reg = mem->regions + i; - if (reg->guest_phys_addr <= addr && - reg->guest_phys_addr + reg->memory_size - 1 >= addr) - return reg; + while (start < end) { + int slot = start + (end - start) / 2; + reg = mem->regions + slot; + if (addr >= reg->guest_phys_addr) + end = slot; + else + start = slot + 1; } + + reg = mem->regions + start; + if (addr >= reg->guest_phys_addr && + reg->guest_phys_addr + reg->memory_size > addr) + return reg; return NULL; }