diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 48bd15ef4..1fad24fd1 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -707,7 +707,7 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) { q.cmd_buffer_idx = 0; } -static int32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { vk::MemoryType memory_type = mem_props->memoryTypes[i]; if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && @@ -716,7 +716,7 @@ static int32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pro return static_cast(i); } } - return -1; + return UINT32_MAX; } static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { @@ -746,22 +746,17 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz vk::PhysicalDeviceMemoryProperties mem_props = ctx->device.lock()->physical_device.getMemoryProperties(); - uint32_t memory_type_index = -1; + uint32_t memory_type_index = UINT32_MAX; memory_type_index = find_properties(&mem_props, &mem_req, req_flags); buf->memory_property_flags = req_flags; - // Failed to find memory type matching req_flags, but we can try again with fallback_flags if specified... - if (memory_type_index == -1 && fallback_flags && ( - // ...as long as req_flags was either: 1) not DEVICE_LOCAL; or 2) DEVICE_LOCAL and device has UMA. - !(req_flags & vk::MemoryPropertyFlagBits::eDeviceLocal) || - (req_flags & vk::MemoryPropertyFlagBits::eDeviceLocal && ctx->device.lock()->uma)) - ) { + if (memory_type_index == UINT32_MAX && fallback_flags) { memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); buf->memory_property_flags = fallback_flags; } - if (memory_type_index == -1) { + if (memory_type_index == UINT32_MAX) { ctx->device.lock()->device.destroyBuffer(buf->buffer); buf->size = 0; throw vk::OutOfDeviceMemoryError("No suitable memory type found"); @@ -807,9 +802,12 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) { vk_buffer buf; try { - buf = ggml_vk_create_buffer(ctx, size, - vk::MemoryPropertyFlagBits::eDeviceLocal, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + if (ctx->device.lock()->uma) { + // Fall back to host memory type + buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } else { + buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal); + } } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; std::cerr << "ggml_vulkan: " << e.what() << std::endl;