llama : remove the h loop in llama_set_inputs
This commit removes the loop over h in the lctx.inp_KQ_mask block in llama_set_inputs. It also removes the usage of h as this will always be zero. The motivation for this is to simplify the code and make it easier to understand. Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
parent
a59f8fdc85
commit
e95c2e1513
1 changed files with 38 additions and 42 deletions
|
@ -13983,7 +13983,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
// For causal attention, use only the previous KV cells
|
// For causal attention, use only the previous KV cells
|
||||||
// of the correct sequence for each token of the batch.
|
// of the correct sequence for each token of the batch.
|
||||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||||
for (int h = 0; h < 1; ++h) {
|
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
@ -13999,22 +13998,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
f = 0.0f;
|
f = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data[j*n_kv + i] = f;
|
||||||
|
|
||||||
// may need to cut off old tokens for sliding window
|
// may need to cut off old tokens for sliding window
|
||||||
if (data_swa) {
|
if (data_swa) {
|
||||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
}
|
}
|
||||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data_swa[j*n_kv + i] = f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||||
for (int j = 0; j < n_kv; ++j) {
|
for (int j = 0; j < n_kv; ++j) {
|
||||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
data[i*n_kv + j] = -INFINITY;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -14026,7 +14024,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
|
@ -14043,12 +14040,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
|
data[j*n_stride + i] = f;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = n_tokens; i < n_stride; ++i) {
|
for (int i = n_tokens; i < n_stride; ++i) {
|
||||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
data[j*n_stride + i] = -INFINITY;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue