use unordered_set everywhere

This commit is contained in:
Xuan Son Nguyen 2024-09-02 11:18:39 +02:00
parent 83249aae0c
commit 24329aac1e

View file

@ -92,10 +92,10 @@ struct server_task {
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
// utility function // utility function
static std::vector<int> get_list_id(std::vector<server_task> tasks) { static std::unordered_set<int> get_list_id(std::vector<server_task> tasks) {
std::vector<int> ids(tasks.size()); std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
ids[i] = tasks[i].id; ids.insert(tasks[i].id);
} }
return ids; return ids;
} }
@ -394,22 +394,35 @@ struct server_queue {
std::function<void(void)> callback_update_slots; std::function<void(void)> callback_update_slots;
// Add a new task to the end of the queue // Add a new task to the end of the queue
int post(server_task task) { int post(server_task task, bool front = false) {
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
if (task.id == -1) { if (task.id == -1) {
task.id = id++; task.id = id++;
LOG_VERBOSE("new task id", {{"new_id", task.id}}); LOG_VERBOSE("new task id", {{"new_id", task.id}});
} }
queue_tasks.push_back(std::move(task)); if (front) {
queue_tasks.insert(queue_tasks.begin(), std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
condition_tasks.notify_one(); condition_tasks.notify_one();
return task.id; return task.id;
} }
// multi-task version of post() // multi-task version of post()
int post(std::vector<server_task> & tasks) { int post(std::vector<server_task> & tasks, bool front = false) {
for (auto & task : tasks) { for (auto & task : tasks) {
post(task); if (task.id == -1) {
task.id = id++;
LOG_VERBOSE("new task id", {{"new_id", task.id}});
}
if (front) {
queue_tasks.insert(queue_tasks.begin(), std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
} }
condition_tasks.notify_one();
return 0; return 0;
} }
@ -419,7 +432,7 @@ struct server_queue {
queue_tasks_deferred.push_back(std::move(task)); queue_tasks_deferred.push_back(std::move(task));
} }
// Get the next id for creating anew task // Get the next id for creating a new task
int get_new_id() { int get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
int new_id = id++; int new_id = id++;
@ -566,16 +579,6 @@ struct server_response {
return recv(id_tasks); return recv(id_tasks);
} }
// multi-task version of recv()
server_task_result recv(std::vector<server_task> & tasks) {
std::unordered_set<int> id_tasks;
id_tasks.reserve(tasks.size());
for (const auto & t : tasks) {
id_tasks.insert(t.id);
}
return recv(id_tasks);
}
// Send a new result to a waiting id_task // Send a new result to a waiting id_task
void send(server_task_result result) { void send(server_task_result result) {
LOG_VERBOSE("send new result", {{"id_task", result.id}}); LOG_VERBOSE("send new result", {{"id_task", result.id}});
@ -1487,7 +1490,7 @@ struct server_context {
// Functions to create new task(s) and receive result(s) // Functions to create new task(s) and receive result(s)
// //
std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) { std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
server_task task; server_task task;
@ -1537,25 +1540,30 @@ struct server_context {
} }
void cancel_tasks(std::unordered_set<int> & id_tasks) { void cancel_tasks(std::unordered_set<int> & id_tasks) {
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) { for (const auto & id_task : id_tasks) {
LOG_VERBOSE("cancel task", {{"id_task", id_task}}); LOG_VERBOSE("cancel task", {{"id_task", id_task}});
server_task task; server_task task;
task.type = SERVER_TASK_TYPE_CANCEL; task.type = SERVER_TASK_TYPE_CANCEL;
task.id_target = id_task; task.id_target = id_task;
queue_tasks.post(task); cancel_tasks.push_back(task);
queue_results.remove_waiting_task_id(id_task); queue_results.remove_waiting_task_id(id_task);
} }
// push to beginning of the queue, so it has highest priority
queue_tasks.post(cancel_tasks, true);
} }
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) { // receive the results from task(s) created by create_tasks_cmpl
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end()); void receive_cmpl_results(std::unordered_set<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
std::vector<server_task_result> results(id_set.size()); // TODO: currently, there is no way to detect the client has cancelled the request
for (size_t i = 0; i < id_set.size(); i++) { std::vector<server_task_result> results(id_tasks.size());
server_task_result result = queue_results.recv(id_set); for (size_t i = 0; i < id_tasks.size(); i++) {
server_task_result result = queue_results.recv(id_tasks);
if (result.error) { if (result.error) {
error_handler(result.data); error_handler(result.data);
cancel_tasks(id_set); cancel_tasks(id_tasks);
break; break;
} }
@ -1565,24 +1573,24 @@ struct server_context {
result_handler(results); result_handler(results);
} }
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) { // receive the results from task(s) created by create_tasks_cmpl, in stream mode
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end()); void receive_cmpl_results_stream(std::unordered_set<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
size_t n_finished = 0; size_t n_finished = 0;
while (true) { while (true) {
server_task_result result = queue_results.recv(id_set); server_task_result result = queue_results.recv(id_tasks);
if (!result_handler(result)) { if (!result_handler(result)) {
cancel_tasks(id_set); cancel_tasks(id_tasks);
break; break;
} }
if (result.error) { if (result.error) {
error_handler(result.data); error_handler(result.data);
cancel_tasks(id_set); cancel_tasks(id_tasks);
break; break;
} }
if (result.stop) { if (result.stop) {
if (++n_finished == id_set.size()) { if (++n_finished == id_tasks.size()) {
break; break;
} }
} }
@ -2953,12 +2961,12 @@ int main(int argc, char ** argv) {
return; return;
} }
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type); std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks); std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
if (!stream) { if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
@ -2984,7 +2992,7 @@ int main(int argc, char ** argv) {
server_sent_event(sink, "error", error_data); server_sent_event(sink, "error", error_data);
}); });
sink.done(); sink.done();
return true; return false;
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider); res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
} }
@ -3009,12 +3017,12 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL); std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
std::vector<int> task_ids = server_task::get_list_id(tasks); std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid(); const auto completion_id = gen_chatcmplid();
if (!stream) { if (!stream) {
@ -3111,12 +3119,12 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
// get the result // get the result
std::vector<int> task_ids = server_task::get_list_id(tasks); std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
for (const auto & res : results) { for (const auto & res : results) {