From 87a4a105b2fafb291610c1e28f97b8ba07c6f2d7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Mar 2024 11:35:03 +0200 Subject: [PATCH] server : add comments --- examples/server/server.cpp | 56 ++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e1adbcee4..3bdbde954 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -147,6 +147,8 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; json prompt; + + // when a task is submitted, we first tokenize the prompt and store it here std::vector prompt_tokens; std::string generated_text; @@ -451,8 +453,7 @@ struct server_queue { while (true) { LOG_VERBOSE("new task may arrive", {}); - while (true) - { + while (true) { std::unique_lock lock(mutex_tasks); if (queue_tasks.empty()) { lock.unlock(); @@ -677,8 +678,7 @@ struct server_context { const int32_t n_ctx_slot = n_ctx / params.n_parallel; LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); - for (int i = 0; i < params.n_parallel; i++) - { + for (int i = 0; i < params.n_parallel; i++) { server_slot slot; slot.id = i; @@ -700,9 +700,9 @@ struct server_context { //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT LOG_INFO("slot self-extend", { - {"id_slot", slot.id}, - {"ga_n", ga_n}, - {"ga_w", ga_w} + {"id_slot", slot.id}, + {"ga_n", ga_n}, + {"ga_w", ga_w} }); } @@ -1600,6 +1600,8 @@ struct server_context { queue_tasks.post(task); } + // apply context-shift if needed + // TODO: simplify and improve for (server_slot & slot : slots) { if (slot.ga_n == 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { @@ -1638,9 +1640,10 @@ struct server_context { } } + // start populating the batch for this iteration llama_batch_clear(batch); - // decode any currently ongoing sequences + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state == SLOT_STATE_IDLE) { continue; @@ -1674,7 +1677,7 @@ struct server_context { // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - // assign workload to the slots + // next, batch any pending prompts without exceeding n_batch if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()); @@ -1690,10 +1693,11 @@ struct server_context { continue; } - // need process the prompt + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { auto & prompt_tokens = slot.prompt_tokens; + // we haven't tokenized the prompt yet - do it now: if (prompt_tokens.empty()) { LOG_VERBOSE("tokenizing prompt", { {"id_slot", slot.id}, @@ -1770,9 +1774,9 @@ struct server_context { LOG_VERBOSE("input truncated", { {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); @@ -1786,6 +1790,7 @@ struct server_context { } else { GGML_ASSERT(slot.ga_n == 1); + // reuse any previously computed tokens that are common with the new prompt slot.n_past = common_part(slot.cache_tokens, prompt_tokens); // remove the non-common part from the cache @@ -1802,7 +1807,7 @@ struct server_context { // we have to evaluate at least 1 token to generate logits. LOG_INFO("we have to evaluate at least 1 token to generate logits", { { "id_slot", slot.id }, - { "id_task", slot.id_task } + { "id_task", slot.id_task } }); slot.n_past--; @@ -1836,6 +1841,8 @@ struct server_context { int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; + // add prompt tokens for processing in the current batch + // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) { if (slot.ga_n != 1) { while (slot_npast >= ga_i + ga_w) { @@ -1855,9 +1862,17 @@ struct server_context { slot_npast++; } + LOG_VERBOSE("prompt processing progress", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, + }); + // entire prompt has been processed - start decoding new tokens if (slot.n_past == slot.n_prompt_tokens) { - slot.state = SLOT_STATE_PROCESSING; + slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; GGML_ASSERT(batch.n_tokens > 0); @@ -1868,14 +1883,7 @@ struct server_context { slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - LOG_VERBOSE("prompt processed", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - }); - } else { - LOG_VERBOSE("prompt still processing", { + LOG_VERBOSE("prompt done", { {"id_slot", slot.id}, {"n_past", slot.n_past}, {"n_ctx", n_ctx}, @@ -1900,12 +1908,14 @@ struct server_context { {"n_tokens", batch.n_tokens}, }); + // process the created batch of tokens for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); for (auto & slot : slots) { if (slot.ga_n != 1) { // context extension via Self-Extend + // TODO: simplify and/or abstract this while (slot.n_past_se >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);