server: fix a race condition cause by "request_completion"

This commit is contained in:
ngxson 2024-01-23 18:13:38 +01:00
parent d083c81761
commit 8f36df8fc9
2 changed files with 44 additions and 24 deletions

View file

@ -1122,9 +1122,10 @@ struct llama_server_context
queue_results.send(res); queue_results.send(res);
} }
int request_completion(json data, bool infill, bool embedding, int multitask_id) void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{ {
task_server task; task_server task;
task.id = task_id;
task.target_id = 0; task.target_id = 0;
task.data = std::move(data); task.data = std::move(data);
task.infill_mode = infill; task.infill_mode = infill;
@ -1135,11 +1136,11 @@ struct llama_server_context
// when a completion task's prompt array is not a singleton, we split it into multiple requests // when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.count("prompt") && task.data.at("prompt").size() > 1) if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
{ {
return split_multiprompt_task(task); split_multiprompt_task(task_id, task);
} }
// otherwise, it's a single-prompt task, we actually queue it // otherwise, it's a single-prompt task, we actually queue it
return queue_tasks.post(task); queue_tasks.post(task);
} }
// for multiple images processing // for multiple images processing
@ -1218,25 +1219,30 @@ struct llama_server_context
queue_tasks.post(task); queue_tasks.post(task);
} }
int split_multiprompt_task(task_server& multiprompt_task) void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
{ {
int prompt_count = multiprompt_task.data.at("prompt").size(); int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1); assert(prompt_count > 1);
int multitask_id = queue_tasks.get_next_id(); // generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count); std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++) for (int i = 0; i < prompt_count; i++)
{
subtask_ids[i] = queue_tasks.get_new_id();
}
// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(multitask_id, subtask_ids);
// add subtasks
for (int i = 0; i < prompt_count; i++)
{ {
json subtask_data = multiprompt_task.data; json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i]; subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.) // subtasks inherit everything else (infill mode, embedding mode, etc.)
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
} }
// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(multitask_id, subtask_ids);
return multitask_id;
} }
void process_single_task(task_server& task) void process_single_task(task_server& task)
@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
return; return;
} }
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false, -1); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
{ {
res.status = 404; res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
} }
llama.queue_results.remove_waiting_task_id(task_id);
} else { } else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
{ {
@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
break; break;
} }
} }
sink.done();
llama.queue_results.remove_waiting_task_id(task_id); llama.queue_results.remove_waiting_task_id(task_id);
sink.done();
return true; return true;
}; };
@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
} }
json data = oaicompat_completion_params_parse(json::parse(req.body)); json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false, -1); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
} else { } else {
res.status = 500; res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
} }
llama.queue_results.remove_waiting_task_id(task_id);
} else { } else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) { while (true) {
@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
return; return;
} }
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false, -1); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, true, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
{ {
res.status = 404; res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
return;
} }
llama.queue_results.remove_waiting_task_id(task_id);
} else { } else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while (true) while (true)
@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
}); });
if (!sink.write(str.c_str(), str.size())) if (!sink.write(str.c_str(), str.size()))
{ {
llama.queue_results.remove_waiting_task_id(task_id);
return false; return false;
} }
if (result.stop) if (result.stop)
@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
} }
} }
llama.queue_results.remove_waiting_task_id(task_id);
sink.done(); sink.done();
return true; return true;
}; };
@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
image_data = ""; image_data = "";
} }
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); // create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}); });

View file

@ -203,7 +203,9 @@ struct llama_server_queue {
// Add a new task to the end of the queue // Add a new task to the end of the queue
int post(task_server task) { int post(task_server task) {
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
if (task.id == -1) {
task.id = id++; task.id = id++;
}
queue_tasks.push_back(std::move(task)); queue_tasks.push_back(std::move(task));
condition_tasks.notify_one(); condition_tasks.notify_one();
return task.id; return task.id;
@ -215,8 +217,8 @@ struct llama_server_queue {
queue_tasks_deferred.push_back(std::move(task)); queue_tasks_deferred.push_back(std::move(task));
} }
// Get the next task id // Get the next id for creating anew task
int get_next_id() { int get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
return id++; return id++;
} }