fix data_swa uninitialized
This commit is contained in:
parent
7df7530b8f
commit
ab2c3de9b3
1 changed files with 4 additions and 3 deletions
|
@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||||
float * data_swa = nullptr;
|
float * data_swa = nullptr;
|
||||||
|
const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens;
|
||||||
|
|
||||||
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
|
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
|
||||||
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
|
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
|
||||||
GGML_ASSERT(hparams.n_ctx_swa > 0);
|
GGML_ASSERT(hparams.n_ctx_swa > 0);
|
||||||
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
|
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
|
||||||
data = (float *) lctx.inp_KQ_mask_l[1]->data;
|
data = (float *) lctx.inp_KQ_mask_l[1]->data;
|
||||||
|
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
|
||||||
}
|
}
|
||||||
|
|
||||||
// For causal attention, use only the previous KV cells
|
// For causal attention, use only the previous KV cells
|
||||||
|
@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||||
|
|
||||||
// may need to cut off old tokens for sliding window
|
// may need to cut off old tokens for sliding window
|
||||||
if (data_swa && f != -INFINITY) {
|
if (data_swa) {
|
||||||
const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens;
|
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
|
||||||
if (pos - lctx.kv_self.cells[i].pos > n_keep) {
|
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
}
|
}
|
||||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue