* add multiprompt support
This commit is contained in:
parent
3e73d31d9c
commit
09562678d9
1 changed files with 133 additions and 11 deletions
|
@ -24,6 +24,7 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#ifndef SERVER_VERBOSE
|
#ifndef SERVER_VERBOSE
|
||||||
#define SERVER_VERBOSE 1
|
#define SERVER_VERBOSE 1
|
||||||
|
@ -155,15 +156,23 @@ struct task_server {
|
||||||
json data;
|
json data;
|
||||||
bool infill_mode = false;
|
bool infill_mode = false;
|
||||||
bool embedding_mode = false;
|
bool embedding_mode = false;
|
||||||
|
int multitask_id = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct task_result {
|
struct task_result {
|
||||||
int id;
|
int id;
|
||||||
|
int multitask_id = -1;
|
||||||
bool stop;
|
bool stop;
|
||||||
bool error;
|
bool error;
|
||||||
json result_json;
|
json result_json;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct task_multi {
|
||||||
|
int id;
|
||||||
|
std::unordered_set<int> subtasks_remaining{};
|
||||||
|
std::vector<task_result> results{};
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: can become bool if we can't find use of more states
|
// TODO: can become bool if we can't find use of more states
|
||||||
enum slot_state
|
enum slot_state
|
||||||
{
|
{
|
||||||
|
@ -406,6 +415,9 @@ struct llama_client_slot
|
||||||
double t_prompt_processing; // ms
|
double t_prompt_processing; // ms
|
||||||
double t_token_generation; // ms
|
double t_token_generation; // ms
|
||||||
|
|
||||||
|
// multitasks
|
||||||
|
int multitask_id = -1;
|
||||||
|
|
||||||
void reset() {
|
void reset() {
|
||||||
num_prompt_tokens = 0;
|
num_prompt_tokens = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
|
@ -512,7 +524,7 @@ struct llama_server_context
|
||||||
bool all_slots_are_idle = false;
|
bool all_slots_are_idle = false;
|
||||||
bool add_bos_token = true;
|
bool add_bos_token = true;
|
||||||
|
|
||||||
int32_t id_gen;
|
std::atomic<int32_t> id_gen;
|
||||||
int32_t n_ctx; // total context for all clients / slots
|
int32_t n_ctx; // total context for all clients / slots
|
||||||
|
|
||||||
// system prompt
|
// system prompt
|
||||||
|
@ -529,8 +541,10 @@ struct llama_server_context
|
||||||
|
|
||||||
std::vector<task_server> queue_tasks;
|
std::vector<task_server> queue_tasks;
|
||||||
std::vector<task_result> queue_results;
|
std::vector<task_result> queue_results;
|
||||||
|
std::vector<task_multi> queue_multitasks;
|
||||||
std::mutex mutex_tasks;
|
std::mutex mutex_tasks;
|
||||||
std::mutex mutex_results;
|
std::mutex mutex_results;
|
||||||
|
std::mutex mutex_multitasks;
|
||||||
|
|
||||||
~llama_server_context()
|
~llama_server_context()
|
||||||
{
|
{
|
||||||
|
@ -1112,17 +1126,40 @@ struct llama_server_context
|
||||||
return slot.images.size() > 0;
|
return slot.images.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_error(int id, std::string error)
|
void send_error(task_server& task, std::string error)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = id;
|
res.id = task.id;
|
||||||
|
res.multitask_id = task.multitask_id;
|
||||||
res.stop = false;
|
res.stop = false;
|
||||||
res.error = true;
|
res.error = true;
|
||||||
res.result_json = { { "content", error } };
|
res.result_json = { { "content", error } };
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void add_multi_task(int id, std::vector<int>& sub_ids)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_multitasks);
|
||||||
|
task_multi multi;
|
||||||
|
multi.id = id;
|
||||||
|
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
||||||
|
queue_multitasks.push_back(multi);
|
||||||
|
}
|
||||||
|
|
||||||
|
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_multitasks);
|
||||||
|
for (auto& multitask : queue_multitasks)
|
||||||
|
{
|
||||||
|
if (multitask.id == multitask_id)
|
||||||
|
{
|
||||||
|
multitask.subtasks_remaining.erase(subtask_id);
|
||||||
|
multitask.results.push_back(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
json get_model_props()
|
json get_model_props()
|
||||||
{
|
{
|
||||||
return get_formated_generation(slots[0]);
|
return get_formated_generation(slots[0]);
|
||||||
|
@ -1167,6 +1204,7 @@ struct llama_server_context
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = false;
|
res.stop = false;
|
||||||
|
|
||||||
|
@ -1206,6 +1244,7 @@ struct llama_server_context
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
|
@ -1251,6 +1290,16 @@ struct llama_server_context
|
||||||
res.result_json["model"] = slot.oaicompat_model;
|
res.result_json["model"] = slot.oaicompat_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if this task has a multitask associated with it, then we update the multitask
|
||||||
|
if (slot.multitask_id != -1)
|
||||||
|
{
|
||||||
|
update_multi_task(slot.multitask_id, slot.task_id, res);
|
||||||
|
}
|
||||||
|
else // otherwise update the results queue
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1259,6 +1308,7 @@ struct llama_server_context
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
|
@ -1285,9 +1335,8 @@ struct llama_server_context
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
int request_completion(json data, bool infill, bool embedding)
|
int request_completion(json data, bool infill, bool embedding, int multitask_id)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = id_gen++;
|
task.id = id_gen++;
|
||||||
task.target_id = 0;
|
task.target_id = 0;
|
||||||
|
@ -1295,6 +1344,17 @@ struct llama_server_context
|
||||||
task.infill_mode = infill;
|
task.infill_mode = infill;
|
||||||
task.embedding_mode = embedding;
|
task.embedding_mode = embedding;
|
||||||
task.type = COMPLETION_TASK;
|
task.type = COMPLETION_TASK;
|
||||||
|
task.multitask_id = multitask_id;
|
||||||
|
|
||||||
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
||||||
|
if (task.data.at("prompt").size() > 1)
|
||||||
|
{
|
||||||
|
auto id = split_multiprompt_task_into_subtasks(task);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise, it's a single-prompt task, we actually queue it
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
queue_tasks.push_back(task);
|
queue_tasks.push_back(task);
|
||||||
return task.id;
|
return task.id;
|
||||||
}
|
}
|
||||||
|
@ -1313,8 +1373,17 @@ struct llama_server_context
|
||||||
|
|
||||||
for (int i = 0; i < (int) queue_results.size(); i++)
|
for (int i = 0; i < (int) queue_results.size(); i++)
|
||||||
{
|
{
|
||||||
|
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
||||||
|
if (queue_results[i].multitask_id == task_id)
|
||||||
|
{
|
||||||
|
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
|
||||||
|
queue_results.erase(queue_results.begin() + i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (queue_results[i].id == task_id)
|
if (queue_results[i].id == task_id)
|
||||||
{
|
{
|
||||||
|
assert(queue_results[i].multitask_id == -1);
|
||||||
task_result res = queue_results[i];
|
task_result res = queue_results[i];
|
||||||
queue_results.erase(queue_results.begin() + i);
|
queue_results.erase(queue_results.begin() + i);
|
||||||
return res;
|
return res;
|
||||||
|
@ -1404,6 +1473,25 @@ struct llama_server_context
|
||||||
queue_tasks.push_back(task);
|
queue_tasks.push_back(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int split_multiprompt_task_into_subtasks(task_server& task)
|
||||||
|
{
|
||||||
|
auto prompt_count = task.data.at("prompt").size();
|
||||||
|
assert(prompt_count > 1);
|
||||||
|
|
||||||
|
int multitask_id = id_gen++;
|
||||||
|
std::vector<int> subtask_ids(prompt_count);
|
||||||
|
for (int i = 0; i < prompt_count; i++)
|
||||||
|
{
|
||||||
|
json subtask_data = task.data;
|
||||||
|
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||||
|
|
||||||
|
subtask_ids[i] = request_completion(subtask_data, task.infill_mode, task.embedding_mode, multitask_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
add_multi_task(multitask_id, subtask_ids);
|
||||||
|
return multitask_id;
|
||||||
|
}
|
||||||
|
|
||||||
void process_tasks()
|
void process_tasks()
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
|
@ -1419,7 +1507,7 @@ struct llama_server_context
|
||||||
{
|
{
|
||||||
LOG_TEE("slot unavailable\n");
|
LOG_TEE("slot unavailable\n");
|
||||||
// send error result
|
// send error result
|
||||||
send_error(task.id, "slot unavailable");
|
send_error(task, "slot unavailable");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1433,11 +1521,12 @@ struct llama_server_context
|
||||||
slot->infill = task.infill_mode;
|
slot->infill = task.infill_mode;
|
||||||
slot->embedding = task.embedding_mode;
|
slot->embedding = task.embedding_mode;
|
||||||
slot->task_id = task.id;
|
slot->task_id = task.id;
|
||||||
|
slot->multitask_id = task.multitask_id;
|
||||||
|
|
||||||
if (!launch_slot_with_data(slot, task.data))
|
if (!launch_slot_with_data(slot, task.data))
|
||||||
{
|
{
|
||||||
// send error result
|
// send error result
|
||||||
send_error(task.id, "internal_error");
|
send_error(task, "internal_error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -1453,6 +1542,39 @@ struct llama_server_context
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
||||||
|
std::lock_guard<std::mutex> lock_multitasks(mutex_multitasks);
|
||||||
|
auto queue_iterator = queue_multitasks.begin();
|
||||||
|
while (queue_iterator != queue_multitasks.end())
|
||||||
|
{
|
||||||
|
if (queue_iterator->subtasks_remaining.empty())
|
||||||
|
{
|
||||||
|
// all subtasks done == multitask is done
|
||||||
|
task_result aggregate_result{};
|
||||||
|
aggregate_result.id = queue_iterator->id;
|
||||||
|
aggregate_result.stop = true;
|
||||||
|
aggregate_result.error = false;
|
||||||
|
|
||||||
|
// collect json results into one json result
|
||||||
|
std::vector<json> result_jsons{};
|
||||||
|
for (auto& subres : queue_iterator->results)
|
||||||
|
{
|
||||||
|
result_jsons.push_back(subres.result_json);
|
||||||
|
aggregate_result.error = aggregate_result.error && subres.error;
|
||||||
|
}
|
||||||
|
aggregate_result.result_json = json{ "results", result_jsons };
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
|
queue_results.push_back(aggregate_result);
|
||||||
|
|
||||||
|
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
++queue_iterator;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool update_slots() {
|
bool update_slots() {
|
||||||
|
@ -2596,7 +2718,7 @@ int main(int argc, char **argv)
|
||||||
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, false, false);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
|
@ -2685,7 +2807,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||||
|
|
||||||
const int task_id = llama.request_completion(data, false, false);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
|
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
|
@ -2754,7 +2876,7 @@ int main(int argc, char **argv)
|
||||||
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, true, false);
|
const int task_id = llama.request_completion(data, true, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
|
@ -2858,7 +2980,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
|
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
return res.set_content(result.result_json.dump(), "application/json");
|
return res.set_content(result.result_json.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue