refactor llama_batch_get_one
This commit is contained in:
parent
17880771ad
commit
b226c5b1a7
2 changed files with 80 additions and 70 deletions
|
@ -232,8 +232,11 @@ extern "C" {
|
||||||
// - token : the token ids of the input (used when embd is NULL)
|
// - token : the token ids of the input (used when embd is NULL)
|
||||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||||
// - pos : the positions of the respective token in the sequence
|
// - pos : the positions of the respective token in the sequence
|
||||||
|
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||||
// - seq_id : the sequence to which the respective token belongs
|
// - seq_id : the sequence to which the respective token belongs
|
||||||
|
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||||
|
// (if set to NULL, only the logits for last token will be returned)
|
||||||
//
|
//
|
||||||
typedef struct llama_batch {
|
typedef struct llama_batch {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
@ -244,15 +247,6 @@ extern "C" {
|
||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id;
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id;
|
||||||
int8_t * logits; // TODO: rename this to "output"
|
int8_t * logits; // TODO: rename this to "output"
|
||||||
|
|
||||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
|
||||||
// for future-proof code, use the above fields instead and ignore everything below
|
|
||||||
//
|
|
||||||
// pos[i] = all_pos_0 + i*all_pos_1
|
|
||||||
//
|
|
||||||
llama_pos all_pos_0; // used if pos == NULL
|
|
||||||
llama_pos all_pos_1; // used if pos == NULL
|
|
||||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
|
||||||
} llama_batch;
|
} llama_batch;
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
enum llama_model_kv_override_type {
|
||||||
|
@ -775,15 +769,15 @@ extern "C" {
|
||||||
// Decoding
|
// Decoding
|
||||||
//
|
//
|
||||||
|
|
||||||
// Return batch for single sequence of tokens starting at pos_0
|
// Return batch for single sequence of tokens
|
||||||
|
// The sequence ID will be fixed to 0
|
||||||
|
// The position of the tokens will be tracked automatically by llama_decode
|
||||||
//
|
//
|
||||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||||
//
|
//
|
||||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens);
|
||||||
llama_pos pos_0,
|
|
||||||
llama_seq_id seq_id);
|
|
||||||
|
|
||||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||||
// Each token can be assigned up to n_seq_max sequence ids
|
// Each token can be assigned up to n_seq_max sequence ids
|
||||||
|
|
130
src/llama.cpp
130
src/llama.cpp
|
@ -2941,9 +2941,6 @@ struct llama_sbatch_seq {
|
||||||
llama_seq_id * seq_id;
|
llama_seq_id * seq_id;
|
||||||
size_t offset;
|
size_t offset;
|
||||||
size_t length;
|
size_t length;
|
||||||
|
|
||||||
// helper for smoother batch API transition -- can be deprecated in the future
|
|
||||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// sequence-length-aware batch splitting
|
// sequence-length-aware batch splitting
|
||||||
|
@ -3038,30 +3035,18 @@ struct llama_sbatch {
|
||||||
} else {
|
} else {
|
||||||
ubatch.embd = nullptr;
|
ubatch.embd = nullptr;
|
||||||
}
|
}
|
||||||
// from here on, the else branches are deprecated;
|
if (ubatch.equal_seqs) {
|
||||||
// they are helpers for smoother batch API transition
|
for (size_t i = 0; i < length; ++i) {
|
||||||
if (batch->pos) {
|
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||||
if (ubatch.equal_seqs) {
|
|
||||||
for (size_t i = 0; i < length; ++i) {
|
|
||||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// simple split
|
|
||||||
ubatch.pos = batch->pos + seq.offset;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
// simple split
|
||||||
llama_pos bi = ids[seq.offset + i];
|
ubatch.pos = batch->pos + seq.offset;
|
||||||
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (ubatch.equal_seqs) {
|
if (ubatch.equal_seqs) {
|
||||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||||
if (seq.seq_id) {
|
if (seq.seq_id) {
|
||||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||||
} else {
|
|
||||||
GGML_ASSERT(seq.n_seq_id == 1);
|
|
||||||
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// simple split
|
// simple split
|
||||||
|
@ -3074,10 +3059,6 @@ struct llama_sbatch {
|
||||||
}
|
}
|
||||||
if (batch->seq_id) {
|
if (batch->seq_id) {
|
||||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < length; ++i) {
|
|
||||||
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (logits_all) {
|
if (logits_all) {
|
||||||
|
@ -3196,7 +3177,6 @@ struct llama_sbatch {
|
||||||
s.seq_id = nullptr;
|
s.seq_id = nullptr;
|
||||||
s.offset = 0;
|
s.offset = 0;
|
||||||
s.length = n_tokens;
|
s.length = n_tokens;
|
||||||
s.all_seq_id = batch.all_seq_id;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::sort(ids.begin(), ids.end(),
|
std::sort(ids.begin(), ids.end(),
|
||||||
|
@ -3219,7 +3199,7 @@ struct llama_sbatch {
|
||||||
if (batch.pos) {
|
if (batch.pos) {
|
||||||
return batch.pos[a] < batch.pos[b];
|
return batch.pos[a] < batch.pos[b];
|
||||||
}
|
}
|
||||||
// no pos, sort by id (assuming batch.all_pos_1 is positive)
|
// no pos, sort by id
|
||||||
return a < b;
|
return a < b;
|
||||||
}
|
}
|
||||||
// shared prompts go first
|
// shared prompts go first
|
||||||
|
@ -3229,30 +3209,25 @@ struct llama_sbatch {
|
||||||
// init seq
|
// init seq
|
||||||
llama_sbatch_seq * last_seq = nullptr;
|
llama_sbatch_seq * last_seq = nullptr;
|
||||||
|
|
||||||
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
for (size_t i = 0; i < n_tokens; ++i) {
|
||||||
for (size_t i = 0; i < n_tokens; ++i) {
|
const size_t bi = ids[i];
|
||||||
const size_t bi = ids[i];
|
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
if (last_seq != nullptr) {
|
||||||
if (last_seq != nullptr) {
|
bool same = n_seqs == last_seq->n_seq_id;
|
||||||
bool same = n_seqs == last_seq->n_seq_id;
|
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
same = false;
|
||||||
same = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (same) {
|
|
||||||
last_seq->length += 1;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
if (same) {
|
||||||
seq.push_back(new_seq);
|
last_seq->length += 1;
|
||||||
last_seq = &seq.back();
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
||||||
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
|
||||||
seq.push_back(new_seq);
|
seq.push_back(new_seq);
|
||||||
|
last_seq = &seq.back();
|
||||||
}
|
}
|
||||||
// keep shared prompts first at the end, then sort by length descending.
|
// keep shared prompts first at the end, then sort by length descending.
|
||||||
std::sort(seq.begin(), seq.end(),
|
std::sort(seq.begin(), seq.end(),
|
||||||
|
@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
||||||
|
|
||||||
struct llama_batch llama_batch_get_one(
|
struct llama_batch llama_batch_get_one(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens) {
|
||||||
llama_pos pos_0,
|
|
||||||
llama_seq_id seq_id) {
|
|
||||||
return {
|
return {
|
||||||
/*n_tokens =*/ n_tokens,
|
/*n_tokens =*/ n_tokens,
|
||||||
/*tokens =*/ tokens,
|
/*tokens =*/ tokens,
|
||||||
|
@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one(
|
||||||
/*n_seq_id =*/ nullptr,
|
/*n_seq_id =*/ nullptr,
|
||||||
/*seq_id =*/ nullptr,
|
/*seq_id =*/ nullptr,
|
||||||
/*logits =*/ nullptr,
|
/*logits =*/ nullptr,
|
||||||
/*all_pos_0 =*/ pos_0,
|
|
||||||
/*all_pos_1 =*/ 1,
|
|
||||||
/*all_seq_id =*/ seq_id,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||||
/*n_seq_id =*/ nullptr,
|
/*n_seq_id =*/ nullptr,
|
||||||
/*seq_id =*/ nullptr,
|
/*seq_id =*/ nullptr,
|
||||||
/*logits =*/ nullptr,
|
/*logits =*/ nullptr,
|
||||||
/*all_pos_0 =*/ 0,
|
|
||||||
/*all_pos_1 =*/ 0,
|
|
||||||
/*all_seq_id =*/ 0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (embd) {
|
if (embd) {
|
||||||
|
@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) {
|
||||||
if (batch.logits) free(batch.logits);
|
if (batch.logits) free(batch.logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// temporary allocate memory for the input batch if needed
|
||||||
|
struct llama_batch_allocr {
|
||||||
|
static const llama_seq_id default_seq_id = 0;
|
||||||
|
std::array<llama_seq_id, 1> seq_id_0 = {default_seq_id};
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id *> seq_id;
|
||||||
|
std::vector<int8_t> logits;
|
||||||
|
// fulfill the batch returned by llama_batch_get_one
|
||||||
|
struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) {
|
||||||
|
struct llama_batch batch = in_batch;
|
||||||
|
if (!batch.pos) {
|
||||||
|
// determine the last position in KV cache
|
||||||
|
llama_pos last_pos;
|
||||||
|
for (const auto & cell : ctx->kv_self.cells) {
|
||||||
|
if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
|
||||||
|
last_pos = std::max(last_pos, cell.pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos.resize(batch.n_tokens);
|
||||||
|
for (int32_t i = 1; i <= batch.n_tokens; i++) {
|
||||||
|
pos[i] = i+last_pos;
|
||||||
|
}
|
||||||
|
batch.pos = pos.data();
|
||||||
|
}
|
||||||
|
if (!batch.n_seq_id) {
|
||||||
|
n_seq_id.reserve(batch.n_tokens);
|
||||||
|
for (int32_t i = 1; i <= batch.n_tokens; i++) {
|
||||||
|
n_seq_id[i] = seq_id_0.size();
|
||||||
|
}
|
||||||
|
batch.n_seq_id = n_seq_id.data();
|
||||||
|
}
|
||||||
|
if (!batch.seq_id) {
|
||||||
|
seq_id.reserve(batch.n_tokens);
|
||||||
|
for (int32_t i = 1; i <= batch.n_tokens; i++) {
|
||||||
|
seq_id[i] = seq_id_0.data();
|
||||||
|
}
|
||||||
|
batch.seq_id = seq_id.data();
|
||||||
|
}
|
||||||
|
if (!batch.logits) {
|
||||||
|
logits.reserve(batch.n_tokens);
|
||||||
|
logits[logits.size() - 1] = true;
|
||||||
|
batch.logits = logits.data();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
int32_t llama_encode(
|
int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch batch) {
|
||||||
const int ret = llama_encode_internal(*ctx, batch);
|
llama_batch_allocr batch_allocr;
|
||||||
|
const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
|
@ -21147,7 +21162,8 @@ int32_t llama_encode(
|
||||||
int32_t llama_decode(
|
int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch batch) {
|
||||||
const int ret = llama_decode_internal(*ctx, batch);
|
llama_batch_allocr batch_allocr;
|
||||||
|
const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue