From 9779bf3ae0f0119d1b44000704aab5e99e09139a Mon Sep 17 00:00:00 2001 From: omroy12 Date: Thu, 30 Jan 2025 22:00:20 +0530 Subject: [PATCH] =?UTF-8?q?Om=20=F0=9F=8D=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llama-batch.cpp | 544 ++++++++++++++++++-------------------------- 1 file changed, 221 insertions(+), 323 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57f..f3dc503e1 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -1,368 +1,266 @@ #include "llama-batch.h" - #include #include +#include +#include +#include -llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { - // clear empty sequences - // the previous ubatch is assumed to be gone, - // so nothing should refer to values in these sequences anymore. - for (size_t i = seq.size(); i-- > 0;) { - if (seq[i].length == 0) { - seq.pop_back(); - } else { - break; - } - } - ubatch_token.resize(!has_embd ? n_ubatch : 0); - ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - ubatch_pos.resize(n_ubatch); - ubatch_n_seq_id.resize(n_ubatch); - ubatch_seq_id.resize(n_ubatch); - ubatch_output.resize(n_ubatch); - llama_ubatch ubatch = { - /*equal_seqs =*/ true, - /*n_tokens =*/ 0, - /*n_seq_tokens =*/ 0, - /*n_seqs =*/ 0, - /*token =*/ !has_embd ? ubatch_token.data() : nullptr, - /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, - /*pos =*/ ubatch_pos.data(), - /*n_seq_id =*/ ubatch_n_seq_id.data(), - /*seq_id =*/ ubatch_seq_id.data(), - /*output =*/ ubatch_output.data(), - }; - return ubatch; -} +class llama_batch_allocr { +public: + llama_batch batch; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; -void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { - GGML_ASSERT(batch != nullptr); - GGML_ASSERT(length <= seq.length); - // Can only add sequences of equal lengths to a batch, - // otherwise it isn't clear to which sequence a token belongs - GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); - GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); - // NOTE: loops are separated for cache-friendliness - if (batch->token) { - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + llama_batch_allocr(const llama_batch& in_batch, llama_pos p0) { + batch = in_batch; + assert(batch.n_tokens > 0); + initialize_batch(p0); + } + +private: + void initialize_batch(llama_pos p0) { + if (!batch.pos) { + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i + p0; } - } else { - // simple split - ubatch.token = batch->token + seq.offset; + batch.pos = pos.data(); } - } else { - ubatch.token = nullptr; - } - if (batch->embd) { - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - memcpy( - ubatch.embd + (n_embd * (ubatch.n_tokens + i)), - batch->embd + (n_embd * ids[seq.offset + i]), - n_embd * sizeof(float) - ); + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); } - } else { - // simple split - ubatch.embd = batch->embd + (n_embd * seq.offset); + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = nullptr; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens, 0); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); } - } else { - ubatch.embd = nullptr; } - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; - } - } else { - // simple split - ubatch.pos = batch->pos + seq.offset; +}; + +class llama_sbatch { +public: + const llama_batch* batch; + size_t n_embd; + bool logits_all; + size_t n_tokens; + std::vector ids; + std::vector seq; + std::vector out_ids; + + llama_sbatch() : batch(nullptr), n_embd(0), logits_all(false), n_tokens(0) {} + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd) { + clear_empty_sequences(); + resize_ubatch_data(n_ubatch, has_embd); + return build_ubatch(n_ubatch, has_embd); } - if (ubatch.equal_seqs) { - ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; - if (seq.seq_id) { - ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; - } - } else { - // simple split - if (batch->n_seq_id) { - ubatch.n_seq_id = batch->n_seq_id + seq.offset; - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + + void add_seq_to_ubatch(llama_ubatch& ubatch, llama_sbatch_seq& seq, size_t length) { + assert(batch != nullptr); + assert(length <= seq.length); + validate_sequence_length(ubatch, seq, length); + + handle_tokens_and_embeddings(ubatch, seq, length); + handle_positions_and_sequence_ids(ubatch, seq, length); + handle_logits(ubatch, seq, length); + + update_ubatch_and_seq(ubatch, seq, length); + } + +private: + void clear_empty_sequences() { + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; } } - if (batch->seq_id) { - ubatch.seq_id = batch->seq_id + seq.offset; + } + + void resize_ubatch_data(size_t n_ubatch, bool has_embd) { + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + } + + llama_ubatch build_ubatch(size_t n_ubatch, bool has_embd) { + return { + true, 0, 0, 0, + !has_embd ? ubatch_token.data() : nullptr, + has_embd ? ubatch_embd.data() : nullptr, + ubatch_pos.data(), + ubatch_n_seq_id.data(), + ubatch_seq_id.data(), + ubatch_output.data() + }; + } + + void validate_sequence_length(llama_ubatch& ubatch, llama_sbatch_seq& seq, size_t length) { + assert(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == ubatch.n_tokens / ubatch.n_seqs); + assert((seq.n_seq_id != 0) == ubatch.equal_seqs); + } + + void handle_tokens_and_embeddings(llama_ubatch& ubatch, llama_sbatch_seq& seq, size_t length) { + if (batch->token) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + ubatch.token = batch->token + seq.offset; + } + } + if (batch->embd) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + std::memcpy(ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float)); + } + } else { + ubatch.embd = batch->embd + n_embd * seq.offset; + } } } - if (logits_all) { - for (size_t i = 0; i < length; ++i) { - ubatch.output[ubatch.n_tokens + i] = 1; - out_ids.push_back(ids[seq.offset + i]); - } - } else if (batch->logits) { + + void handle_positions_and_sequence_ids(llama_ubatch& ubatch, llama_sbatch_seq& seq, size_t length) { if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + ubatch.pos = batch->pos + seq.offset; + } + + if (ubatch.equal_seqs) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } + } else { + if (batch->n_seq_id) { + ubatch.n_seq_id = batch->n_seq_id + seq.offset; + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } + } + if (batch->seq_id) { + ubatch.seq_id = batch->seq_id + seq.offset; + } + } + } + + void handle_logits(llama_ubatch& ubatch, llama_sbatch_seq& seq, size_t length) { + if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else if (batch->logits) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { + out_ids.push_back(id); + } + } + } else { + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { + out_ids.push_back(seq.offset + i); + } + } + } + } else { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; - ubatch.output[ubatch.n_tokens + i] = is_output; - if (is_output) { out_ids.push_back(id); } - } - } else { - // simple split - ubatch.output = batch->logits + seq.offset; - for (size_t i = 0; i < length; ++i) { - if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } - } - } - } else { - // only get last output - for (size_t i = 0; i < length; ++i) { - size_t id = ids[seq.offset + i]; - int8_t is_last = id == ids.size() - 1; - ubatch.output[ubatch.n_tokens + i] = is_last; - if (is_last) { out_ids.push_back(id); } - } - } - if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { - ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; - } - ubatch.n_tokens += length; - ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits - seq.offset += length; - seq.length -= length; - n_tokens -= length; - GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); -} - -llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); - ubatch.equal_seqs = false; - if (!seq.empty()) { - llama_sbatch_seq & s = seq[0]; - size_t length = s.length < n_ubatch ? s.length : n_ubatch; - GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits - add_seq_to_ubatch(ubatch, s, length); - } - return ubatch; -} - -llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); - if (!seq.empty()) { - size_t length = 0; - size_t n_tokens_in_ubatch = 0; - GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits - // smallest first, because it's easier to split this way; - // starting from the end to pop in constant time. - for (size_t i = seq.size(); i-- > 0;) { - llama_sbatch_seq & s = seq[i]; - GGML_ASSERT(s.length > 0); - if (length == 0) { - length = s.length < n_ubatch ? s.length : n_ubatch; - } - add_seq_to_ubatch(ubatch, s, length); - n_tokens_in_ubatch += length; - // shared prompts can't be mixed with any of their sequences, - // so it's safer to compute them in their own ubatch - if (s.n_seq_id > 1) { break; } - // stop when there isn't enough space for another sequence - if (length + n_tokens_in_ubatch > n_ubatch) { break; } - } - } - return ubatch; -} - -llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); - if (!seq.empty()) { - llama_sbatch_seq & s = seq[seq.size() - 1]; - size_t length = s.length < n_ubatch ? s.length : n_ubatch; - GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits - add_seq_to_ubatch(ubatch, s, length); - } - return ubatch; -} - -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { - GGML_ASSERT(batch.n_tokens >= 0); - this->batch = &batch; - this->n_embd = n_embd; - this->logits_all = logits_all; - - n_tokens = batch.n_tokens; - ids.resize(n_tokens); - out_ids.clear(); - // TODO: reserve out_ids and seq - - for (size_t i = 0; i < n_tokens; ++i) { - ids[i] = i; - } - if (simple_split) { - seq.resize(1); - llama_sbatch_seq & s = seq[0]; - s.n_seq_id = 0; - s.seq_id = nullptr; - s.offset = 0; - s.length = n_tokens; - return; - } - std::sort(ids.begin(), ids.end(), - [&batch](size_t a, size_t b) { - int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; - int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; - // sort by seq_id, then by pos - if (n_seq_a == n_seq_b) { - if (batch.seq_id) { - for (int32_t i = 0; i < n_seq_a; ++i) { - llama_seq_id seq_id_a = batch.seq_id[a][i]; - llama_seq_id seq_id_b = batch.seq_id[b][i]; - // smaller seq_ids go first - if (seq_id_a != seq_id_b) { - return seq_id_a < seq_id_b; - } - } - } - // when all else is equal, sort by pos - if (batch.pos) { - return batch.pos[a] < batch.pos[b]; - } - // no pos, sort by id - return a < b; - } - // shared prompts go first - return n_seq_a > n_seq_b; - } - ); - // init seq - llama_sbatch_seq * last_seq = nullptr; - - for (size_t i = 0; i < n_tokens; ++i) { - const size_t bi = ids[i]; - const int32_t n_seqs = batch.n_seq_id[bi]; - llama_seq_id * seq_ids = batch.seq_id[bi]; - if (last_seq != nullptr) { - bool same = n_seqs == last_seq->n_seq_id; - for (int32_t j = 0; same && j < n_seqs; ++j) { - if (seq_ids[j] != last_seq->seq_id[j]) { - same = false; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { + out_ids.push_back(id); } } - if (same) { - last_seq->length += 1; - continue; - } } - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; - seq.push_back(new_seq); - last_seq = &seq.back(); } - // keep shared prompts first at the end, then sort by length descending. - std::sort(seq.begin(), seq.end(), - [](llama_sbatch_seq & a, llama_sbatch_seq & b) { - if (a.n_seq_id == b.n_seq_id) { - return a.length > b.length; - } - return a.n_seq_id < b.n_seq_id; - } - ); -} -llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { - batch = in_batch; - GGML_ASSERT(batch.n_tokens > 0); - if (!batch.pos) { - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = i + p0; + void update_ubatch_and_seq(llama_ubatch& ubatch, llama_sbatch_seq& seq, size_t length) { + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; } - batch.pos = pos.data(); - } - if (!batch.n_seq_id) { - n_seq_id.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - n_seq_id[i] = seq_id_0.size(); - } - batch.n_seq_id = n_seq_id.data(); - } - if (!batch.seq_id) { - seq_id.resize(batch.n_tokens + 1); - seq_id[batch.n_tokens] = NULL; - for (int32_t i = 0; i < batch.n_tokens; i++) { - seq_id[i] = seq_id_0.data(); - } - batch.seq_id = seq_id.data(); - } - if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); - } -} + ubatch.n_tokens += length; + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; + seq.offset += length; + seq.length -= length; + n_tokens -= length; -// -// interface implementation -// + assert(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); + } +}; -struct llama_batch llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens) { +// Initialization functions remain unchanged +llama_batch llama_batch_get_one(llama_token * tokens, int32_t n_tokens) { return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + n_tokens, + tokens, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr }; } -struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - }; +llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = {0}; if (embd) { - batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + batch.embd = new float[n_tokens_alloc * embd](); } else { - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + batch.token = new llama_token[n_tokens_alloc](); } - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); - batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); - batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + batch.pos = new llama_pos[n_tokens_alloc](); + batch.n_seq_id = new int32_t[n_tokens_alloc](); + batch.seq_id = new llama_seq_id*[n_tokens_alloc + 1]; for (int i = 0; i < n_tokens_alloc; ++i) { - batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + batch.seq_id[i] = new llama_seq_id[n_seq_max](); } batch.seq_id[n_tokens_alloc] = nullptr; - - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch.logits = new int8_t[n_tokens_alloc](); return batch; } void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.n_seq_id) free(batch.n_seq_id); - if (batch.seq_id) { - for (int i = 0; batch.seq_id[i] != nullptr; ++i) { - free(batch.seq_id[i]); - } - free(batch.seq_id); + delete[] batch.token; + delete[] batch.embd; + delete[] batch.pos; + delete[] batch.n_seq_id; + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + delete[] batch.seq_id[i]; } - if (batch.logits) free(batch.logits); + delete[] batch.seq_id; + delete[] batch.logits; }