server : add comments

This commit is contained in:
Georgi Gerganov 2024-03-07 11:35:03 +02:00
parent 818d898fe7
commit 87a4a105b2
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -147,6 +147,8 @@ struct server_slot {
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
json prompt; json prompt;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
std::string generated_text; std::string generated_text;
@ -451,8 +453,7 @@ struct server_queue {
while (true) { while (true) {
LOG_VERBOSE("new task may arrive", {}); LOG_VERBOSE("new task may arrive", {});
while (true) while (true) {
{
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) { if (queue_tasks.empty()) {
lock.unlock(); lock.unlock();
@ -677,8 +678,7 @@ struct server_context {
const int32_t n_ctx_slot = n_ctx / params.n_parallel; const int32_t n_ctx_slot = n_ctx / params.n_parallel;
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
for (int i = 0; i < params.n_parallel; i++) for (int i = 0; i < params.n_parallel; i++) {
{
server_slot slot; server_slot slot;
slot.id = i; slot.id = i;
@ -700,9 +700,9 @@ struct server_context {
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
LOG_INFO("slot self-extend", { LOG_INFO("slot self-extend", {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"ga_n", ga_n}, {"ga_n", ga_n},
{"ga_w", ga_w} {"ga_w", ga_w}
}); });
} }
@ -1600,6 +1600,8 @@ struct server_context {
queue_tasks.post(task); queue_tasks.post(task);
} }
// apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.ga_n == 1) { if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
@ -1638,9 +1640,10 @@ struct server_context {
} }
} }
// start populating the batch for this iteration
llama_batch_clear(batch); llama_batch_clear(batch);
// decode any currently ongoing sequences // frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state == SLOT_STATE_IDLE) { if (slot.state == SLOT_STATE_IDLE) {
continue; continue;
@ -1674,7 +1677,7 @@ struct server_context {
// process in chunks of params.n_batch // process in chunks of params.n_batch
int32_t n_batch = params.n_batch; int32_t n_batch = params.n_batch;
// assign workload to the slots // next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()); const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
@ -1690,10 +1693,11 @@ struct server_context {
continue; continue;
} }
// need process the prompt // this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
auto & prompt_tokens = slot.prompt_tokens; auto & prompt_tokens = slot.prompt_tokens;
// we haven't tokenized the prompt yet - do it now:
if (prompt_tokens.empty()) { if (prompt_tokens.empty()) {
LOG_VERBOSE("tokenizing prompt", { LOG_VERBOSE("tokenizing prompt", {
{"id_slot", slot.id}, {"id_slot", slot.id},
@ -1770,9 +1774,9 @@ struct server_context {
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx}, {"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_left", n_left}, {"n_left", n_left},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
}); });
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
@ -1786,6 +1790,7 @@ struct server_context {
} else { } else {
GGML_ASSERT(slot.ga_n == 1); GGML_ASSERT(slot.ga_n == 1);
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
// remove the non-common part from the cache // remove the non-common part from the cache
@ -1802,7 +1807,7 @@ struct server_context {
// we have to evaluate at least 1 token to generate logits. // we have to evaluate at least 1 token to generate logits.
LOG_INFO("we have to evaluate at least 1 token to generate logits", { LOG_INFO("we have to evaluate at least 1 token to generate logits", {
{ "id_slot", slot.id }, { "id_slot", slot.id },
{ "id_task", slot.id_task } { "id_task", slot.id_task }
}); });
slot.n_past--; slot.n_past--;
@ -1836,6 +1841,8 @@ struct server_context {
int32_t ga_n = slot.ga_n; int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w; int32_t ga_w = slot.ga_w;
// add prompt tokens for processing in the current batch
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) { for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
if (slot.ga_n != 1) { if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) { while (slot_npast >= ga_i + ga_w) {
@ -1855,9 +1862,17 @@ struct server_context {
slot_npast++; slot_npast++;
} }
LOG_VERBOSE("prompt processing progress", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
{"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
});
// entire prompt has been processed - start decoding new tokens // entire prompt has been processed - start decoding new tokens
if (slot.n_past == slot.n_prompt_tokens) { if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_PROCESSING; slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE; slot.command = SLOT_COMMAND_NONE;
GGML_ASSERT(batch.n_tokens > 0); GGML_ASSERT(batch.n_tokens > 0);
@ -1868,14 +1883,7 @@ struct server_context {
slot.n_decoded = 0; slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1; slot.i_batch = batch.n_tokens - 1;
LOG_VERBOSE("prompt processed", { LOG_VERBOSE("prompt done", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
});
} else {
LOG_VERBOSE("prompt still processing", {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"n_past", slot.n_past}, {"n_past", slot.n_past},
{"n_ctx", n_ctx}, {"n_ctx", n_ctx},
@ -1900,12 +1908,14 @@ struct server_context {
{"n_tokens", batch.n_tokens}, {"n_tokens", batch.n_tokens},
}); });
// process the created batch of tokens
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.ga_n != 1) { if (slot.ga_n != 1) {
// context extension via Self-Extend // context extension via Self-Extend
// TODO: simplify and/or abstract this
while (slot.n_past_se >= slot.ga_i + slot.ga_w) { while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);