refactor work queue related stuff

This commit is contained in:
ngxson 2024-02-25 12:56:34 +01:00
parent 9e359a4f47
commit 91e7e0ff17
2 changed files with 64 additions and 33 deletions

View file

@ -1021,13 +1021,23 @@ struct llama_server_context
return slot.images.size() > 0;
}
void send_error(task_server& task, const std::string &error)
void send_error(task_server &task, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
send_error(task.id, task.multitask_id, error);
}
void send_error(llama_client_slot &slot, const std::string &error)
{
send_error(slot.task_id, slot.multitask_id, error);
}
void send_error(int task_id, int multitask_id, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task_id, error.c_str());
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.id = task_id;
res.multitask_id = multitask_id;
res.stop = true;
res.error = true;
res.result_json = { { "content", error } };
queue_results.send(res);
@ -1466,7 +1476,9 @@ struct llama_server_context
queue_results.send(result);
}
bool update_slots() {
void run_slots() {
bool has_next_response = false; // whether to schedule next slot run, to generate next token
if (system_need_update)
{
LOG_TEE("updating system prompt\n");
@ -1482,14 +1494,9 @@ struct llama_server_context
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
kv_cache_clear();
}
return true;
return;
}
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
for (llama_client_slot &slot : slots)
{
if (slot.ga_n == 1)
@ -1737,7 +1744,8 @@ struct llama_server_context
if (has_images && !ingest_images(slot, n_batch))
{
LOG_TEE("failed processing images\n");
return false;
send_error(slot, "failed processing images");
continue;
}
// extract the logits only for the last token
@ -1755,7 +1763,6 @@ struct llama_server_context
if (batch.n_tokens == 0)
{
all_slots_are_idle = true;
return true;
}
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
@ -1812,7 +1819,13 @@ struct llama_server_context
{
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return false;
for (auto & slot : slots)
{
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
slot.release();
}
has_next_response = false;
break;
}
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
@ -1873,14 +1886,23 @@ struct llama_server_context
send_final_response(slot);
}
// if slot is not yet finish its work, we schedule next run
if (slot.has_next_token)
{
has_next_response = true;
}
slot.i_batch = -1;
}
}
return true;
}
void run_on_all_tasks_finished() {
update_slots();
if (has_next_response) {
LOG_VERBOSE("schedule next slot run", {});
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
}
}
};
@ -3210,7 +3232,7 @@ int main(int argc, char **argv)
bool running = true;
while (running)
{
running = llama.update_slots();
running = llama.run_slots();
}
}*/
//);
@ -3232,8 +3254,8 @@ int main(int argc, char **argv)
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
llama.queue_tasks.on_finish_multitask(std::bind(
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
llama.queue_tasks.on_all_tasks_finished(std::bind(
&llama_server_context::run_on_all_tasks_finished, &llama));
llama.queue_tasks.on_run_slots(std::bind(
&llama_server_context::run_slots, &llama));
llama.queue_results.on_multitask_update(std::bind(
&llama_server_queue::update_multitask,
&llama.queue_tasks,

View file

@ -227,7 +227,7 @@ struct llama_server_queue {
// callback functions
std::function<void(task_server&)> callback_new_task;
std::function<void(task_multi&)> callback_finish_multitask;
std::function<void(void)> callback_all_task_finished;
std::function<void(void)> callback_run_slots;
// Add a new task to the end of the queue
int post(task_server task) {
@ -257,14 +257,14 @@ struct llama_server_queue {
callback_new_task = callback;
}
// Register function to process a multitask
// Register function to process a multitask when it is finished
void on_finish_multitask(std::function<void(task_multi&)> callback) {
callback_finish_multitask = callback;
}
// Register the function to be called when the batch of tasks is finished
void on_all_tasks_finished(std::function<void(void)> callback) {
callback_all_task_finished = callback;
// Register the function to be called when all slots data is ready to be processed
void on_run_slots(std::function<void(void)> callback) {
callback_run_slots = callback;
}
// Call when the state of one slot is changed
@ -286,7 +286,13 @@ struct llama_server_queue {
condition_tasks.notify_all();
}
// Start the main loop.
/**
* Main loop consists of these steps:
* - Wait until a new task arrives
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Run all slots
*/
void start_loop() {
running = true;
while (true) {
@ -306,8 +312,8 @@ struct llama_server_queue {
LOG_VERBOSE("callback_new_task", {});
callback_new_task(task);
}
LOG_VERBOSE("callback_all_task_finished", {});
// process and update all the multitasks
LOG_VERBOSE("update_multitasks", {});
// check if we have any finished multitasks
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end())
{
@ -324,8 +330,9 @@ struct llama_server_queue {
++queue_iterator;
}
}
// all tasks in the current loop is finished
callback_all_task_finished();
// all tasks in the current loop is processed, slots data is now ready
LOG_VERBOSE("callback_run_slots", {});
callback_run_slots();
}
LOG_VERBOSE("wait for new task", {});
// wait for new task
@ -401,7 +408,9 @@ struct llama_server_response {
condition_results.wait(lock, [&]{
return !queue_results.empty();
});
LOG_VERBOSE("condition_results unblock", {});
LOG_VERBOSE("condition_results unblock", {
{"data", queue_results[0].result_json},
});
for (int i = 0; i < (int) queue_results.size(); i++)
{