refactor llama_batch_get_one

This commit is contained in:
Xuan Son Nguyen 2024-10-11 11:48:09 +02:00
parent 17880771ad
commit b226c5b1a7
2 changed files with 80 additions and 70 deletions

View file

@ -232,8 +232,11 @@ extern "C" {
// - token : the token ids of the input (used when embd is NULL) // - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence // - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs // - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
// //
typedef struct llama_batch { typedef struct llama_batch {
int32_t n_tokens; int32_t n_tokens;
@ -244,15 +247,6 @@ extern "C" {
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output" int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
//
// pos[i] = all_pos_0 + i*all_pos_1
//
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch; } llama_batch;
enum llama_model_kv_override_type { enum llama_model_kv_override_type {
@ -775,15 +769,15 @@ extern "C" {
// Decoding // Decoding
// //
// Return batch for single sequence of tokens starting at pos_0 // Return batch for single sequence of tokens
// The sequence ID will be fixed to 0
// The position of the tokens will be tracked automatically by llama_decode
// //
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
// //
LLAMA_API struct llama_batch llama_batch_get_one( LLAMA_API struct llama_batch llama_batch_get_one(
llama_token * tokens, llama_token * tokens,
int32_t n_tokens, int32_t n_tokens);
llama_pos pos_0,
llama_seq_id seq_id);
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids // Each token can be assigned up to n_seq_max sequence ids

View file

@ -2941,9 +2941,6 @@ struct llama_sbatch_seq {
llama_seq_id * seq_id; llama_seq_id * seq_id;
size_t offset; size_t offset;
size_t length; size_t length;
// helper for smoother batch API transition -- can be deprecated in the future
llama_seq_id all_seq_id; // used if seq_id == NULL
}; };
// sequence-length-aware batch splitting // sequence-length-aware batch splitting
@ -3038,30 +3035,18 @@ struct llama_sbatch {
} else { } else {
ubatch.embd = nullptr; ubatch.embd = nullptr;
} }
// from here on, the else branches are deprecated; if (ubatch.equal_seqs) {
// they are helpers for smoother batch API transition for (size_t i = 0; i < length; ++i) {
if (batch->pos) { ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
}
} else {
// simple split
ubatch.pos = batch->pos + seq.offset;
} }
} else { } else {
for (size_t i = 0; i < length; ++i) { // simple split
llama_pos bi = ids[seq.offset + i]; ubatch.pos = batch->pos + seq.offset;
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
}
} }
if (ubatch.equal_seqs) { if (ubatch.equal_seqs) {
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
if (seq.seq_id) { if (seq.seq_id) {
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
} else {
GGML_ASSERT(seq.n_seq_id == 1);
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
} }
} else { } else {
// simple split // simple split
@ -3074,10 +3059,6 @@ struct llama_sbatch {
} }
if (batch->seq_id) { if (batch->seq_id) {
ubatch.seq_id = batch->seq_id + seq.offset; ubatch.seq_id = batch->seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
}
} }
} }
if (logits_all) { if (logits_all) {
@ -3196,7 +3177,6 @@ struct llama_sbatch {
s.seq_id = nullptr; s.seq_id = nullptr;
s.offset = 0; s.offset = 0;
s.length = n_tokens; s.length = n_tokens;
s.all_seq_id = batch.all_seq_id;
return; return;
} }
std::sort(ids.begin(), ids.end(), std::sort(ids.begin(), ids.end(),
@ -3219,7 +3199,7 @@ struct llama_sbatch {
if (batch.pos) { if (batch.pos) {
return batch.pos[a] < batch.pos[b]; return batch.pos[a] < batch.pos[b];
} }
// no pos, sort by id (assuming batch.all_pos_1 is positive) // no pos, sort by id
return a < b; return a < b;
} }
// shared prompts go first // shared prompts go first
@ -3229,30 +3209,25 @@ struct llama_sbatch {
// init seq // init seq
llama_sbatch_seq * last_seq = nullptr; llama_sbatch_seq * last_seq = nullptr;
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { for (size_t i = 0; i < n_tokens; ++i) {
for (size_t i = 0; i < n_tokens; ++i) { const size_t bi = ids[i];
const size_t bi = ids[i]; const int32_t n_seqs = batch.n_seq_id[bi];
const int32_t n_seqs = batch.n_seq_id[bi]; llama_seq_id * seq_ids = batch.seq_id[bi];
llama_seq_id * seq_ids = batch.seq_id[bi]; if (last_seq != nullptr) {
if (last_seq != nullptr) { bool same = n_seqs == last_seq->n_seq_id;
bool same = n_seqs == last_seq->n_seq_id; for (int32_t j = 0; same && j < n_seqs; ++j) {
for (int32_t j = 0; same && j < n_seqs; ++j) { if (seq_ids[j] != last_seq->seq_id[j]) {
if (seq_ids[j] != last_seq->seq_id[j]) { same = false;
same = false;
}
}
if (same) {
last_seq->length += 1;
continue;
} }
} }
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; if (same) {
seq.push_back(new_seq); last_seq->length += 1;
last_seq = &seq.back(); continue;
}
} }
} else { llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
seq.push_back(new_seq); seq.push_back(new_seq);
last_seq = &seq.back();
} }
// keep shared prompts first at the end, then sort by length descending. // keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(), std::sort(seq.begin(), seq.end(),
@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_get_one(
llama_token * tokens, llama_token * tokens,
int32_t n_tokens, int32_t n_tokens) {
llama_pos pos_0,
llama_seq_id seq_id) {
return { return {
/*n_tokens =*/ n_tokens, /*n_tokens =*/ n_tokens,
/*tokens =*/ tokens, /*tokens =*/ tokens,
@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one(
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
/*logits =*/ nullptr, /*logits =*/ nullptr,
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
}; };
} }
@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
/*logits =*/ nullptr, /*logits =*/ nullptr,
/*all_pos_0 =*/ 0,
/*all_pos_1 =*/ 0,
/*all_seq_id =*/ 0,
}; };
if (embd) { if (embd) {
@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) {
if (batch.logits) free(batch.logits); if (batch.logits) free(batch.logits);
} }
// temporary allocate memory for the input batch if needed
struct llama_batch_allocr {
static const llama_seq_id default_seq_id = 0;
std::array<llama_seq_id, 1> seq_id_0 = {default_seq_id};
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
// fulfill the batch returned by llama_batch_get_one
struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) {
struct llama_batch batch = in_batch;
if (!batch.pos) {
// determine the last position in KV cache
llama_pos last_pos;
for (const auto & cell : ctx->kv_self.cells) {
if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
last_pos = std::max(last_pos, cell.pos);
}
}
pos.resize(batch.n_tokens);
for (int32_t i = 1; i <= batch.n_tokens; i++) {
pos[i] = i+last_pos;
}
batch.pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.reserve(batch.n_tokens);
for (int32_t i = 1; i <= batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.reserve(batch.n_tokens);
for (int32_t i = 1; i <= batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.reserve(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
}
}
};
int32_t llama_encode( int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch) { struct llama_batch batch) {
const int ret = llama_encode_internal(*ctx, batch); llama_batch_allocr batch_allocr;
const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
if (ret < 0) { if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
} }
@ -21147,7 +21162,8 @@ int32_t llama_encode(
int32_t llama_decode( int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch) { struct llama_batch batch) {
const int ret = llama_decode_internal(*ctx, batch); llama_batch_allocr batch_allocr;
const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
if (ret < 0) { if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
} }