refactor work queue related stuff

This commit is contained in:
ngxson 2024-02-25 12:56:34 +01:00
parent 9e359a4f47
commit 91e7e0ff17
2 changed files with 64 additions and 33 deletions

View file

@ -1023,11 +1023,21 @@ struct llama_server_context
void send_error(task_server &task, const std::string &error) void send_error(task_server &task, const std::string &error)
{ {
LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); send_error(task.id, task.multitask_id, error);
}
void send_error(llama_client_slot &slot, const std::string &error)
{
send_error(slot.task_id, slot.multitask_id, error);
}
void send_error(int task_id, int multitask_id, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task_id, error.c_str());
task_result res; task_result res;
res.id = task.id; res.id = task_id;
res.multitask_id = task.multitask_id; res.multitask_id = multitask_id;
res.stop = false; res.stop = true;
res.error = true; res.error = true;
res.result_json = { { "content", error } }; res.result_json = { { "content", error } };
queue_results.send(res); queue_results.send(res);
@ -1466,7 +1476,9 @@ struct llama_server_context
queue_results.send(result); queue_results.send(result);
} }
bool update_slots() { void run_slots() {
bool has_next_response = false; // whether to schedule next slot run, to generate next token
if (system_need_update) if (system_need_update)
{ {
LOG_TEE("updating system prompt\n"); LOG_TEE("updating system prompt\n");
@ -1482,14 +1494,9 @@ struct llama_server_context
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n"); LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
kv_cache_clear(); kv_cache_clear();
} }
return true; return;
} }
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
for (llama_client_slot &slot : slots) for (llama_client_slot &slot : slots)
{ {
if (slot.ga_n == 1) if (slot.ga_n == 1)
@ -1737,7 +1744,8 @@ struct llama_server_context
if (has_images && !ingest_images(slot, n_batch)) if (has_images && !ingest_images(slot, n_batch))
{ {
LOG_TEE("failed processing images\n"); LOG_TEE("failed processing images\n");
return false; send_error(slot, "failed processing images");
continue;
} }
// extract the logits only for the last token // extract the logits only for the last token
@ -1755,7 +1763,6 @@ struct llama_server_context
if (batch.n_tokens == 0) if (batch.n_tokens == 0)
{ {
all_slots_are_idle = true; all_slots_are_idle = true;
return true;
} }
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
@ -1812,7 +1819,13 @@ struct llama_server_context
{ {
// if you get here, it means the KV cache is full - try increasing it via the context size // if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return false; for (auto & slot : slots)
{
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
slot.release();
}
has_next_response = false;
break;
} }
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
@ -1873,14 +1886,23 @@ struct llama_server_context
send_final_response(slot); send_final_response(slot);
} }
// if slot is not yet finish its work, we schedule next run
if (slot.has_next_token)
{
has_next_response = true;
}
slot.i_batch = -1; slot.i_batch = -1;
} }
} }
return true;
}
void run_on_all_tasks_finished() { if (has_next_response) {
update_slots(); LOG_VERBOSE("schedule next slot run", {});
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
}
} }
}; };
@ -3210,7 +3232,7 @@ int main(int argc, char **argv)
bool running = true; bool running = true;
while (running) while (running)
{ {
running = llama.update_slots(); running = llama.run_slots();
} }
}*/ }*/
//); //);
@ -3232,8 +3254,8 @@ int main(int argc, char **argv)
&llama_server_context::process_single_task, &llama, std::placeholders::_1)); &llama_server_context::process_single_task, &llama, std::placeholders::_1));
llama.queue_tasks.on_finish_multitask(std::bind( llama.queue_tasks.on_finish_multitask(std::bind(
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
llama.queue_tasks.on_all_tasks_finished(std::bind( llama.queue_tasks.on_run_slots(std::bind(
&llama_server_context::run_on_all_tasks_finished, &llama)); &llama_server_context::run_slots, &llama));
llama.queue_results.on_multitask_update(std::bind( llama.queue_results.on_multitask_update(std::bind(
&llama_server_queue::update_multitask, &llama_server_queue::update_multitask,
&llama.queue_tasks, &llama.queue_tasks,

View file

@ -227,7 +227,7 @@ struct llama_server_queue {
// callback functions // callback functions
std::function<void(task_server&)> callback_new_task; std::function<void(task_server&)> callback_new_task;
std::function<void(task_multi&)> callback_finish_multitask; std::function<void(task_multi&)> callback_finish_multitask;
std::function<void(void)> callback_all_task_finished; std::function<void(void)> callback_run_slots;
// Add a new task to the end of the queue // Add a new task to the end of the queue
int post(task_server task) { int post(task_server task) {
@ -257,14 +257,14 @@ struct llama_server_queue {
callback_new_task = callback; callback_new_task = callback;
} }
// Register function to process a multitask // Register function to process a multitask when it is finished
void on_finish_multitask(std::function<void(task_multi&)> callback) { void on_finish_multitask(std::function<void(task_multi&)> callback) {
callback_finish_multitask = callback; callback_finish_multitask = callback;
} }
// Register the function to be called when the batch of tasks is finished // Register the function to be called when all slots data is ready to be processed
void on_all_tasks_finished(std::function<void(void)> callback) { void on_run_slots(std::function<void(void)> callback) {
callback_all_task_finished = callback; callback_run_slots = callback;
} }
// Call when the state of one slot is changed // Call when the state of one slot is changed
@ -286,7 +286,13 @@ struct llama_server_queue {
condition_tasks.notify_all(); condition_tasks.notify_all();
} }
// Start the main loop. /**
* Main loop consists of these steps:
* - Wait until a new task arrives
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Run all slots
*/
void start_loop() { void start_loop() {
running = true; running = true;
while (true) { while (true) {
@ -306,8 +312,8 @@ struct llama_server_queue {
LOG_VERBOSE("callback_new_task", {}); LOG_VERBOSE("callback_new_task", {});
callback_new_task(task); callback_new_task(task);
} }
LOG_VERBOSE("callback_all_task_finished", {}); LOG_VERBOSE("update_multitasks", {});
// process and update all the multitasks // check if we have any finished multitasks
auto queue_iterator = queue_multitasks.begin(); auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end()) while (queue_iterator != queue_multitasks.end())
{ {
@ -324,8 +330,9 @@ struct llama_server_queue {
++queue_iterator; ++queue_iterator;
} }
} }
// all tasks in the current loop is finished // all tasks in the current loop is processed, slots data is now ready
callback_all_task_finished(); LOG_VERBOSE("callback_run_slots", {});
callback_run_slots();
} }
LOG_VERBOSE("wait for new task", {}); LOG_VERBOSE("wait for new task", {});
// wait for new task // wait for new task
@ -401,7 +408,9 @@ struct llama_server_response {
condition_results.wait(lock, [&]{ condition_results.wait(lock, [&]{
return !queue_results.empty(); return !queue_results.empty();
}); });
LOG_VERBOSE("condition_results unblock", {}); LOG_VERBOSE("condition_results unblock", {
{"data", queue_results[0].result_json},
});
for (int i = 0; i < (int) queue_results.size(); i++) for (int i = 0; i < (int) queue_results.size(); i++)
{ {