llama : remove the h loop in llama_set_inputs
This commit removes the loop over h in the lctx.inp_KQ_mask block in llama_set_inputs. It also removes the usage of h as this will always be zero. The motivation for this is to simplify the code and make it easier to understand. Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
parent
a59f8fdc85
commit
e95c2e1513
1 changed files with 38 additions and 42 deletions
|
@ -13983,38 +13983,36 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
// For causal attention, use only the previous KV cells
|
||||
// of the correct sequence for each token of the batch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -fabs(lctx.kv_self.cells[i].pos - pos);
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -fabs(lctx.kv_self.cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
f = 0.0f;
|
||||
}
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
data[j*n_kv + i] = f;
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -14026,29 +14024,27 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
|
||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id[i][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -fabs(batch.pos[i] - batch.pos[j]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id[i][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -fabs(batch.pos[i] - batch.pos[j]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
||||
}
|
||||
data[j*n_stride + i] = f;
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[j*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue