llama : remove the h loop in llama_set_inputs

This commit removes the loop over h in the lctx.inp_KQ_mask block in
llama_set_inputs. It also removes the usage of h as this will always be
zero.

The motivation for this is to simplify the code and make it easier to
understand.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
Daniel Bevenius 2024-07-10 07:13:02 +02:00
parent a59f8fdc85
commit e95c2e1513
Failed to extract signature

View file

@ -13983,38 +13983,36 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// For causal attention, use only the previous KV cells // For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch. // of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) {
for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j];
const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
float f; float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY; f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -fabs(lctx.kv_self.cells[i].pos - pos);
} else { } else {
if (hparams.use_alibi) { f = 0.0f;
f = -fabs(lctx.kv_self.cells[i].pos - pos);
} else {
f = 0.0f;
}
} }
data[h*(n_kv*n_tokens) + j*n_kv + i] = f; }
data[j*n_kv + i] = f;
// may need to cut off old tokens for sliding window // may need to cut off old tokens for sliding window
if (data_swa) { if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY; f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
} }
data_swa[j*n_kv + i] = f;
} }
} }
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) { for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; data[i*n_kv + j] = -INFINITY;
}
} }
} }
} else { } else {
@ -14026,29 +14024,27 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data; float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) {
for (int j = 0; j < n_tokens; ++j) { const llama_seq_id seq_id = batch.seq_id[j][0];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
float f = -INFINITY; float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[i]; ++s) { for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) { if (batch.seq_id[i][s] == seq_id) {
if (hparams.use_alibi) { if (hparams.use_alibi) {
f = -fabs(batch.pos[i] - batch.pos[j]); f = -fabs(batch.pos[i] - batch.pos[j]);
} else { } else {
f = 0.0f; f = 0.0f;
}
break;
} }
break;
} }
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
} }
for (int i = n_tokens; i < n_stride; ++i) { data[j*n_stride + i] = f;
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; }
}
for (int i = n_tokens; i < n_stride; ++i) {
data[j*n_stride + i] = -INFINITY;
} }
} }
} }