From b226c5b1a72cef6140bc62a3b94ef773ba1c9dc7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 11:48:09 +0200 Subject: [PATCH] refactor llama_batch_get_one --- include/llama.h | 20 +++----- src/llama.cpp | 130 +++++++++++++++++++++++++++--------------------- 2 files changed, 80 insertions(+), 70 deletions(-) diff --git a/include/llama.h b/include/llama.h index 7cae1bbe2..dabd64dbb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -232,8 +232,11 @@ extern "C" { // - token : the token ids of the input (used when embd is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence + // (if set to NULL, the token position will be tracked automatically by llama_decode) // - seq_id : the sequence to which the respective token belongs + // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // (if set to NULL, only the logits for last token will be returned) // typedef struct llama_batch { int32_t n_tokens; @@ -244,15 +247,6 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" - - // NOTE: helpers for smooth API transition - can be deprecated in the future - // for future-proof code, use the above fields instead and ignore everything below - // - // pos[i] = all_pos_0 + i*all_pos_1 - // - llama_pos all_pos_0; // used if pos == NULL - llama_pos all_pos_1; // used if pos == NULL - llama_seq_id all_seq_id; // used if seq_id == NULL } llama_batch; enum llama_model_kv_override_type { @@ -775,15 +769,15 @@ extern "C" { // Decoding // - // Return batch for single sequence of tokens starting at pos_0 + // Return batch for single sequence of tokens + // The sequence ID will be fixed to 0 + // The position of the tokens will be tracked automatically by llama_decode // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id); + int32_t n_tokens); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids diff --git a/src/llama.cpp b/src/llama.cpp index 3443b0689..825999b3c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2941,9 +2941,6 @@ struct llama_sbatch_seq { llama_seq_id * seq_id; size_t offset; size_t length; - - // helper for smoother batch API transition -- can be deprecated in the future - llama_seq_id all_seq_id; // used if seq_id == NULL }; // sequence-length-aware batch splitting @@ -3038,30 +3035,18 @@ struct llama_sbatch { } else { ubatch.embd = nullptr; } - // from here on, the else branches are deprecated; - // they are helpers for smoother batch API transition - if (batch->pos) { - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; - } - } else { - // simple split - ubatch.pos = batch->pos + seq.offset; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; } } else { - for (size_t i = 0; i < length; ++i) { - llama_pos bi = ids[seq.offset + i]; - ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); - } + // simple split + ubatch.pos = batch->pos + seq.offset; } if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; - } else { - GGML_ASSERT(seq.n_seq_id == 1); - ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { // simple split @@ -3074,10 +3059,6 @@ struct llama_sbatch { } if (batch->seq_id) { ubatch.seq_id = batch->seq_id + seq.offset; - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; - } } } if (logits_all) { @@ -3196,7 +3177,6 @@ struct llama_sbatch { s.seq_id = nullptr; s.offset = 0; s.length = n_tokens; - s.all_seq_id = batch.all_seq_id; return; } std::sort(ids.begin(), ids.end(), @@ -3219,7 +3199,7 @@ struct llama_sbatch { if (batch.pos) { return batch.pos[a] < batch.pos[b]; } - // no pos, sort by id (assuming batch.all_pos_1 is positive) + // no pos, sort by id return a < b; } // shared prompts go first @@ -3229,30 +3209,25 @@ struct llama_sbatch { // init seq llama_sbatch_seq * last_seq = nullptr; - if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { - for (size_t i = 0; i < n_tokens; ++i) { - const size_t bi = ids[i]; - const int32_t n_seqs = batch.n_seq_id[bi]; - llama_seq_id * seq_ids = batch.seq_id[bi]; - if (last_seq != nullptr) { - bool same = n_seqs == last_seq->n_seq_id; - for (int32_t j = 0; same && j < n_seqs; ++j) { - if (seq_ids[j] != last_seq->seq_id[j]) { - same = false; - } - } - if (same) { - last_seq->length += 1; - continue; + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; } } - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; - seq.push_back(new_seq); - last_seq = &seq.back(); + if (same) { + last_seq->length += 1; + continue; + } } - } else { - llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; seq.push_back(new_seq); + last_seq = &seq.back(); } // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), @@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id) { + int32_t n_tokens) { return { /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, @@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one( /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, - /*all_pos_0 =*/ pos_0, - /*all_pos_1 =*/ 1, - /*all_seq_id =*/ seq_id, }; } @@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, - /*all_pos_0 =*/ 0, - /*all_pos_1 =*/ 0, - /*all_seq_id =*/ 0, }; if (embd) { @@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } +// temporary allocate memory for the input batch if needed +struct llama_batch_allocr { + static const llama_seq_id default_seq_id = 0; + std::array seq_id_0 = {default_seq_id}; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + // fulfill the batch returned by llama_batch_get_one + struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) { + struct llama_batch batch = in_batch; + if (!batch.pos) { + // determine the last position in KV cache + llama_pos last_pos; + for (const auto & cell : ctx->kv_self.cells) { + if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) { + last_pos = std::max(last_pos, cell.pos); + } + } + pos.resize(batch.n_tokens); + for (int32_t i = 1; i <= batch.n_tokens; i++) { + pos[i] = i+last_pos; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.reserve(batch.n_tokens); + for (int32_t i = 1; i <= batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.reserve(batch.n_tokens); + for (int32_t i = 1; i <= batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.reserve(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } + } +}; + int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_encode_internal(*ctx, batch); + llama_batch_allocr batch_allocr; + const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21147,7 +21162,8 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_decode_internal(*ctx, batch); + llama_batch_allocr batch_allocr; + const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); }