This commit is contained in:
Xuan Son Nguyen 2024-10-11 13:59:26 +02:00
parent 59fd6b6119
commit 7740c969d0

View file

@ -21081,58 +21081,59 @@ void llama_batch_free(struct llama_batch batch) {
} }
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
static const llama_seq_id batch_default_seq_id = 0;
struct llama_batch_allocr { struct llama_batch_allocr {
static const llama_seq_id default_seq_id = 0; std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
std::array<llama_seq_id, 1> seq_id_0 = {default_seq_id};
std::vector<llama_pos> pos; std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id; std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id; std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits; std::vector<int8_t> logits;
// fulfill the batch returned by llama_batch_get_one struct llama_batch batch;
struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) { // optionally fulfill the batch returned by llama_batch_get_one
struct llama_batch batch = in_batch; llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
batch = in_batch;
if (!batch.pos) { if (!batch.pos) {
// determine the last position in KV cache // determine the last position in KV cache
llama_pos last_pos = 0; llama_pos last_pos = 0;
for (const auto & cell : ctx->kv_self.cells) { for (const auto & cell : ctx->kv_self.cells) {
if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) { if (cell.has_seq_id(batch_default_seq_id)) {
last_pos = std::max(last_pos, cell.pos); last_pos = std::max(last_pos, cell.pos);
} }
} }
last_pos++; // next position
pos.resize(batch.n_tokens); pos.resize(batch.n_tokens);
for (int32_t i = 1; i <= batch.n_tokens; i++) { for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i+last_pos; pos[i] = i+last_pos;
} }
batch.pos = pos.data(); batch.pos = pos.data();
} }
if (!batch.n_seq_id) { if (!batch.n_seq_id) {
n_seq_id.reserve(batch.n_tokens); n_seq_id.resize(batch.n_tokens);
for (int32_t i = 1; i <= batch.n_tokens; i++) { for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size(); n_seq_id[i] = seq_id_0.size();
} }
batch.n_seq_id = n_seq_id.data(); batch.n_seq_id = n_seq_id.data();
} }
if (!batch.seq_id) { if (!batch.seq_id) {
seq_id.reserve(batch.n_tokens); seq_id.resize(batch.n_tokens);
for (int32_t i = 1; i <= batch.n_tokens; i++) { for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data(); seq_id[i] = seq_id_0.data();
} }
batch.seq_id = seq_id.data(); batch.seq_id = seq_id.data();
} }
if (!batch.logits) { if (!batch.logits) {
logits.reserve(batch.n_tokens); logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true; logits[logits.size() - 1] = true;
batch.logits = logits.data(); batch.logits = logits.data();
} }
return batch;
} }
}; };
int32_t llama_encode( int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch) { struct llama_batch batch) {
llama_batch_allocr batch_allocr; llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
if (ret < 0) { if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
} }
@ -21143,8 +21144,8 @@ int32_t llama_encode(
int32_t llama_decode( int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch) { struct llama_batch batch) {
llama_batch_allocr batch_allocr; llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
if (ret < 0) { if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
} }