From 2c81cde49383c71b8f5d1368ebaf5c9f51d8eee1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 2 Sep 2024 22:09:35 +0200 Subject: [PATCH] server : simplify state machine for slot --- examples/server/server.cpp | 86 +++++++++++++++----------------------- 1 file changed, 33 insertions(+), 53 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 109dbc023..9518829a0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -52,13 +52,8 @@ enum stop_type { enum slot_state { SLOT_STATE_IDLE, - SLOT_STATE_PROCESSING, -}; - -enum slot_command { - SLOT_COMMAND_NONE, - SLOT_COMMAND_LOAD_PROMPT, - SLOT_COMMAND_RELEASE, + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_GENERATING, }; enum server_state { @@ -135,7 +130,6 @@ struct server_slot { struct slot_params params; slot_state state = SLOT_STATE_IDLE; - slot_command command = SLOT_COMMAND_NONE; // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -194,6 +188,8 @@ struct server_slot { double t_prompt_processing; // ms double t_token_generation; // ms + std::function callback_on_release; + void reset() { n_prompt_tokens = 0; generated_text = ""; @@ -229,24 +225,32 @@ struct server_slot { } bool available() const { - return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; + return state == SLOT_STATE_IDLE; } bool is_processing() const { - return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; + return state != SLOT_STATE_IDLE; } void add_token_string(const completion_token_output & token) { - if (command == SLOT_COMMAND_RELEASE) { + if (!is_processing()) { return; } generated_token_probs.push_back(token); } void release() { - if (state == SLOT_STATE_PROCESSING) { + if (is_processing()) { t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - command = SLOT_COMMAND_RELEASE; + state = SLOT_STATE_IDLE; + LOG_INFO("slot released", { + {"id_slot", id}, + {"id_task", id_task}, + {"n_past", n_past}, + {"truncated", truncated} + }); + callback_on_release(id); + // queue_tasks.notify_slot_changed(); } } @@ -716,6 +720,10 @@ struct server_context { slot.sparams = params.sparams; + slot.callback_on_release = [this](int) { + queue_tasks.notify_slot_changed(); + }; + slot.reset(); slots.push_back(slot); @@ -1077,7 +1085,7 @@ struct server_context { } } - slot.command = SLOT_COMMAND_LOAD_PROMPT; + slot.state = SLOT_STATE_PROCESSING_PROMPT; slot.prompt_tokens.clear(); LOG_INFO("slot is processing task", { @@ -1875,33 +1883,12 @@ struct server_context { system_prompt_update(); } - // release slots - for (auto & slot : slots) { - if (slot.command == SLOT_COMMAND_RELEASE) { - slot.state = SLOT_STATE_IDLE; - slot.command = SLOT_COMMAND_NONE; - slot.t_last_used = ggml_time_us(); - - LOG_INFO("slot released", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated} - }); - - queue_tasks.notify_slot_changed(); - } - } - // check if all slots are idle { bool all_idle = true; for (auto & slot : slots) { - if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) { + if (slot.is_processing()) { all_idle = false; break; } @@ -1972,7 +1959,7 @@ struct server_context { // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { - if (slot.state == SLOT_STATE_IDLE) { + if (slot.state != SLOT_STATE_GENERATING) { continue; } @@ -2014,7 +2001,7 @@ struct server_context { if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { + if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { auto & prompt_tokens = slot.prompt_tokens; // we haven't tokenized the prompt yet - do it now: @@ -2082,8 +2069,6 @@ struct server_context { {"id_task", slot.id_task} }); - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); send_final_response(slot); @@ -2093,8 +2078,6 @@ struct server_context { if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; @@ -2254,8 +2237,7 @@ struct server_context { // entire prompt has been processed - start decoding new tokens if (slot.n_past == slot.n_prompt_tokens) { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + slot.state = SLOT_STATE_GENERATING; GGML_ASSERT(batch.n_tokens > 0); @@ -2342,13 +2324,11 @@ struct server_context { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", { - {"i", i}, - {"n_batch", ret}, - {"ret", ret}, + {"i", i}, + {"n_batch", n_batch}, + {"ret", ret}, }); for (auto & slot : slots) { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; slot.release(); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } @@ -2360,16 +2340,16 @@ struct server_context { i -= n_batch; LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, + {"i", i}, + {"n_batch", n_batch}, + {"ret", ret}, }); continue; // continue loop of n_batch } for (auto & slot : slots) { - if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + if (slot.state != SLOT_STATE_GENERATING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; // continue loop of slots }