bring back non-causal attention

This commit is contained in:
Douglas Hanley 2024-02-12 12:13:21 -06:00
parent 1549493e94
commit f281d76f41

View file

@ -4849,7 +4849,6 @@ struct llm_build_context {
const int32_t n_orig_ctx;
const bool do_rope_shift;
const bool causal_attn;
const bool do_pooling;
const llm_build_cb & cb;
@ -4894,7 +4893,6 @@ struct llm_build_context {
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift),
causal_attn (hparams.causal_attn),
do_pooling (hparams.pooling_layer && cparams.do_pooling),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
@ -7361,7 +7359,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
f = -INFINITY;
} else {
f = 0;