fix data_swa uninitialized

This commit is contained in:
ngxson 2024-06-30 20:18:53 +02:00
parent 7df7530b8f
commit ab2c3de9b3

View file

@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data;
float * data_swa = nullptr;
const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens;
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
GGML_ASSERT(hparams.n_ctx_swa > 0);
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
data = (float *) lctx.inp_KQ_mask_l[1]->data;
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
}
// For causal attention, use only the previous KV cells
@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
// may need to cut off old tokens for sliding window
if (data_swa && f != -INFINITY) {
const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens;
if (pos - lctx.kv_self.cells[i].pos > n_keep) {
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;