llama : remove the h loop in llama_set_inputs

This commit removes the loop over h in the lctx.inp_KQ_mask block in
llama_set_inputs. It also removes the usage of h as this will always be
zero.

The motivation for this is to simplify the code and make it easier to
understand.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
Daniel Bevenius 2024-07-10 07:13:02 +02:00
parent a59f8fdc85
commit e95c2e1513
Failed to extract signature

View file

@ -13983,7 +13983,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// For causal attention, use only the previous KV cells // For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch. // of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j]; const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0]; const llama_seq_id seq_id = batch.seq_id[j][0];
@ -13999,22 +13998,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
f = 0.0f; f = 0.0f;
} }
} }
data[h*(n_kv*n_tokens) + j*n_kv + i] = f; data[j*n_kv + i] = f;
// may need to cut off old tokens for sliding window // may need to cut off old tokens for sliding window
if (data_swa) { if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY; f = -INFINITY;
} }
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; data_swa[j*n_kv + i] = f;
} }
} }
} }
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) { for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; data[i*n_kv + j] = -INFINITY;
}
} }
} }
} else { } else {
@ -14026,7 +14024,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data; float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0]; const llama_seq_id seq_id = batch.seq_id[j][0];
@ -14043,12 +14040,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; data[j*n_stride + i] = f;
} }
for (int i = n_tokens; i < n_stride; ++i) { for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; data[j*n_stride + i] = -INFINITY;
}
} }
} }
} }