server: fix a race condition cause by "request_completion"
This commit is contained in:
parent
d083c81761
commit
8f36df8fc9
2 changed files with 44 additions and 24 deletions
|
@ -1122,9 +1122,10 @@ struct llama_server_context
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
int request_completion(json data, bool infill, bool embedding, int multitask_id)
|
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
|
||||||
{
|
{
|
||||||
task_server task;
|
task_server task;
|
||||||
|
task.id = task_id;
|
||||||
task.target_id = 0;
|
task.target_id = 0;
|
||||||
task.data = std::move(data);
|
task.data = std::move(data);
|
||||||
task.infill_mode = infill;
|
task.infill_mode = infill;
|
||||||
|
@ -1135,11 +1136,11 @@ struct llama_server_context
|
||||||
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
||||||
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
|
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
|
||||||
{
|
{
|
||||||
return split_multiprompt_task(task);
|
split_multiprompt_task(task_id, task);
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, it's a single-prompt task, we actually queue it
|
// otherwise, it's a single-prompt task, we actually queue it
|
||||||
return queue_tasks.post(task);
|
queue_tasks.post(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for multiple images processing
|
// for multiple images processing
|
||||||
|
@ -1218,25 +1219,30 @@ struct llama_server_context
|
||||||
queue_tasks.post(task);
|
queue_tasks.post(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
int split_multiprompt_task(task_server& multiprompt_task)
|
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
|
||||||
{
|
{
|
||||||
int prompt_count = multiprompt_task.data.at("prompt").size();
|
int prompt_count = multiprompt_task.data.at("prompt").size();
|
||||||
assert(prompt_count > 1);
|
assert(prompt_count > 1);
|
||||||
|
|
||||||
int multitask_id = queue_tasks.get_next_id();
|
// generate all the ID for subtask
|
||||||
std::vector<int> subtask_ids(prompt_count);
|
std::vector<int> subtask_ids(prompt_count);
|
||||||
for (int i = 0; i < prompt_count; i++)
|
for (int i = 0; i < prompt_count; i++)
|
||||||
|
{
|
||||||
|
subtask_ids[i] = queue_tasks.get_new_id();
|
||||||
|
}
|
||||||
|
|
||||||
|
// queue up the multitask so we can track its subtask progression
|
||||||
|
queue_tasks.add_multitask(multitask_id, subtask_ids);
|
||||||
|
|
||||||
|
// add subtasks
|
||||||
|
for (int i = 0; i < prompt_count; i++)
|
||||||
{
|
{
|
||||||
json subtask_data = multiprompt_task.data;
|
json subtask_data = multiprompt_task.data;
|
||||||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||||
|
|
||||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||||
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// queue up the multitask so we can track its subtask progression
|
|
||||||
queue_tasks.add_multitask(multitask_id, subtask_ids);
|
|
||||||
return multitask_id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void process_single_task(task_server& task)
|
void process_single_task(task_server& task)
|
||||||
|
@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
llama.request_completion(task_id, data, false, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
|
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
|
||||||
{
|
{
|
||||||
|
@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sink.done();
|
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
sink.done();
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||||
|
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
llama.request_completion(task_id, data, false, false, -1);
|
||||||
|
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
|
@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
|
||||||
} else {
|
} else {
|
||||||
res.status = 500;
|
res.status = 500;
|
||||||
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
|
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, true, false, -1);
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
llama.request_completion(task_id, data, true, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
|
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
|
||||||
while (true)
|
while (true)
|
||||||
|
@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
|
||||||
});
|
});
|
||||||
if (!sink.write(str.c_str(), str.size()))
|
if (!sink.write(str.c_str(), str.size()))
|
||||||
{
|
{
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (result.stop)
|
if (result.stop)
|
||||||
|
@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
sink.done();
|
sink.done();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
|
||||||
image_data = "";
|
image_data = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
|
// create and queue the task
|
||||||
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
|
||||||
|
|
||||||
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
|
// send the result
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -203,7 +203,9 @@ struct llama_server_queue {
|
||||||
// Add a new task to the end of the queue
|
// Add a new task to the end of the queue
|
||||||
int post(task_server task) {
|
int post(task_server task) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
|
if (task.id == -1) {
|
||||||
task.id = id++;
|
task.id = id++;
|
||||||
|
}
|
||||||
queue_tasks.push_back(std::move(task));
|
queue_tasks.push_back(std::move(task));
|
||||||
condition_tasks.notify_one();
|
condition_tasks.notify_one();
|
||||||
return task.id;
|
return task.id;
|
||||||
|
@ -215,8 +217,8 @@ struct llama_server_queue {
|
||||||
queue_tasks_deferred.push_back(std::move(task));
|
queue_tasks_deferred.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next task id
|
// Get the next id for creating anew task
|
||||||
int get_next_id() {
|
int get_new_id() {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
return id++;
|
return id++;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue