refactor work queue related stuff
This commit is contained in:
parent
9e359a4f47
commit
91e7e0ff17
2 changed files with 64 additions and 33 deletions
|
@ -1021,13 +1021,23 @@ struct llama_server_context
|
|||
return slot.images.size() > 0;
|
||||
}
|
||||
|
||||
void send_error(task_server& task, const std::string &error)
|
||||
void send_error(task_server &task, const std::string &error)
|
||||
{
|
||||
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
|
||||
send_error(task.id, task.multitask_id, error);
|
||||
}
|
||||
|
||||
void send_error(llama_client_slot &slot, const std::string &error)
|
||||
{
|
||||
send_error(slot.task_id, slot.multitask_id, error);
|
||||
}
|
||||
|
||||
void send_error(int task_id, int multitask_id, const std::string &error)
|
||||
{
|
||||
LOG_TEE("task %i - error: %s\n", task_id, error.c_str());
|
||||
task_result res;
|
||||
res.id = task.id;
|
||||
res.multitask_id = task.multitask_id;
|
||||
res.stop = false;
|
||||
res.id = task_id;
|
||||
res.multitask_id = multitask_id;
|
||||
res.stop = true;
|
||||
res.error = true;
|
||||
res.result_json = { { "content", error } };
|
||||
queue_results.send(res);
|
||||
|
@ -1466,7 +1476,9 @@ struct llama_server_context
|
|||
queue_results.send(result);
|
||||
}
|
||||
|
||||
bool update_slots() {
|
||||
void run_slots() {
|
||||
bool has_next_response = false; // whether to schedule next slot run, to generate next token
|
||||
|
||||
if (system_need_update)
|
||||
{
|
||||
LOG_TEE("updating system prompt\n");
|
||||
|
@ -1482,14 +1494,9 @@ struct llama_server_context
|
|||
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
|
||||
kv_cache_clear();
|
||||
}
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
task_server task;
|
||||
task.type = TASK_TYPE_NEXT_RESPONSE;
|
||||
task.target_id = -1;
|
||||
queue_tasks.post(task);
|
||||
|
||||
for (llama_client_slot &slot : slots)
|
||||
{
|
||||
if (slot.ga_n == 1)
|
||||
|
@ -1737,7 +1744,8 @@ struct llama_server_context
|
|||
if (has_images && !ingest_images(slot, n_batch))
|
||||
{
|
||||
LOG_TEE("failed processing images\n");
|
||||
return false;
|
||||
send_error(slot, "failed processing images");
|
||||
continue;
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
|
@ -1755,7 +1763,6 @@ struct llama_server_context
|
|||
if (batch.n_tokens == 0)
|
||||
{
|
||||
all_slots_are_idle = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||
|
@ -1812,7 +1819,13 @@ struct llama_server_context
|
|||
{
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return false;
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
slot.release();
|
||||
}
|
||||
has_next_response = false;
|
||||
break;
|
||||
}
|
||||
|
||||
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
|
||||
|
@ -1873,14 +1886,23 @@ struct llama_server_context
|
|||
send_final_response(slot);
|
||||
}
|
||||
|
||||
// if slot is not yet finish its work, we schedule next run
|
||||
if (slot.has_next_token)
|
||||
{
|
||||
has_next_response = true;
|
||||
}
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void run_on_all_tasks_finished() {
|
||||
update_slots();
|
||||
if (has_next_response) {
|
||||
LOG_VERBOSE("schedule next slot run", {});
|
||||
task_server task;
|
||||
task.type = TASK_TYPE_NEXT_RESPONSE;
|
||||
task.target_id = -1;
|
||||
queue_tasks.post(task);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -3210,7 +3232,7 @@ int main(int argc, char **argv)
|
|||
bool running = true;
|
||||
while (running)
|
||||
{
|
||||
running = llama.update_slots();
|
||||
running = llama.run_slots();
|
||||
}
|
||||
}*/
|
||||
//);
|
||||
|
@ -3232,8 +3254,8 @@ int main(int argc, char **argv)
|
|||
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
|
||||
llama.queue_tasks.on_finish_multitask(std::bind(
|
||||
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
|
||||
llama.queue_tasks.on_all_tasks_finished(std::bind(
|
||||
&llama_server_context::run_on_all_tasks_finished, &llama));
|
||||
llama.queue_tasks.on_run_slots(std::bind(
|
||||
&llama_server_context::run_slots, &llama));
|
||||
llama.queue_results.on_multitask_update(std::bind(
|
||||
&llama_server_queue::update_multitask,
|
||||
&llama.queue_tasks,
|
||||
|
|
|
@ -227,7 +227,7 @@ struct llama_server_queue {
|
|||
// callback functions
|
||||
std::function<void(task_server&)> callback_new_task;
|
||||
std::function<void(task_multi&)> callback_finish_multitask;
|
||||
std::function<void(void)> callback_all_task_finished;
|
||||
std::function<void(void)> callback_run_slots;
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(task_server task) {
|
||||
|
@ -257,14 +257,14 @@ struct llama_server_queue {
|
|||
callback_new_task = callback;
|
||||
}
|
||||
|
||||
// Register function to process a multitask
|
||||
// Register function to process a multitask when it is finished
|
||||
void on_finish_multitask(std::function<void(task_multi&)> callback) {
|
||||
callback_finish_multitask = callback;
|
||||
}
|
||||
|
||||
// Register the function to be called when the batch of tasks is finished
|
||||
void on_all_tasks_finished(std::function<void(void)> callback) {
|
||||
callback_all_task_finished = callback;
|
||||
// Register the function to be called when all slots data is ready to be processed
|
||||
void on_run_slots(std::function<void(void)> callback) {
|
||||
callback_run_slots = callback;
|
||||
}
|
||||
|
||||
// Call when the state of one slot is changed
|
||||
|
@ -286,7 +286,13 @@ struct llama_server_queue {
|
|||
condition_tasks.notify_all();
|
||||
}
|
||||
|
||||
// Start the main loop.
|
||||
/**
|
||||
* Main loop consists of these steps:
|
||||
* - Wait until a new task arrives
|
||||
* - Process the task (i.e. maybe copy data into slot)
|
||||
* - Check if multitask is finished
|
||||
* - Run all slots
|
||||
*/
|
||||
void start_loop() {
|
||||
running = true;
|
||||
while (true) {
|
||||
|
@ -306,8 +312,8 @@ struct llama_server_queue {
|
|||
LOG_VERBOSE("callback_new_task", {});
|
||||
callback_new_task(task);
|
||||
}
|
||||
LOG_VERBOSE("callback_all_task_finished", {});
|
||||
// process and update all the multitasks
|
||||
LOG_VERBOSE("update_multitasks", {});
|
||||
// check if we have any finished multitasks
|
||||
auto queue_iterator = queue_multitasks.begin();
|
||||
while (queue_iterator != queue_multitasks.end())
|
||||
{
|
||||
|
@ -324,8 +330,9 @@ struct llama_server_queue {
|
|||
++queue_iterator;
|
||||
}
|
||||
}
|
||||
// all tasks in the current loop is finished
|
||||
callback_all_task_finished();
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
LOG_VERBOSE("callback_run_slots", {});
|
||||
callback_run_slots();
|
||||
}
|
||||
LOG_VERBOSE("wait for new task", {});
|
||||
// wait for new task
|
||||
|
@ -401,7 +408,9 @@ struct llama_server_response {
|
|||
condition_results.wait(lock, [&]{
|
||||
return !queue_results.empty();
|
||||
});
|
||||
LOG_VERBOSE("condition_results unblock", {});
|
||||
LOG_VERBOSE("condition_results unblock", {
|
||||
{"data", queue_results[0].result_json},
|
||||
});
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue