Only do device info print in the beginning and initialize one backend for cpu assist

Add missing cleanup code
This commit is contained in:
0cc4m 2024-02-05 20:52:47 +01:00
parent daa6a9c303
commit 087ae64e5e
3 changed files with 166 additions and 28 deletions

View file

@ -37,6 +37,8 @@
#define GGML_VK_MAX_NODES 8192 #define GGML_VK_MAX_NODES 8192
#define MAX_VK_BUFFERS 256
#ifndef K_QUANTS_PER_ITERATION #ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1 #define K_QUANTS_PER_ITERATION 1
#else #else
@ -74,6 +76,7 @@ struct vk_subbuffer {
struct vk_pipeline { struct vk_pipeline {
std::string name; std::string name;
vk::ShaderModule shader_module;
vk::DescriptorSetLayout dsl; vk::DescriptorSetLayout dsl;
std::vector<vk::DescriptorPool> descriptor_pools; std::vector<vk::DescriptorPool> descriptor_pools;
std::vector<vk::DescriptorSet> descriptor_sets; std::vector<vk::DescriptorSet> descriptor_sets;
@ -119,6 +122,7 @@ struct vk_device {
uint32_t vendor_id; uint32_t vendor_id;
vk_queue compute_queue; vk_queue compute_queue;
vk_queue transfer_queue; vk_queue transfer_queue;
bool single_queue;
uint32_t descriptor_set_mode; uint32_t descriptor_set_mode;
uint32_t subgroup_size; uint32_t subgroup_size;
bool uma; bool uma;
@ -259,6 +263,8 @@ struct ggml_backend_vk_context {
size_t staging_offset; size_t staging_offset;
vk_buffer sync_staging; vk_buffer sync_staging;
vk_buffer buffer_pool[MAX_VK_BUFFERS];
vk_context * compute_ctx; vk_context * compute_ctx;
vk_context * transfer_ctx; vk_context * transfer_ctx;
@ -293,6 +299,8 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subct
static bool vk_instance_initialized = false; static bool vk_instance_initialized = false;
static vk_instance vk_instance; static vk_instance vk_instance;
GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) { static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_create_pipeline(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl; std::cerr << "ggml_vk_create_pipeline(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl;
@ -307,7 +315,7 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
pipeline.align = align; pipeline.align = align;
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data)); vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
vk::ShaderModule shader_module = ctx->device.device.createShaderModule(shader_module_create_info); pipeline.shader_module = ctx->device.device.createShaderModule(shader_module_create_info);
std::vector<vk::DescriptorSetLayoutBinding> dsl_binding; std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
std::vector<vk::DescriptorBindingFlags> dsl_binding_flags; std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@ -383,7 +391,7 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
vk::PipelineShaderStageCreateFlags(), vk::PipelineShaderStageCreateFlags(),
vk::ShaderStageFlagBits::eCompute, vk::ShaderStageFlagBits::eCompute,
shader_module, pipeline.shader_module,
entrypoint.c_str(), entrypoint.c_str(),
&specialization_info); &specialization_info);
vk::ComputePipelineCreateInfo compute_pipeline_create_info( vk::ComputePipelineCreateInfo compute_pipeline_create_info(
@ -405,6 +413,10 @@ static void ggml_vk_destroy_pipeline(ggml_backend_vk_context * ctx, vk_pipeline
ctx->device.device.destroyDescriptorSetLayout(pipeline->dsl); ctx->device.device.destroyDescriptorSetLayout(pipeline->dsl);
ctx->device.device.destroyPipelineLayout(pipeline->layout);
ctx->device.device.destroyShaderModule(pipeline->shader_module);
ctx->device.device.destroyPipeline(pipeline->pipeline); ctx->device.device.destroyPipeline(pipeline->pipeline);
} }
@ -733,6 +745,10 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
buf.ctx = ctx; buf.ctx = ctx;
#ifdef GGML_VULKAN_DEBUG
std::cerr << "Created buffer " << buf.buffer << std::endl;
#endif
return buf; return buf;
} }
@ -974,6 +990,79 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
} }
static void ggml_vk_print_gpu_info(size_t idx) {
GGML_ASSERT(idx < vk_instance.device_indices.size());
size_t dev_num = vk_instance.device_indices[idx];
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_print_gpu_info(" << dev_num << ")" << std::endl;
#endif
GGML_ASSERT(vk_instance.initialized);
std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
if (dev_num >= devices.size()) {
std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
throw std::runtime_error("Device not found");
}
vk::PhysicalDevice physical_device = devices[dev_num];
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceSubgroupProperties subgroup_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
physical_device.getProperties2(&props2);
const size_t subgroup_size = subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
bool fp16_storage = false;
bool fp16_compute = false;
for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
fp16_storage = true;
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
fp16_compute = true;
}
}
const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
VkPhysicalDeviceFeatures2 device_features2;
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
device_features2.pNext = nullptr;
device_features2.features = (VkPhysicalDeviceFeatures)device_features;
VkPhysicalDeviceVulkan11Features vk11_features;
vk11_features.pNext = nullptr;
vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
device_features2.pNext = &vk11_features;
VkPhysicalDeviceVulkan12Features vk12_features;
vk12_features.pNext = nullptr;
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
vk11_features.pNext = &vk12_features;
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
std::string device_name = props2.properties.deviceName.data();
std::cerr << GGML_VK_NAME << idx << ": " << device_name << " | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << std::endl;
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
}
}
void ggml_vk_instance_init() { void ggml_vk_instance_init() {
if (vk_instance_initialized) { if (vk_instance_initialized) {
return; return;
@ -1094,7 +1183,7 @@ void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
} }
const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16"); const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != NULL; bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
ctx->device.fp16 = !force_disable_f16 && fp16_storage && fp16_compute; ctx->device.fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@ -1105,13 +1194,13 @@ void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
const float priorities[] = { 1.0f, 1.0f }; const float priorities[] = { 1.0f, 1.0f };
const bool single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; ctx->device.single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos; std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
if (compute_queue_family_index != transfer_queue_family_index) { if (compute_queue_family_index != transfer_queue_family_index) {
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
} else if(!single_queue) { } else if(!ctx->device.single_queue) {
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
} else { } else {
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
@ -1140,8 +1229,8 @@ void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
ctx->device.fp16 = ctx->device.fp16 && vk12_features.shaderFloat16; ctx->device.fp16 = ctx->device.fp16 && vk12_features.shaderFloat16;
if (!vk11_features.storageBuffer16BitAccess) { if (!vk11_features.storageBuffer16BitAccess) {
std::cerr << "ggml_vulkan: device does not support 16-bit storage" << std::endl; std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
GGML_ASSERT(false); throw std::runtime_error("Unsupported device");
} }
device_extensions.push_back("VK_KHR_16bit_storage"); device_extensions.push_back("VK_KHR_16bit_storage");
@ -1154,11 +1243,6 @@ void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
device_extensions.push_back("VK_KHR_shader_float16_int8"); device_extensions.push_back("VK_KHR_shader_float16_int8");
} }
ctx->device.name = ctx->device.properties.deviceName.data(); ctx->device.name = ctx->device.properties.deviceName.data();
std::cerr << GGML_VK_NAME << idx << ": " << ctx->device.name << " | uma: " << ctx->device.uma << " | fp16: " << ctx->device.fp16 << " | warp size: " << ctx->device.subgroup_size << std::endl;
if (ctx->device.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
}
device_create_info = { device_create_info = {
vk::DeviceCreateFlags(), vk::DeviceCreateFlags(),
@ -1176,7 +1260,7 @@ void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
// Queues // Queues
ggml_vk_create_queue(ctx, ctx->device.compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }); ggml_vk_create_queue(ctx, ctx->device.compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
if (!single_queue) { if (!ctx->device.single_queue) {
const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
ggml_vk_create_queue(ctx, ctx->device.transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }); ggml_vk_create_queue(ctx, ctx->device.transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
} else { } else {
@ -1250,11 +1334,6 @@ static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
return &ctx->pipeline_dequant_mul_mat_vec_f32[type]; return &ctx->pipeline_dequant_mul_mat_vec_f32[type];
} }
// buffer pool for vulkan
#define MAX_VK_BUFFERS 256
static vk_buffer g_vk_buffer_pool[MAX_VK_BUFFERS];
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_pool_malloc(" << size << ")" << std::endl; std::cerr << "ggml_vk_pool_malloc(" << size << ")" << std::endl;
@ -1264,7 +1343,7 @@ static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size)
int worst_i = -1; int worst_i = -1;
size_t worst_size = 0; //largest unused buffer seen so far size_t worst_size = 0; //largest unused buffer seen so far
for (int i = 0; i < MAX_VK_BUFFERS; ++i) { for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer &b = g_vk_buffer_pool[i]; vk_buffer &b = ctx->buffer_pool[i];
if (b.size > 0 && b.size >= size && b.size < best_size) { if (b.size > 0 && b.size >= size && b.size < best_size) {
best_i = i; best_i = i;
best_size = b.size; best_size = b.size;
@ -1276,13 +1355,13 @@ static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size)
} }
if(best_i != -1) { if(best_i != -1) {
//found the smallest buffer that fits our needs //found the smallest buffer that fits our needs
vk_buffer b = g_vk_buffer_pool[best_i]; vk_buffer b = ctx->buffer_pool[best_i];
g_vk_buffer_pool[best_i].size = 0; ctx->buffer_pool[best_i].size = 0;
return b; return b;
} }
if(worst_i != -1) { if(worst_i != -1) {
//no buffer that fits our needs, resize largest one to save memory //no buffer that fits our needs, resize largest one to save memory
vk_buffer& b = g_vk_buffer_pool[worst_i]; vk_buffer& b = ctx->buffer_pool[worst_i];
ggml_vk_destroy_buffer(ctx, b); ggml_vk_destroy_buffer(ctx, b);
} }
@ -1294,7 +1373,7 @@ static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer)
std::cerr << "ggml_vk_pool_free(" << buffer.size << ")" << std::endl; std::cerr << "ggml_vk_pool_free(" << buffer.size << ")" << std::endl;
#endif #endif
for (int i = 0; i < MAX_VK_BUFFERS; ++i) { for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer& b = g_vk_buffer_pool[i]; vk_buffer& b = ctx->buffer_pool[i];
if (b.size == 0) { if (b.size == 0) {
b = buffer; b = buffer;
// Set owning queue family index to ignored to avoid synchronization on next use // Set owning queue family index to ignored to avoid synchronization on next use
@ -4319,14 +4398,24 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
// Clean up on backend free // Clean up on backend free
static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_cleanup()" << std::endl; std::cerr << "ggml_vk_cleanup(" << ctx->idx << ")" << std::endl;
#endif #endif
ggml_vk_graph_cleanup(ctx);
ggml_vk_destroy_buffer(ctx, ctx->prealloc_qx);
ggml_vk_destroy_buffer(ctx, ctx->prealloc_qy);
ggml_vk_destroy_buffer(ctx, ctx->prealloc_x); ggml_vk_destroy_buffer(ctx, ctx->prealloc_x);
ggml_vk_destroy_buffer(ctx, ctx->prealloc_y); ggml_vk_destroy_buffer(ctx, ctx->prealloc_y);
ggml_vk_destroy_buffer(ctx, ctx->prealloc_split_k); ggml_vk_destroy_buffer(ctx, ctx->prealloc_split_k);
ggml_vk_destroy_buffer(ctx, ctx->staging); ggml_vk_destroy_buffer(ctx, ctx->staging);
ggml_vk_destroy_buffer(ctx, ctx->sync_staging); ggml_vk_destroy_buffer(ctx, ctx->sync_staging);
for (auto& buffer : ctx->buffer_pool) {
ggml_vk_destroy_buffer(ctx, buffer);
}
ctx->prealloc_size_qx = 0;
ctx->prealloc_size_qy = 0;
ctx->prealloc_size_x = 0; ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0; ctx->prealloc_size_y = 0;
ctx->prealloc_size_split_k = 0; ctx->prealloc_size_split_k = 0;
@ -4343,6 +4432,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ctx->gc.pipelines.clear(); ctx->gc.pipelines.clear();
ctx->device.device.destroyFence(ctx->fence); ctx->device.device.destroyFence(ctx->fence);
ctx->device.device.destroyCommandPool(ctx->device.compute_queue.pool);
if (!ctx->device.single_queue) {
ctx->device.device.destroyCommandPool(ctx->device.transfer_queue.pool);
}
} }
GGML_CALL int ggml_vk_get_device_count() { GGML_CALL int ggml_vk_get_device_count() {
@ -4369,45 +4463,76 @@ void ggml_vk_init_cpu_assist() {
std::cerr << "ggml_vulkan: Found " << ggml_vk_get_device_count() << " Vulkan devices:" << std::endl; std::cerr << "ggml_vulkan: Found " << ggml_vk_get_device_count() << " Vulkan devices:" << std::endl;
for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { for (size_t i = 0; i < ggml_vk_get_device_count(); i++) {
ggml_backend_vk_init(i); ggml_vk_print_gpu_info(i);
} }
// Initialize the first backend to make sure CPU matrix multiplications can be offloaded.
ggml_backend_vk_init(0);
} }
void ggml_vk_preallocate_buffers_graph_cpu_assist(ggml_tensor * node) { void ggml_vk_preallocate_buffers_graph_cpu_assist(ggml_tensor * node) {
ggml_backend_vk_context * ctx = &vk_instance.contexts[0]; ggml_backend_vk_context * ctx = &vk_instance.contexts[0];
if (!ctx->initialized) {
return;
}
ggml_vk_preallocate_buffers_graph(ctx, node); ggml_vk_preallocate_buffers_graph(ctx, node);
} }
void ggml_vk_preallocate_buffers_cpu_assist() { void ggml_vk_preallocate_buffers_cpu_assist() {
ggml_backend_vk_context * ctx = &vk_instance.contexts[0]; ggml_backend_vk_context * ctx = &vk_instance.contexts[0];
if (!ctx->initialized) {
return;
}
ggml_vk_preallocate_buffers(ctx); ggml_vk_preallocate_buffers(ctx);
} }
void ggml_vk_build_graph_cpu_assist(ggml_tensor * node, bool last_node) { void ggml_vk_build_graph_cpu_assist(ggml_tensor * node, bool last_node) {
ggml_backend_vk_context * ctx = &vk_instance.contexts[0]; ggml_backend_vk_context * ctx = &vk_instance.contexts[0];
if (!ctx->initialized) {
return;
}
ggml_vk_build_graph(ctx, node, last_node); ggml_vk_build_graph(ctx, node, last_node);
} }
bool ggml_vk_compute_forward_cpu_assist(ggml_compute_params * params, ggml_tensor * tensor){ bool ggml_vk_compute_forward_cpu_assist(ggml_compute_params * params, ggml_tensor * tensor){
ggml_backend_vk_context * ctx = &vk_instance.contexts[0]; ggml_backend_vk_context * ctx = &vk_instance.contexts[0];
if (!ctx->initialized) {
return false;
}
return ggml_vk_compute_forward(ctx, params, tensor); return ggml_vk_compute_forward(ctx, params, tensor);
} }
void ggml_vk_graph_cleanup_cpu_assist() { void ggml_vk_graph_cleanup_cpu_assist() {
ggml_backend_vk_context * ctx = &vk_instance.contexts[0]; ggml_backend_vk_context * ctx = &vk_instance.contexts[0];
if (!ctx->initialized) {
return;
}
ggml_vk_graph_cleanup(ctx); ggml_vk_graph_cleanup(ctx);
} }
void ggml_vk_cleanup_cpu_assist() { void ggml_vk_cleanup_cpu_assist() {
ggml_backend_vk_context * ctx = &vk_instance.contexts[0]; ggml_backend_vk_context * ctx = &vk_instance.contexts[0];
ggml_vk_cleanup(ctx); if (!ctx->initialized) {
return;
}
// Shouldn't happen, but better check
if (vk_instance.backends[0] == nullptr) {
return;
}
ggml_backend_vk_free(vk_instance.backends[0]);
} }
// backend interface // backend interface
@ -4679,6 +4804,11 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
/* .context = */ nullptr, /* .context = */ nullptr,
}; };
if (!vk_instance.contexts[0].initialized) {
// Fall back to CPU
return ggml_backend_cpu_buffer_type();
}
return &ggml_backend_vk_buffer_type_host; return &ggml_backend_vk_buffer_type_host;
} }
@ -4698,9 +4828,12 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
size_t idx = ctx->idx; size_t idx = ctx->idx;
ggml_vk_graph_cleanup(ctx);
ggml_vk_cleanup(ctx); ggml_vk_cleanup(ctx);
// Not possible until llama.cpp makes sure to destroy buffers before backends
// ctx->device.device.destroy();
ctx->initialized = false;
vk_instance.initialized[idx] = false; vk_instance.initialized[idx] = false;
vk_instance.backends[idx] = nullptr; vk_instance.backends[idx] = nullptr;
memset(&vk_instance.buffer_types[idx], 0, sizeof(ggml_backend_buffer_type)); memset(&vk_instance.buffer_types[idx], 0, sizeof(ggml_backend_buffer_type));

View file

@ -20,6 +20,7 @@ GGML_API bool ggml_vk_compute_forward_cpu_assist(struct ggml_compute_params * pa
void ggml_vk_check_results_1_cpu_assist(struct ggml_compute_params * params, struct ggml_tensor * tensor); void ggml_vk_check_results_1_cpu_assist(struct ggml_compute_params * params, struct ggml_tensor * tensor);
#endif #endif
GGML_API void ggml_vk_graph_cleanup_cpu_assist(void); GGML_API void ggml_vk_graph_cleanup_cpu_assist(void);
GGML_API void ggml_vk_cleanup_cpu_assist(void);
// backend API // backend API
GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num); GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);

View file

@ -1764,6 +1764,10 @@ struct llama_context {
ggml_backend_free(backend); ggml_backend_free(backend);
} }
#ifdef GGML_USE_VULKAN
ggml_vk_cleanup_cpu_assist();
#endif
ggml_backend_buffer_free(buf_input); ggml_backend_buffer_free(buf_input);
ggml_free(ctx_input); ggml_free(ctx_input);
} }