use sliding window for phi3
This commit is contained in:
parent
50e05353e8
commit
03cb5cda6d
1 changed files with 29 additions and 9 deletions
|
@ -4974,6 +4974,8 @@ static void llm_load_hparams(
|
|||
} break;
|
||||
case LLM_ARCH_PHI3:
|
||||
{
|
||||
hparams.n_swa = 2048;
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
|
@ -10843,7 +10845,7 @@ struct llm_build_context {
|
|||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto residual = inpL;
|
||||
|
@ -10901,7 +10903,7 @@ struct llm_build_context {
|
|||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
|
@ -14108,18 +14110,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
"causal attention is not supported by this model"
|
||||
);
|
||||
|
||||
if (lctx.inp_KQ_mask) {
|
||||
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn && !lctx.is_encoding) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
|
||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
if (lctx.inp_KQ_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
data_swa = (float *) lctx.inp_KQ_mask->data;
|
||||
}
|
||||
|
||||
if (lctx.inp_KQ_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
|
||||
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
||||
}
|
||||
|
||||
|
@ -14142,7 +14149,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
|
@ -14154,9 +14164,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue