bring back non-causal attention

This commit is contained in:
Douglas Hanley 2024-02-12 12:13:21 -06:00
parent 1549493e94
commit f281d76f41

View file

@ -4849,7 +4849,6 @@ struct llm_build_context {
const int32_t n_orig_ctx; const int32_t n_orig_ctx;
const bool do_rope_shift; const bool do_rope_shift;
const bool causal_attn;
const bool do_pooling; const bool do_pooling;
const llm_build_cb & cb; const llm_build_cb & cb;
@ -4894,7 +4893,6 @@ struct llm_build_context {
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx), n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift), do_rope_shift (worst_case || kv_self.has_shift),
causal_attn (hparams.causal_attn),
do_pooling (hparams.pooling_layer && cparams.do_pooling), do_pooling (hparams.pooling_layer && cparams.do_pooling),
cb (cb), cb (cb),
buf_compute_meta (lctx.buf_compute_meta) { buf_compute_meta (lctx.buf_compute_meta) {
@ -7361,7 +7359,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
float f; float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
f = -INFINITY; f = -INFINITY;
} else { } else {
f = 0; f = 0;