From 98e588c6eb54e7fe9726ad46c98d30bac5473a8b Mon Sep 17 00:00:00 2001 From: niansa Date: Fri, 23 Jun 2023 16:50:37 +0200 Subject: [PATCH] Fix ggml_vk_h2d_tensor throwing on second call --- ggml-vulkan.cpp | 18 +++++++++++++----- llama.cpp | 16 ++++++++-------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index c35509bd8..c260c59c2 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -91,12 +91,20 @@ void ggml_vk_h2d_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * auto data = t->data; auto size = ggml_nbytes(t); - std::vector vec(size); - memcpy(vec.data(), data, size); + auto res = ctx->tensors.find(t); - auto tensor = mgr.tensorT(vec); - mgr.sequence()->eval({tensor}); - ctx->tensors.emplace(t, std::move(tensor)); + if (res != ctx->tensors.end()) { + assert(res->second->size() != size); + res->second->setRawData(data); + mgr.sequence()->eval({res->second}); + } else { + std::vector vec(size); + memcpy(vec.data(), data, size); + + auto tensor = mgr.tensorT(vec); + mgr.sequence()->eval({tensor}); + ctx->tensors.emplace(t, std::move(tensor)); + } } void ggml_vk_d2h_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t) { diff --git a/llama.cpp b/llama.cpp index 89c7fa656..cbe285afb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2769,7 +2769,7 @@ struct llama_context * llama_init_from_file( } #elif defined(GGML_USE_KOMPUTE) if (params.n_gpu_layers > 0) { - // this allocates all Metal resources and memory buffers + // this allocates all Vulkan resources and memory buffers ctx->ctx_kompute = ggml_vk_init(); void * data_ptr = NULL; @@ -2787,21 +2787,21 @@ struct llama_context * llama_init_from_file( printf("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); -#define LLAMA_METAL_CHECK_BUF(result) \ +#define LLAMA_VK_CHECK_BUF(result) \ if (!(result)) { \ fprintf(stderr, "%s: failed to add buffer\n", __func__); \ llama_free(ctx); \ return NULL; \ } - LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "data", data_ptr, data_size, max_size)); + LLAMA_VK_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "data", data_ptr, data_size, max_size)); - LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size, 0)); + LLAMA_VK_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); + LLAMA_VK_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0)); -#undef LLAMA_METAL_CHECK_BUF + LLAMA_VK_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0)); + LLAMA_VK_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0)); +#undef LLAMA_VK_CHECK_BUF } #endif