use unordered_set everywhere

This commit is contained in:
Xuan Son Nguyen 2024-09-02 11:18:39 +02:00
parent 83249aae0c
commit 24329aac1e

View file

@ -92,10 +92,10 @@ struct server_task {
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
// utility function
static std::vector<int> get_list_id(std::vector<server_task> tasks) {
std::vector<int> ids(tasks.size());
static std::unordered_set<int> get_list_id(std::vector<server_task> tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids[i] = tasks[i].id;
ids.insert(tasks[i].id);
}
return ids;
}
@ -394,22 +394,35 @@ struct server_queue {
std::function<void(void)> callback_update_slots;
// Add a new task to the end of the queue
int post(server_task task) {
int post(server_task task, bool front = false) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (task.id == -1) {
task.id = id++;
LOG_VERBOSE("new task id", {{"new_id", task.id}});
}
if (front) {
queue_tasks.insert(queue_tasks.begin(), std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
condition_tasks.notify_one();
return task.id;
}
// multi-task version of post()
int post(std::vector<server_task> & tasks) {
int post(std::vector<server_task> & tasks, bool front = false) {
for (auto & task : tasks) {
post(task);
if (task.id == -1) {
task.id = id++;
LOG_VERBOSE("new task id", {{"new_id", task.id}});
}
if (front) {
queue_tasks.insert(queue_tasks.begin(), std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
}
condition_tasks.notify_one();
return 0;
}
@ -566,16 +579,6 @@ struct server_response {
return recv(id_tasks);
}
// multi-task version of recv()
server_task_result recv(std::vector<server_task> & tasks) {
std::unordered_set<int> id_tasks;
id_tasks.reserve(tasks.size());
for (const auto & t : tasks) {
id_tasks.insert(t.id);
}
return recv(id_tasks);
}
// Send a new result to a waiting id_task
void send(server_task_result result) {
LOG_VERBOSE("send new result", {{"id_task", result.id}});
@ -1487,7 +1490,7 @@ struct server_context {
// Functions to create new task(s) and receive result(s)
//
std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
server_task task;
@ -1537,25 +1540,30 @@ struct server_context {
}
void cancel_tasks(std::unordered_set<int> & id_tasks) {
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) {
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
server_task task;
task.type = SERVER_TASK_TYPE_CANCEL;
task.id_target = id_task;
queue_tasks.post(task);
cancel_tasks.push_back(task);
queue_results.remove_waiting_task_id(id_task);
}
// push to beginning of the queue, so it has highest priority
queue_tasks.post(cancel_tasks, true);
}
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
std::vector<server_task_result> results(id_set.size());
for (size_t i = 0; i < id_set.size(); i++) {
server_task_result result = queue_results.recv(id_set);
// receive the results from task(s) created by create_tasks_cmpl
void receive_cmpl_results(std::unordered_set<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
// TODO: currently, there is no way to detect the client has cancelled the request
std::vector<server_task_result> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) {
server_task_result result = queue_results.recv(id_tasks);
if (result.error) {
error_handler(result.data);
cancel_tasks(id_set);
cancel_tasks(id_tasks);
break;
}
@ -1565,24 +1573,24 @@ struct server_context {
result_handler(results);
}
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
void receive_cmpl_results_stream(std::unordered_set<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
size_t n_finished = 0;
while (true) {
server_task_result result = queue_results.recv(id_set);
server_task_result result = queue_results.recv(id_tasks);
if (!result_handler(result)) {
cancel_tasks(id_set);
cancel_tasks(id_tasks);
break;
}
if (result.error) {
error_handler(result.data);
cancel_tasks(id_set);
cancel_tasks(id_tasks);
break;
}
if (result.stop) {
if (++n_finished == id_set.size()) {
if (++n_finished == id_tasks.size()) {
break;
}
}
@ -2953,12 +2961,12 @@ int main(int argc, char ** argv) {
return;
}
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks);
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
@ -2984,7 +2992,7 @@ int main(int argc, char ** argv) {
server_sent_event(sink, "error", error_data);
});
sink.done();
return true;
return false;
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
}
@ -3009,12 +3017,12 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks);
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
if (!stream) {
@ -3111,12 +3119,12 @@ int main(int argc, char ** argv) {
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
// get the result
std::vector<int> task_ids = server_task::get_list_id(tasks);
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
for (const auto & res : results) {