llama : remove the h loop in llama_set_inputs

This commit removes the loop over h in the lctx.inp_KQ_mask block in
llama_set_inputs. It also removes the usage of h as this will always be
zero.

The motivation for this is to simplify the code and make it easier to
understand.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
Daniel Bevenius 2024-07-10 07:13:02 +02:00
parent a59f8fdc85
commit e95c2e1513
Failed to extract signature

View file

@ -13983,7 +13983,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
@ -13999,22 +13998,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
f = 0.0f;
}
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
data[j*n_kv + i] = f;
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
data_swa[j*n_kv + i] = f;
}
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
data[i*n_kv + j] = -INFINITY;
}
}
} else {
@ -14026,7 +14024,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];
@ -14043,12 +14040,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
data[j*n_stride + i] = f;
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
data[j*n_stride + i] = -INFINITY;
}
}
}