server : simplify state machine for slot

This commit is contained in:
Xuan Son Nguyen 2024-09-02 22:09:35 +02:00
parent f1485161e5
commit 2c81cde493

View file

@ -52,13 +52,8 @@ enum stop_type {
enum slot_state { enum slot_state {
SLOT_STATE_IDLE, SLOT_STATE_IDLE,
SLOT_STATE_PROCESSING, SLOT_STATE_PROCESSING_PROMPT,
}; SLOT_STATE_GENERATING,
enum slot_command {
SLOT_COMMAND_NONE,
SLOT_COMMAND_LOAD_PROMPT,
SLOT_COMMAND_RELEASE,
}; };
enum server_state { enum server_state {
@ -135,7 +130,6 @@ struct server_slot {
struct slot_params params; struct slot_params params;
slot_state state = SLOT_STATE_IDLE; slot_state state = SLOT_STATE_IDLE;
slot_command command = SLOT_COMMAND_NONE;
// used to determine the slot that has been used the longest // used to determine the slot that has been used the longest
int64_t t_last_used = -1; int64_t t_last_used = -1;
@ -194,6 +188,8 @@ struct server_slot {
double t_prompt_processing; // ms double t_prompt_processing; // ms
double t_token_generation; // ms double t_token_generation; // ms
std::function<void(int)> callback_on_release;
void reset() { void reset() {
n_prompt_tokens = 0; n_prompt_tokens = 0;
generated_text = ""; generated_text = "";
@ -229,24 +225,32 @@ struct server_slot {
} }
bool available() const { bool available() const {
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; return state == SLOT_STATE_IDLE;
} }
bool is_processing() const { bool is_processing() const {
return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; return state != SLOT_STATE_IDLE;
} }
void add_token_string(const completion_token_output & token) { void add_token_string(const completion_token_output & token) {
if (command == SLOT_COMMAND_RELEASE) { if (!is_processing()) {
return; return;
} }
generated_token_probs.push_back(token); generated_token_probs.push_back(token);
} }
void release() { void release() {
if (state == SLOT_STATE_PROCESSING) { if (is_processing()) {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE; state = SLOT_STATE_IDLE;
LOG_INFO("slot released", {
{"id_slot", id},
{"id_task", id_task},
{"n_past", n_past},
{"truncated", truncated}
});
callback_on_release(id);
// queue_tasks.notify_slot_changed();
} }
} }
@ -716,6 +720,10 @@ struct server_context {
slot.sparams = params.sparams; slot.sparams = params.sparams;
slot.callback_on_release = [this](int) {
queue_tasks.notify_slot_changed();
};
slot.reset(); slot.reset();
slots.push_back(slot); slots.push_back(slot);
@ -1077,7 +1085,7 @@ struct server_context {
} }
} }
slot.command = SLOT_COMMAND_LOAD_PROMPT; slot.state = SLOT_STATE_PROCESSING_PROMPT;
slot.prompt_tokens.clear(); slot.prompt_tokens.clear();
LOG_INFO("slot is processing task", { LOG_INFO("slot is processing task", {
@ -1875,33 +1883,12 @@ struct server_context {
system_prompt_update(); system_prompt_update();
} }
// release slots
for (auto & slot : slots) {
if (slot.command == SLOT_COMMAND_RELEASE) {
slot.state = SLOT_STATE_IDLE;
slot.command = SLOT_COMMAND_NONE;
slot.t_last_used = ggml_time_us();
LOG_INFO("slot released", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
queue_tasks.notify_slot_changed();
}
}
// check if all slots are idle // check if all slots are idle
{ {
bool all_idle = true; bool all_idle = true;
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) { if (slot.is_processing()) {
all_idle = false; all_idle = false;
break; break;
} }
@ -1972,7 +1959,7 @@ struct server_context {
// frist, add sampled tokens from any ongoing sequences // frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state == SLOT_STATE_IDLE) { if (slot.state != SLOT_STATE_GENERATING) {
continue; continue;
} }
@ -2014,7 +2001,7 @@ struct server_context {
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
// this slot still has a prompt to be processed // this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
auto & prompt_tokens = slot.prompt_tokens; auto & prompt_tokens = slot.prompt_tokens;
// we haven't tokenized the prompt yet - do it now: // we haven't tokenized the prompt yet - do it now:
@ -2082,8 +2069,6 @@ struct server_context {
{"id_task", slot.id_task} {"id_task", slot.id_task}
}); });
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release(); slot.release();
slot.print_timings(); slot.print_timings();
send_final_response(slot); send_final_response(slot);
@ -2093,8 +2078,6 @@ struct server_context {
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
// this prompt is too large to process - discard it // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) { if (slot.n_prompt_tokens > n_ubatch) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release(); slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
continue; continue;
@ -2254,8 +2237,7 @@ struct server_context {
// entire prompt has been processed - start decoding new tokens // entire prompt has been processed - start decoding new tokens
if (slot.n_past == slot.n_prompt_tokens) { if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_PROCESSING; slot.state = SLOT_STATE_GENERATING;
slot.command = SLOT_COMMAND_NONE;
GGML_ASSERT(batch.n_tokens > 0); GGML_ASSERT(batch.n_tokens > 0);
@ -2342,13 +2324,11 @@ struct server_context {
if (n_batch == 1 || ret < 0) { if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size // if you get here, it means the KV cache is full - try increasing it via the context size
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", { LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i}, {"i", i},
{"n_batch", ret}, {"n_batch", n_batch},
{"ret", ret}, {"ret", ret},
}); });
for (auto & slot : slots) { for (auto & slot : slots) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release(); slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
} }
@ -2360,16 +2340,16 @@ struct server_context {
i -= n_batch; i -= n_batch;
LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
{"i", i}, {"i", i},
{"n_batch", n_batch}, {"n_batch", n_batch},
{"ret", ret}, {"ret", ret},
}); });
continue; // continue loop of n_batch continue; // continue loop of n_batch
} }
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { if (slot.state != SLOT_STATE_GENERATING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
continue; // continue loop of slots continue; // continue loop of slots
} }