fix data_swa uninitialized
This commit is contained in:
parent
7df7530b8f
commit
ab2c3de9b3
1 changed files with 4 additions and 3 deletions
|
@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
|
||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||
float * data_swa = nullptr;
|
||||
const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens;
|
||||
|
||||
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
|
||||
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
|
||||
GGML_ASSERT(hparams.n_ctx_swa > 0);
|
||||
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
|
||||
data = (float *) lctx.inp_KQ_mask_l[1]->data;
|
||||
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
|
||||
}
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
|
@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa && f != -INFINITY) {
|
||||
const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens;
|
||||
if (pos - lctx.kv_self.cells[i].pos > n_keep) {
|
||||
if (data_swa) {
|
||||
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue