diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 9e4224309..2b2e3378c 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -106,10 +106,6 @@ std::shared_ptr ggml_vk_get_buffer(struct ggml_kompute_context * ctx void ggml_vk_h2d_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t) { printf("%s: Context: %p Tensor: %p\n", __func__, ctx, t); - if (t->backend != GGML_BACKEND_GPU) { - return; - } - auto data = t->data; auto size = ggml_nbytes(t); @@ -121,6 +117,7 @@ void ggml_vk_h2d_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * GGML_ASSERT(res->second->size() != size); res->second->setRawData(data); mgr.sequence()->eval({res->second}); + printf("%s: Updating Host->GPU tensor: %p\n", __func__, t); } else { std::vector vec(size); memcpy(vec.data(), data, size); @@ -130,16 +127,13 @@ void ggml_vk_h2d_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * ctx->tensors_mutex.lock(); ctx->tensors.emplace(t, std::move(tensor)); ctx->tensors_mutex.unlock(); + printf("%s: Creating Host->GPU tensor: %p\n", __func__, t); } } void ggml_vk_d2h_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t) { printf("%s: Context: %p Tensor: %p\n", __func__, ctx, t); - if (t->backend != GGML_BACKEND_GPU) { - return; - } - auto data = t->data; auto size = ggml_nbytes(t); @@ -151,18 +145,21 @@ void ggml_vk_d2h_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * auto tensor = res->second; mgr.sequence()->eval({tensor}); memcpy(data, tensor->data(), size); + printf("%s: Updating GPU->Host tensor: %p\n", __func__, t); } static const std::shared_ptr & ggml_vk_get_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t) { printf("%s: Context: %p Tensor: %p\n", __func__, ctx, t); - GGML_ASSERT(t->backend != GGML_BACKEND_GPU); - ctx->tensors_mutex.lock(); auto res = ctx->tensors.find(t); ctx->tensors_mutex.unlock(); - GGML_ASSERT(res != ctx->tensors.end()); + + if (res == ctx->tensors.end()) { + ggml_vk_h2d_tensor(ctx, t); + return ggml_vk_get_tensor(ctx, t); + } return res->second; }