llama : only use default buffer types for the KV cache (#10358)
This commit is contained in:
parent
20a780c7b6
commit
be5caccef9
2 changed files with 9 additions and 16 deletions
|
@ -3460,21 +3460,13 @@ static bool llama_kv_cache_init(
|
|||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
|
||||
const llama_model::buft_list_t * buft_list;
|
||||
ggml_backend_buffer_type_t buft;
|
||||
if (offload) {
|
||||
buft_list = model.dev_layer.at(i).buft_list;
|
||||
auto * dev = model.dev_layer.at(i).dev;
|
||||
buft = ggml_backend_dev_buffer_type(dev);
|
||||
} else {
|
||||
buft_list = &model.cpu_buft_list;
|
||||
buft = ggml_backend_cpu_buffer_type();
|
||||
}
|
||||
ggml_backend_buffer_type_t buft = select_buft(*buft_list,
|
||||
[&](ggml_context * ctx) {
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
||||
if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
|
||||
return k;
|
||||
}
|
||||
ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||
return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
|
||||
});
|
||||
ggml_context * ctx = ctx_for_buft(buft);
|
||||
|
||||
if (!ctx) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue