Add further missing cleanup code

This commit is contained in:
0cc4m 2024-02-04 19:11:59 +01:00
parent 5a1ad8c3e5
commit a1f9c008db

View file

@ -391,25 +391,27 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
pipeline_shader_create_info,
pipeline.layout);
pipeline.pipeline = ctx->device.device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
ctx->gc.pipelines.push_back(&pipeline);
}
static void ggml_vk_destroy_pipeline(ggml_backend_vk_context * ctx, vk_pipeline * pipeline) {
for (auto& pool : pipeline->descriptor_pools) {
ctx->device.device.destroyDescriptorPool(pool);
}
pipeline->descriptor_pools.clear();
pipeline->descriptor_sets.clear();
pipeline->descriptor_set_idx = 0;
ctx->device.device.destroyDescriptorSetLayout(pipeline->dsl);
ctx->device.device.destroyPipeline(pipeline->pipeline);
}
static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, uint32_t n) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl;
#endif
// Check if gc already contains pipeline before adding it
bool gc_found = false;
for (auto * pl : ctx->gc.pipelines) {
if (&pipeline == pl) {
gc_found = true;
break;
}
}
if (!gc_found) {
ctx->gc.pipelines.push_back(&pipeline);
}
if (pipeline.descriptor_sets.size() >= pipeline.descriptor_set_idx + n) {
// Enough descriptors are available
return;
@ -4270,6 +4272,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
return true;
}
// Clean up after graph processing is done
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
if (ctx->disable) {
return;
@ -4285,7 +4288,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
for (auto * pipeline : ctx->gc.pipelines) {
ggml_pipeline_cleanup(ctx, *pipeline);
}
ctx->gc.pipelines.clear();
ggml_vk_queue_cleanup(ctx, ctx->device.compute_queue);
ggml_vk_queue_cleanup(ctx, ctx->device.transfer_queue);
@ -4314,6 +4316,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
ctx->gc.contexts.clear();
}
// Clean up on backend free
static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_cleanup()" << std::endl;
@ -4333,6 +4336,13 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ctx->device.device.destroyEvent(event);
}
ctx->gc.events.clear();
for (auto* pipeline : ctx->gc.pipelines) {
ggml_vk_destroy_pipeline(ctx, pipeline);
}
ctx->gc.pipelines.clear();
ctx->device.device.destroyFence(ctx->fence);
}
GGML_CALL int ggml_vk_get_device_count() {
@ -4688,8 +4698,8 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
size_t idx = ctx->idx;
// ggml_vk_graph_cleanup(ctx);
// ggml_vk_cleanup(ctx);
ggml_vk_graph_cleanup(ctx);
ggml_vk_cleanup(ctx);
vk_instance.initialized[idx] = false;
vk_instance.backends[idx] = nullptr;