diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 6c18061a9..5b0affbae 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1,6 +1,7 @@ #include "ggml-vulkan.h" #include +#define VMA_IMPLEMENTATION #include "external/vk_mem_alloc.h" #include @@ -106,13 +107,14 @@ struct scoped_spin_lock { struct vk_buffer { vk::Buffer buffer; + vk::DeviceMemory memory; size_t size = 0; }; static vk_buffer g_vk_buffer_pool[MAX_VK_BUFFERS]; static std::atomic_flag g_vk_pool_lock = ATOMIC_FLAG_INIT; -static vk::Buffer ggml_vk_pool_malloc(size_t size, size_t * actual_size) { +static void ggml_vk_pool_malloc(size_t size, vk_buffer* buf) { scoped_spin_lock lock(g_vk_pool_lock); int best_i = -1; @@ -121,33 +123,33 @@ static vk::Buffer ggml_vk_pool_malloc(size_t size, size_t * actual_size) { size_t worst_size = 0; //largest unused buffer seen so far for (int i = 0; i < MAX_VK_BUFFERS; ++i) { vk_buffer &b = g_vk_buffer_pool[i]; - if (b.size > 0 && b.size >= size && b.size < best_size) - { + if (b.size > 0 && b.size >= size && b.size < best_size) { best_i = i; best_size = b.size; } - if (b.size > 0 && b.size > worst_size) - { + if (b.size > 0 && b.size > worst_size) { worst_i = i; worst_size = b.size; } } - if(best_i!=-1) //found the smallest buffer that fits our needs - { + if(best_i != -1) { + //found the smallest buffer that fits our needs vk_buffer& b = g_vk_buffer_pool[best_i]; - vk::Buffer buffer = b.buffer; - *actual_size = b.size; + buf->buffer = b.buffer; + buf->memory = b.memory; + buf->size = b.size; b.size = 0; - return buffer; + return; } - if(worst_i!=-1) //no buffer that fits our needs, resize largest one to save memory - { - vk_buffer& b = g_vk_buffer_pool[worst_i]; - vk::Buffer buffer = b.buffer; - b.size = 0; - // vkReleaseMemObject(buffer); + if(worst_i != -1) { + //no buffer that fits our needs, resize largest one to save memory + vk_buffer& b = g_vk_buffer_pool[worst_i]; + b.size = 0; + vk_device.freeMemory(b.memory); + vk_device.destroyBuffer(b.buffer); } - vk::Buffer buffer; + buf = new vk_buffer; + buf->size = size; vk::BufferCreateInfo buffer_create_info{ vk::BufferCreateFlags(), @@ -172,27 +174,46 @@ static vk::Buffer ggml_vk_pool_malloc(size_t size, size_t * actual_size) { VmaAllocation buffer_allocation; vmaCreateBuffer(allocator, - &static_cast(buffer_create_info), + (VkBufferCreateInfo*)&buffer_create_info, &allocation_info, - &static_cast(buffer), + (VkBuffer*)&buf->buffer, &buffer_allocation, nullptr); - *actual_size = size; - return buffer; + vk::MemoryRequirements buffer_memory_requirements = vk_device.getBufferMemoryRequirements(buf->buffer); + vk::PhysicalDeviceMemoryProperties memory_properties = vk_physical_device.getMemoryProperties(); + + uint32_t memory_type_index = uint32_t(~0); + + for (uint32_t current_memory_type_index = 0; current_memory_type_index < memory_properties.memoryTypeCount; current_memory_type_index++) { + vk::MemoryType memory_type = memory_properties.memoryTypes[current_memory_type_index]; + if ((vk::MemoryPropertyFlagBits::eHostVisible & memory_type.propertyFlags) && + (vk::MemoryPropertyFlagBits::eHostCoherent & memory_type.propertyFlags)) + { + memory_type_index = current_memory_type_index; + break; + } + } + + vk::MemoryAllocateInfo buffer_memory_allocate_info(buffer_memory_requirements.size, memory_type_index); + + buf->memory = vk_device.allocateMemory(buffer_memory_allocate_info); } -static void ggml_vk_pool_free(vk::Buffer buffer, size_t size) { +static void ggml_vk_pool_free(vk_buffer* buffer) { scoped_spin_lock lock(g_vk_pool_lock); for (int i = 0; i < MAX_VK_BUFFERS; ++i) { vk_buffer& b = g_vk_buffer_pool[i]; if (b.size == 0) { - b.buffer = buffer; - b.size = size; + b.buffer = buffer->buffer; + b.memory = buffer->memory; + b.size = buffer->size; return; } } fprintf(stderr, "WARNING: vk buffer pool full, increase MAX_VK_BUFFERS\n"); - vkReleaseMemObject(mem); + buffer->size = 0; + vk_device.freeMemory(buffer->memory); + vk_device.destroyBuffer(buffer->buffer); }