From 91e7e0ff175c18debd20149cdbfcf205473810c1 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 25 Feb 2024 12:56:34 +0100 Subject: [PATCH] refactor work queue related stuff --- examples/server/server.cpp | 66 +++++++++++++++++++++++++------------- examples/server/utils.hpp | 31 +++++++++++------- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 19a8c1067..e1e6ebc57 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1021,13 +1021,23 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(task_server& task, const std::string &error) + void send_error(task_server &task, const std::string &error) { - LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); + send_error(task.id, task.multitask_id, error); + } + + void send_error(llama_client_slot &slot, const std::string &error) + { + send_error(slot.task_id, slot.multitask_id, error); + } + + void send_error(int task_id, int multitask_id, const std::string &error) + { + LOG_TEE("task %i - error: %s\n", task_id, error.c_str()); task_result res; - res.id = task.id; - res.multitask_id = task.multitask_id; - res.stop = false; + res.id = task_id; + res.multitask_id = multitask_id; + res.stop = true; res.error = true; res.result_json = { { "content", error } }; queue_results.send(res); @@ -1466,7 +1476,9 @@ struct llama_server_context queue_results.send(result); } - bool update_slots() { + void run_slots() { + bool has_next_response = false; // whether to schedule next slot run, to generate next token + if (system_need_update) { LOG_TEE("updating system prompt\n"); @@ -1482,14 +1494,9 @@ struct llama_server_context LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n"); kv_cache_clear(); } - return true; + return; } - task_server task; - task.type = TASK_TYPE_NEXT_RESPONSE; - task.target_id = -1; - queue_tasks.post(task); - for (llama_client_slot &slot : slots) { if (slot.ga_n == 1) @@ -1737,7 +1744,8 @@ struct llama_server_context if (has_images && !ingest_images(slot, n_batch)) { LOG_TEE("failed processing images\n"); - return false; + send_error(slot, "failed processing images"); + continue; } // extract the logits only for the last token @@ -1755,7 +1763,6 @@ struct llama_server_context if (batch.n_tokens == 0) { all_slots_are_idle = true; - return true; } for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) @@ -1812,7 +1819,13 @@ struct llama_server_context { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); - return false; + for (auto & slot : slots) + { + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + slot.release(); + } + has_next_response = false; + break; } LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); @@ -1873,14 +1886,23 @@ struct llama_server_context send_final_response(slot); } + // if slot is not yet finish its work, we schedule next run + if (slot.has_next_token) + { + has_next_response = true; + } + slot.i_batch = -1; } } - return true; - } - void run_on_all_tasks_finished() { - update_slots(); + if (has_next_response) { + LOG_VERBOSE("schedule next slot run", {}); + task_server task; + task.type = TASK_TYPE_NEXT_RESPONSE; + task.target_id = -1; + queue_tasks.post(task); + } } }; @@ -3210,7 +3232,7 @@ int main(int argc, char **argv) bool running = true; while (running) { - running = llama.update_slots(); + running = llama.run_slots(); } }*/ //); @@ -3232,8 +3254,8 @@ int main(int argc, char **argv) &llama_server_context::process_single_task, &llama, std::placeholders::_1)); llama.queue_tasks.on_finish_multitask(std::bind( &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); - llama.queue_tasks.on_all_tasks_finished(std::bind( - &llama_server_context::run_on_all_tasks_finished, &llama)); + llama.queue_tasks.on_run_slots(std::bind( + &llama_server_context::run_slots, &llama)); llama.queue_results.on_multitask_update(std::bind( &llama_server_queue::update_multitask, &llama.queue_tasks, diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 88545eb69..8cc63e7d4 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -227,7 +227,7 @@ struct llama_server_queue { // callback functions std::function callback_new_task; std::function callback_finish_multitask; - std::function callback_all_task_finished; + std::function callback_run_slots; // Add a new task to the end of the queue int post(task_server task) { @@ -257,14 +257,14 @@ struct llama_server_queue { callback_new_task = callback; } - // Register function to process a multitask + // Register function to process a multitask when it is finished void on_finish_multitask(std::function callback) { callback_finish_multitask = callback; } - // Register the function to be called when the batch of tasks is finished - void on_all_tasks_finished(std::function callback) { - callback_all_task_finished = callback; + // Register the function to be called when all slots data is ready to be processed + void on_run_slots(std::function callback) { + callback_run_slots = callback; } // Call when the state of one slot is changed @@ -286,7 +286,13 @@ struct llama_server_queue { condition_tasks.notify_all(); } - // Start the main loop. + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Run all slots + */ void start_loop() { running = true; while (true) { @@ -306,8 +312,8 @@ struct llama_server_queue { LOG_VERBOSE("callback_new_task", {}); callback_new_task(task); } - LOG_VERBOSE("callback_all_task_finished", {}); - // process and update all the multitasks + LOG_VERBOSE("update_multitasks", {}); + // check if we have any finished multitasks auto queue_iterator = queue_multitasks.begin(); while (queue_iterator != queue_multitasks.end()) { @@ -324,8 +330,9 @@ struct llama_server_queue { ++queue_iterator; } } - // all tasks in the current loop is finished - callback_all_task_finished(); + // all tasks in the current loop is processed, slots data is now ready + LOG_VERBOSE("callback_run_slots", {}); + callback_run_slots(); } LOG_VERBOSE("wait for new task", {}); // wait for new task @@ -401,7 +408,9 @@ struct llama_server_response { condition_results.wait(lock, [&]{ return !queue_results.empty(); }); - LOG_VERBOSE("condition_results unblock", {}); + LOG_VERBOSE("condition_results unblock", { + {"data", queue_results[0].result_json}, + }); for (int i = 0; i < (int) queue_results.size(); i++) {