server: fix a race condition cause by "request_completion"

This commit is contained in:
ngxson 2024-01-23 18:13:38 +01:00
parent d083c81761
commit 8f36df8fc9
2 changed files with 44 additions and 24 deletions

View file

@ -1122,9 +1122,10 @@ struct llama_server_context
queue_results.send(res);
}
int request_completion(json data, bool infill, bool embedding, int multitask_id)
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{
task_server task;
task.id = task_id;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
@ -1135,11 +1136,11 @@ struct llama_server_context
// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
{
return split_multiprompt_task(task);
split_multiprompt_task(task_id, task);
}
// otherwise, it's a single-prompt task, we actually queue it
return queue_tasks.post(task);
queue_tasks.post(task);
}
// for multiple images processing
@ -1218,25 +1219,30 @@ struct llama_server_context
queue_tasks.post(task);
}
int split_multiprompt_task(task_server& multiprompt_task)
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
{
int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);
int multitask_id = queue_tasks.get_next_id();
// generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
{
subtask_ids[i] = queue_tasks.get_new_id();
}
// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(multitask_id, subtask_ids);
// add subtasks
for (int i = 0; i < prompt_count; i++)
{
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}
// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(multitask_id, subtask_ids);
return multitask_id;
}
void process_single_task(task_server& task)
@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
return;
}
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
{
@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
break;
}
}
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
sink.done();
return true;
};
@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
}
json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
return;
}
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, true, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
return;
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while (true)
@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
});
if (!sink.write(str.c_str(), str.size()))
{
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
if (result.stop)
@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
}
}
llama.queue_results.remove_waiting_task_id(task_id);
sink.done();
return true;
};
@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
image_data = "";
}
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});

View file

@ -203,7 +203,9 @@ struct llama_server_queue {
// Add a new task to the end of the queue
int post(task_server task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
task.id = id++;
if (task.id == -1) {
task.id = id++;
}
queue_tasks.push_back(std::move(task));
condition_tasks.notify_one();
return task.id;
@ -215,8 +217,8 @@ struct llama_server_queue {
queue_tasks_deferred.push_back(std::move(task));
}
// Get the next task id
int get_next_id() {
// Get the next id for creating anew task
int get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return id++;
}