From f281d76f4188a093faa88b4976408da689e019c9 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Mon, 12 Feb 2024 12:13:21 -0600 Subject: [PATCH] bring back non-causal attention --- llama.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index e13850e2b..ec2a42736 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4849,7 +4849,6 @@ struct llm_build_context { const int32_t n_orig_ctx; const bool do_rope_shift; - const bool causal_attn; const bool do_pooling; const llm_build_cb & cb; @@ -4894,7 +4893,6 @@ struct llm_build_context { kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), - causal_attn (hparams.causal_attn), do_pooling (hparams.pooling_layer && cparams.do_pooling), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { @@ -7361,7 +7359,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || + (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) { f = -INFINITY; } else { f = 0;