diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 40a4196a5..95c4f584e 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -233,8 +233,7 @@ extern "C" { GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); // Copy K and V cache pointers to backend - GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size); - GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size); + GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c8090f502..613a98bf9 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -302,12 +302,26 @@ static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice); } -void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) { - ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size); -} +void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa,const int64_t n_embd_v_gqa, const bool flash_attn) { -void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) { - ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size); + std::vector host_k_cache_ptrs; + std::vector host_v_cache_ptrs; + for (int il = 0; il < n_layer; ++il) { + // K cache pointer for this layer + ggml_tensor * tmp_tensor = kv_kl[il]; + size_t tmp_offset = (ggml_row_size(kv_kl[il]->type, n_embd_k_gqa))*kv_head; + host_k_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + // V cache pointer for this layer + tmp_tensor = kv_vl[il]; + if (flash_attn) { + tmp_offset = (kv_head)*ggml_row_size(kv_vl[il]->type, n_embd_v_gqa); + } else { + tmp_offset = (kv_head)*ggml_element_size(kv_vl[il]); + } + host_v_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + } + ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_k_cache_ptrs.data(), host_k_cache_ptrs.size()); + ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_v_cache_ptrs.data(), host_v_cache_ptrs.size()); } static void ggml_cpy_f16_f32_cuda( diff --git a/src/llama.cpp b/src/llama.cpp index 9dace659e..d4875a84f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14735,33 +14735,7 @@ static int llama_decode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); #ifdef GGML_USE_CUDA - // Copy K and V cache pointers to backend - - // Stage pointers for each layer in host vectors - std::vector k_cache_ptrs; - std::vector v_cache_ptrs; - const int64_t n_layer = model.hparams.n_layer; - const int64_t kv_head = kv_self.head; - for (int il = 0; il < n_layer; ++il) { - // K cache pointer for this layer - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - ggml_tensor * tmp_tensor = kv_self.k_l[il]; - size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head; - k_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); - // V cache pointer for this layer - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - tmp_tensor = kv_self.v_l[il]; - if (cparams.flash_attn) { - tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); - } else { - tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]); - } - v_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); - } - - // copy host vector data to backend - ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size()); - ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size()); + ggml_backend_copy_kv_cache_ptrs(model.hparams.n_layer, kv_self.head, kv_self.k_l.data(), kv_self.v_l.data(), hparams.n_embd_k_gqa(), hparams.n_embd_v_gqa(), cparams.flash_attn); #endif llama_set_inputs(lctx, u_batch);