llama : only use default buffer types for the KV cache (#10358)

This commit is contained in:
Diego Devesa 2024-11-17 12:25:45 +01:00 committed by GitHub
parent 20a780c7b6
commit be5caccef9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 16 deletions

View file

@ -3460,21 +3460,13 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const llama_model::buft_list_t * buft_list;
ggml_backend_buffer_type_t buft;
if (offload) {
buft_list = model.dev_layer.at(i).buft_list;
auto * dev = model.dev_layer.at(i).dev;
buft = ggml_backend_dev_buffer_type(dev);
} else {
buft_list = &model.cpu_buft_list;
buft = ggml_backend_cpu_buffer_type();
}
ggml_backend_buffer_type_t buft = select_buft(*buft_list,
[&](ggml_context * ctx) {
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
return k;
}
ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
});
ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {