diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index efbcf8290..ae67d2d3a 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -221,26 +221,26 @@ static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index) { +static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, int32_t min_num_queues) { const uint32_t qfsize = queue_family_props.size(); // Try with avoid preferences first for (uint32_t i = 0; i < qfsize; i++) { - if ((compute_index < 0 || i != compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { return i; } } // Fall back to only required for (size_t i = 0; i < qfsize; i++) { - if ((compute_index < 0 || i != compute_index) && queue_family_props[i].queueFlags & required) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != compute_index) && queue_family_props[i].queueFlags & required) { return i; } } // Fall back to reusing compute queue for (size_t i = 0; i < qfsize; i++) { - if (queue_family_props[i].queueFlags & required) { + if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { return i; } } @@ -373,14 +373,18 @@ void ggml_vk_init(void) { std::vector queue_family_props = vk_physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues - uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1); - uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics | vk::QueueFlagBits::eVideoDecodeKHR | vk::QueueFlagBits::eProtected | vk::QueueFlagBits::eOpticalFlowNV, compute_queue_family_index); + uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); + uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics | vk::QueueFlagBits::eVideoDecodeKHR | vk::QueueFlagBits::eProtected | vk::QueueFlagBits::eOpticalFlowNV, compute_queue_family_index, 2); const float compute_queue_priority = 1.0f; - const float transfer_queue_priority[] = { 1.0f, 1.0f }; + const float transfer_queue_priority[] = { 1.0f, 1.0f, 1.0f }; std::vector device_queue_create_infos; - device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, &compute_queue_priority}); - device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, VK_TRANSFER_QUEUE_COUNT, transfer_queue_priority}); + if (compute_queue_family_index != transfer_queue_family_index) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, &compute_queue_priority}); + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, VK_TRANSFER_QUEUE_COUNT, transfer_queue_priority}); + } else { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1 + VK_TRANSFER_QUEUE_COUNT, transfer_queue_priority}); + } vk::DeviceCreateInfo device_create_info; std::vector device_extensions; vk::PhysicalDeviceFeatures device_features = vk_physical_device.getFeatures();