refactor work queue related stuff
This commit is contained in:
parent
9e359a4f47
commit
91e7e0ff17
2 changed files with 64 additions and 33 deletions
|
@ -1021,13 +1021,23 @@ struct llama_server_context
|
||||||
return slot.images.size() > 0;
|
return slot.images.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_error(task_server& task, const std::string &error)
|
void send_error(task_server &task, const std::string &error)
|
||||||
{
|
{
|
||||||
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
|
send_error(task.id, task.multitask_id, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
void send_error(llama_client_slot &slot, const std::string &error)
|
||||||
|
{
|
||||||
|
send_error(slot.task_id, slot.multitask_id, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
void send_error(int task_id, int multitask_id, const std::string &error)
|
||||||
|
{
|
||||||
|
LOG_TEE("task %i - error: %s\n", task_id, error.c_str());
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = task.id;
|
res.id = task_id;
|
||||||
res.multitask_id = task.multitask_id;
|
res.multitask_id = multitask_id;
|
||||||
res.stop = false;
|
res.stop = true;
|
||||||
res.error = true;
|
res.error = true;
|
||||||
res.result_json = { { "content", error } };
|
res.result_json = { { "content", error } };
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
|
@ -1466,7 +1476,9 @@ struct llama_server_context
|
||||||
queue_results.send(result);
|
queue_results.send(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool update_slots() {
|
void run_slots() {
|
||||||
|
bool has_next_response = false; // whether to schedule next slot run, to generate next token
|
||||||
|
|
||||||
if (system_need_update)
|
if (system_need_update)
|
||||||
{
|
{
|
||||||
LOG_TEE("updating system prompt\n");
|
LOG_TEE("updating system prompt\n");
|
||||||
|
@ -1482,14 +1494,9 @@ struct llama_server_context
|
||||||
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
|
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
|
||||||
kv_cache_clear();
|
kv_cache_clear();
|
||||||
}
|
}
|
||||||
return true;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
task_server task;
|
|
||||||
task.type = TASK_TYPE_NEXT_RESPONSE;
|
|
||||||
task.target_id = -1;
|
|
||||||
queue_tasks.post(task);
|
|
||||||
|
|
||||||
for (llama_client_slot &slot : slots)
|
for (llama_client_slot &slot : slots)
|
||||||
{
|
{
|
||||||
if (slot.ga_n == 1)
|
if (slot.ga_n == 1)
|
||||||
|
@ -1737,7 +1744,8 @@ struct llama_server_context
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
{
|
{
|
||||||
LOG_TEE("failed processing images\n");
|
LOG_TEE("failed processing images\n");
|
||||||
return false;
|
send_error(slot, "failed processing images");
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
|
@ -1755,7 +1763,6 @@ struct llama_server_context
|
||||||
if (batch.n_tokens == 0)
|
if (batch.n_tokens == 0)
|
||||||
{
|
{
|
||||||
all_slots_are_idle = true;
|
all_slots_are_idle = true;
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||||
|
@ -1812,7 +1819,13 @@ struct llama_server_context
|
||||||
{
|
{
|
||||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||||
return false;
|
for (auto & slot : slots)
|
||||||
|
{
|
||||||
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||||
|
slot.release();
|
||||||
|
}
|
||||||
|
has_next_response = false;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
|
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
|
||||||
|
@ -1873,14 +1886,23 @@ struct llama_server_context
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if slot is not yet finish its work, we schedule next run
|
||||||
|
if (slot.has_next_token)
|
||||||
|
{
|
||||||
|
has_next_response = true;
|
||||||
|
}
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void run_on_all_tasks_finished() {
|
if (has_next_response) {
|
||||||
update_slots();
|
LOG_VERBOSE("schedule next slot run", {});
|
||||||
|
task_server task;
|
||||||
|
task.type = TASK_TYPE_NEXT_RESPONSE;
|
||||||
|
task.target_id = -1;
|
||||||
|
queue_tasks.post(task);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3210,7 +3232,7 @@ int main(int argc, char **argv)
|
||||||
bool running = true;
|
bool running = true;
|
||||||
while (running)
|
while (running)
|
||||||
{
|
{
|
||||||
running = llama.update_slots();
|
running = llama.run_slots();
|
||||||
}
|
}
|
||||||
}*/
|
}*/
|
||||||
//);
|
//);
|
||||||
|
@ -3232,8 +3254,8 @@ int main(int argc, char **argv)
|
||||||
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
|
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
|
||||||
llama.queue_tasks.on_finish_multitask(std::bind(
|
llama.queue_tasks.on_finish_multitask(std::bind(
|
||||||
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
|
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
|
||||||
llama.queue_tasks.on_all_tasks_finished(std::bind(
|
llama.queue_tasks.on_run_slots(std::bind(
|
||||||
&llama_server_context::run_on_all_tasks_finished, &llama));
|
&llama_server_context::run_slots, &llama));
|
||||||
llama.queue_results.on_multitask_update(std::bind(
|
llama.queue_results.on_multitask_update(std::bind(
|
||||||
&llama_server_queue::update_multitask,
|
&llama_server_queue::update_multitask,
|
||||||
&llama.queue_tasks,
|
&llama.queue_tasks,
|
||||||
|
|
|
@ -227,7 +227,7 @@ struct llama_server_queue {
|
||||||
// callback functions
|
// callback functions
|
||||||
std::function<void(task_server&)> callback_new_task;
|
std::function<void(task_server&)> callback_new_task;
|
||||||
std::function<void(task_multi&)> callback_finish_multitask;
|
std::function<void(task_multi&)> callback_finish_multitask;
|
||||||
std::function<void(void)> callback_all_task_finished;
|
std::function<void(void)> callback_run_slots;
|
||||||
|
|
||||||
// Add a new task to the end of the queue
|
// Add a new task to the end of the queue
|
||||||
int post(task_server task) {
|
int post(task_server task) {
|
||||||
|
@ -257,14 +257,14 @@ struct llama_server_queue {
|
||||||
callback_new_task = callback;
|
callback_new_task = callback;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register function to process a multitask
|
// Register function to process a multitask when it is finished
|
||||||
void on_finish_multitask(std::function<void(task_multi&)> callback) {
|
void on_finish_multitask(std::function<void(task_multi&)> callback) {
|
||||||
callback_finish_multitask = callback;
|
callback_finish_multitask = callback;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the function to be called when the batch of tasks is finished
|
// Register the function to be called when all slots data is ready to be processed
|
||||||
void on_all_tasks_finished(std::function<void(void)> callback) {
|
void on_run_slots(std::function<void(void)> callback) {
|
||||||
callback_all_task_finished = callback;
|
callback_run_slots = callback;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call when the state of one slot is changed
|
// Call when the state of one slot is changed
|
||||||
|
@ -286,7 +286,13 @@ struct llama_server_queue {
|
||||||
condition_tasks.notify_all();
|
condition_tasks.notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the main loop.
|
/**
|
||||||
|
* Main loop consists of these steps:
|
||||||
|
* - Wait until a new task arrives
|
||||||
|
* - Process the task (i.e. maybe copy data into slot)
|
||||||
|
* - Check if multitask is finished
|
||||||
|
* - Run all slots
|
||||||
|
*/
|
||||||
void start_loop() {
|
void start_loop() {
|
||||||
running = true;
|
running = true;
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -306,8 +312,8 @@ struct llama_server_queue {
|
||||||
LOG_VERBOSE("callback_new_task", {});
|
LOG_VERBOSE("callback_new_task", {});
|
||||||
callback_new_task(task);
|
callback_new_task(task);
|
||||||
}
|
}
|
||||||
LOG_VERBOSE("callback_all_task_finished", {});
|
LOG_VERBOSE("update_multitasks", {});
|
||||||
// process and update all the multitasks
|
// check if we have any finished multitasks
|
||||||
auto queue_iterator = queue_multitasks.begin();
|
auto queue_iterator = queue_multitasks.begin();
|
||||||
while (queue_iterator != queue_multitasks.end())
|
while (queue_iterator != queue_multitasks.end())
|
||||||
{
|
{
|
||||||
|
@ -324,8 +330,9 @@ struct llama_server_queue {
|
||||||
++queue_iterator;
|
++queue_iterator;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// all tasks in the current loop is finished
|
// all tasks in the current loop is processed, slots data is now ready
|
||||||
callback_all_task_finished();
|
LOG_VERBOSE("callback_run_slots", {});
|
||||||
|
callback_run_slots();
|
||||||
}
|
}
|
||||||
LOG_VERBOSE("wait for new task", {});
|
LOG_VERBOSE("wait for new task", {});
|
||||||
// wait for new task
|
// wait for new task
|
||||||
|
@ -401,7 +408,9 @@ struct llama_server_response {
|
||||||
condition_results.wait(lock, [&]{
|
condition_results.wait(lock, [&]{
|
||||||
return !queue_results.empty();
|
return !queue_results.empty();
|
||||||
});
|
});
|
||||||
LOG_VERBOSE("condition_results unblock", {});
|
LOG_VERBOSE("condition_results unblock", {
|
||||||
|
{"data", queue_results[0].result_json},
|
||||||
|
});
|
||||||
|
|
||||||
for (int i = 0; i < (int) queue_results.size(); i++)
|
for (int i = 0; i < (int) queue_results.size(); i++)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue