Abstract into GGML
This commit is contained in:
parent
a3d48e448a
commit
38f4863a24
3 changed files with 21 additions and 34 deletions
|
@ -233,8 +233,7 @@ extern "C" {
|
||||||
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
|
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
// Copy K and V cache pointers to backend
|
// Copy K and V cache pointers to backend
|
||||||
GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size);
|
GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn);
|
||||||
GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -302,12 +302,26 @@ static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char
|
||||||
cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice);
|
cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) {
|
void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa,const int64_t n_embd_v_gqa, const bool flash_attn) {
|
||||||
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) {
|
std::vector<const char *> host_k_cache_ptrs;
|
||||||
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size);
|
std::vector<const char *> host_v_cache_ptrs;
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
// K cache pointer for this layer
|
||||||
|
ggml_tensor * tmp_tensor = kv_kl[il];
|
||||||
|
size_t tmp_offset = (ggml_row_size(kv_kl[il]->type, n_embd_k_gqa))*kv_head;
|
||||||
|
host_k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
|
||||||
|
// V cache pointer for this layer
|
||||||
|
tmp_tensor = kv_vl[il];
|
||||||
|
if (flash_attn) {
|
||||||
|
tmp_offset = (kv_head)*ggml_row_size(kv_vl[il]->type, n_embd_v_gqa);
|
||||||
|
} else {
|
||||||
|
tmp_offset = (kv_head)*ggml_element_size(kv_vl[il]);
|
||||||
|
}
|
||||||
|
host_v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
|
||||||
|
}
|
||||||
|
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_k_cache_ptrs.data(), host_k_cache_ptrs.size());
|
||||||
|
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_v_cache_ptrs.data(), host_v_cache_ptrs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f16_f32_cuda(
|
static void ggml_cpy_f16_f32_cuda(
|
||||||
|
|
|
@ -14735,33 +14735,7 @@ static int llama_decode_internal(
|
||||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||||
|
|
||||||
#ifdef GGML_USE_CUDA
|
#ifdef GGML_USE_CUDA
|
||||||
// Copy K and V cache pointers to backend
|
ggml_backend_copy_kv_cache_ptrs(model.hparams.n_layer, kv_self.head, kv_self.k_l.data(), kv_self.v_l.data(), hparams.n_embd_k_gqa(), hparams.n_embd_v_gqa(), cparams.flash_attn);
|
||||||
|
|
||||||
// Stage pointers for each layer in host vectors
|
|
||||||
std::vector<const char *> k_cache_ptrs;
|
|
||||||
std::vector<const char *> v_cache_ptrs;
|
|
||||||
const int64_t n_layer = model.hparams.n_layer;
|
|
||||||
const int64_t kv_head = kv_self.head;
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
|
||||||
// K cache pointer for this layer
|
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
||||||
ggml_tensor * tmp_tensor = kv_self.k_l[il];
|
|
||||||
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
|
|
||||||
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
|
|
||||||
// V cache pointer for this layer
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
||||||
tmp_tensor = kv_self.v_l[il];
|
|
||||||
if (cparams.flash_attn) {
|
|
||||||
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
|
||||||
} else {
|
|
||||||
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
|
|
||||||
}
|
|
||||||
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy host vector data to backend
|
|
||||||
ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size());
|
|
||||||
ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size());
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
llama_set_inputs(lctx, u_batch);
|
llama_set_inputs(lctx, u_batch);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue