This commit is contained in:
Yifan Gu 2024-11-14 14:01:42 +08:00 committed by GitHub
commit a6016508ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -157,6 +157,7 @@ struct vk_device_struct {
vk::PhysicalDeviceProperties properties;
std::string name;
uint64_t max_memory_allocation_size;
uint32_t force_heap_index;
bool fp16;
vk::Device device;
uint32_t vendor_id;
@ -1037,9 +1038,12 @@ static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
q.cmd_buffer_idx = 0;
}
static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags, uint32_t force_heap_index = UINT32_MAX) {
for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
vk::MemoryType memory_type = mem_props->memoryTypes[i];
if (force_heap_index != UINT32_MAX && memory_type.heapIndex != force_heap_index) {
continue;
}
if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
(flags & memory_type.propertyFlags) == flags &&
mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
@ -1081,11 +1085,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
uint32_t memory_type_index = UINT32_MAX;
memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
memory_type_index = find_properties(&mem_props, &mem_req, req_flags, device->force_heap_index);
buf->memory_property_flags = req_flags;
if (memory_type_index == UINT32_MAX && fallback_flags) {
memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags, device->force_heap_index);
buf->memory_property_flags = fallback_flags;
}
@ -1580,6 +1584,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
}
const char* GGML_VK_FORCE_HEAP_INDEX = getenv("GGML_VK_FORCE_HEAP_INDEX");
if (GGML_VK_FORCE_HEAP_INDEX != nullptr) {
device->force_heap_index = std::stoi(GGML_VK_FORCE_HEAP_INDEX);
} else {
device->force_heap_index = UINT32_MAX;
}
device->vendor_id = device->properties.vendorID;
device->subgroup_size = subgroup_props.subgroupSize;
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;