Abstract into GGML

This commit is contained in:
Alan Gray 2024-08-14 01:38:00 -07:00
parent a3d48e448a
commit 38f4863a24
3 changed files with 21 additions and 34 deletions

View file

@ -233,8 +233,7 @@ extern "C" {
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
// Copy K and V cache pointers to backend // Copy K and V cache pointers to backend
GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size); GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn);
GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size);
#ifdef __cplusplus #ifdef __cplusplus
} }

View file

@ -302,12 +302,26 @@ static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char
cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice); cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice);
} }
void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) { void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa,const int64_t n_embd_v_gqa, const bool flash_attn) {
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size);
}
void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) { std::vector<const char *> host_k_cache_ptrs;
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size); std::vector<const char *> host_v_cache_ptrs;
for (int il = 0; il < n_layer; ++il) {
// K cache pointer for this layer
ggml_tensor * tmp_tensor = kv_kl[il];
size_t tmp_offset = (ggml_row_size(kv_kl[il]->type, n_embd_k_gqa))*kv_head;
host_k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
// V cache pointer for this layer
tmp_tensor = kv_vl[il];
if (flash_attn) {
tmp_offset = (kv_head)*ggml_row_size(kv_vl[il]->type, n_embd_v_gqa);
} else {
tmp_offset = (kv_head)*ggml_element_size(kv_vl[il]);
}
host_v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
}
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_k_cache_ptrs.data(), host_k_cache_ptrs.size());
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_v_cache_ptrs.data(), host_v_cache_ptrs.size());
} }
static void ggml_cpy_f16_f32_cuda( static void ggml_cpy_f16_f32_cuda(

View file

@ -14735,33 +14735,7 @@ static int llama_decode_internal(
ggml_backend_sched_alloc_graph(lctx.sched, gf); ggml_backend_sched_alloc_graph(lctx.sched, gf);
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
// Copy K and V cache pointers to backend ggml_backend_copy_kv_cache_ptrs(model.hparams.n_layer, kv_self.head, kv_self.k_l.data(), kv_self.v_l.data(), hparams.n_embd_k_gqa(), hparams.n_embd_v_gqa(), cparams.flash_attn);
// Stage pointers for each layer in host vectors
std::vector<const char *> k_cache_ptrs;
std::vector<const char *> v_cache_ptrs;
const int64_t n_layer = model.hparams.n_layer;
const int64_t kv_head = kv_self.head;
for (int il = 0; il < n_layer; ++il) {
// K cache pointer for this layer
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
ggml_tensor * tmp_tensor = kv_self.k_l[il];
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
// V cache pointer for this layer
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
tmp_tensor = kv_self.v_l[il];
if (cparams.flash_attn) {
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
} else {
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
}
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
}
// copy host vector data to backend
ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size());
ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size());
#endif #endif
llama_set_inputs(lctx, u_batch); llama_set_inputs(lctx, u_batch);