cuda: add q8_0->f32 cpy operation (#9571)
llama: enable K-shift for quantized KV cache It will fail on unsupported backends or quant types.
This commit is contained in:
parent
0b3bf966f4
commit
116efee0ee
3 changed files with 82 additions and 9 deletions
|
@ -9930,17 +9930,36 @@ struct llm_build_context {
|
|||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
struct ggml_tensor * tmp =
|
||||
// we rotate only the first n_rot dimensions
|
||||
ggml_rope_ext_inplace(ctx0,
|
||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||
n_embd_head_k, n_head_kv, n_ctx,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
0),
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||
n_embd_head_k, n_head_kv, n_ctx,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
0);
|
||||
|
||||
struct ggml_tensor * tmp;
|
||||
if (ggml_is_quantized(k->type)) {
|
||||
// dequantize to f32 -> RoPE -> quantize back
|
||||
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
|
||||
cb(tmp, "K_f32", il);
|
||||
for (auto * backend : lctx.backends) {
|
||||
// Figure out which backend KV cache belongs to
|
||||
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
|
||||
ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
|
||||
break;
|
||||
}
|
||||
}
|
||||
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cb(tmp, "K_shifted_f32", il);
|
||||
tmp = ggml_cpy(ctx0, tmp, k);
|
||||
} else {
|
||||
// we rotate only the first n_rot dimensions
|
||||
tmp = ggml_rope_ext_inplace(ctx0, k,
|
||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue