use unordered_set everywhere
This commit is contained in:
parent
83249aae0c
commit
24329aac1e
1 changed files with 47 additions and 39 deletions
|
@ -92,10 +92,10 @@ struct server_task {
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
static std::vector<int> get_list_id(std::vector<server_task> tasks) {
|
static std::unordered_set<int> get_list_id(std::vector<server_task> tasks) {
|
||||||
std::vector<int> ids(tasks.size());
|
std::unordered_set<int> ids(tasks.size());
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
ids[i] = tasks[i].id;
|
ids.insert(tasks[i].id);
|
||||||
}
|
}
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
@ -394,22 +394,35 @@ struct server_queue {
|
||||||
std::function<void(void)> callback_update_slots;
|
std::function<void(void)> callback_update_slots;
|
||||||
|
|
||||||
// Add a new task to the end of the queue
|
// Add a new task to the end of the queue
|
||||||
int post(server_task task) {
|
int post(server_task task, bool front = false) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
if (task.id == -1) {
|
if (task.id == -1) {
|
||||||
task.id = id++;
|
task.id = id++;
|
||||||
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
||||||
}
|
}
|
||||||
queue_tasks.push_back(std::move(task));
|
if (front) {
|
||||||
|
queue_tasks.insert(queue_tasks.begin(), std::move(task));
|
||||||
|
} else {
|
||||||
|
queue_tasks.push_back(std::move(task));
|
||||||
|
}
|
||||||
condition_tasks.notify_one();
|
condition_tasks.notify_one();
|
||||||
return task.id;
|
return task.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
// multi-task version of post()
|
// multi-task version of post()
|
||||||
int post(std::vector<server_task> & tasks) {
|
int post(std::vector<server_task> & tasks, bool front = false) {
|
||||||
for (auto & task : tasks) {
|
for (auto & task : tasks) {
|
||||||
post(task);
|
if (task.id == -1) {
|
||||||
|
task.id = id++;
|
||||||
|
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
||||||
|
}
|
||||||
|
if (front) {
|
||||||
|
queue_tasks.insert(queue_tasks.begin(), std::move(task));
|
||||||
|
} else {
|
||||||
|
queue_tasks.push_back(std::move(task));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
condition_tasks.notify_one();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,7 +432,7 @@ struct server_queue {
|
||||||
queue_tasks_deferred.push_back(std::move(task));
|
queue_tasks_deferred.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next id for creating anew task
|
// Get the next id for creating a new task
|
||||||
int get_new_id() {
|
int get_new_id() {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
int new_id = id++;
|
int new_id = id++;
|
||||||
|
@ -566,16 +579,6 @@ struct server_response {
|
||||||
return recv(id_tasks);
|
return recv(id_tasks);
|
||||||
}
|
}
|
||||||
|
|
||||||
// multi-task version of recv()
|
|
||||||
server_task_result recv(std::vector<server_task> & tasks) {
|
|
||||||
std::unordered_set<int> id_tasks;
|
|
||||||
id_tasks.reserve(tasks.size());
|
|
||||||
for (const auto & t : tasks) {
|
|
||||||
id_tasks.insert(t.id);
|
|
||||||
}
|
|
||||||
return recv(id_tasks);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send a new result to a waiting id_task
|
// Send a new result to a waiting id_task
|
||||||
void send(server_task_result result) {
|
void send(server_task_result result) {
|
||||||
LOG_VERBOSE("send new result", {{"id_task", result.id}});
|
LOG_VERBOSE("send new result", {{"id_task", result.id}});
|
||||||
|
@ -1487,7 +1490,7 @@ struct server_context {
|
||||||
// Functions to create new task(s) and receive result(s)
|
// Functions to create new task(s) and receive result(s)
|
||||||
//
|
//
|
||||||
|
|
||||||
std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) {
|
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||||
server_task task;
|
server_task task;
|
||||||
|
@ -1537,25 +1540,30 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
void cancel_tasks(std::unordered_set<int> & id_tasks) {
|
void cancel_tasks(std::unordered_set<int> & id_tasks) {
|
||||||
|
std::vector<server_task> cancel_tasks;
|
||||||
|
cancel_tasks.reserve(id_tasks.size());
|
||||||
for (const auto & id_task : id_tasks) {
|
for (const auto & id_task : id_tasks) {
|
||||||
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
|
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
|
||||||
server_task task;
|
server_task task;
|
||||||
task.type = SERVER_TASK_TYPE_CANCEL;
|
task.type = SERVER_TASK_TYPE_CANCEL;
|
||||||
task.id_target = id_task;
|
task.id_target = id_task;
|
||||||
queue_tasks.post(task);
|
cancel_tasks.push_back(task);
|
||||||
queue_results.remove_waiting_task_id(id_task);
|
queue_results.remove_waiting_task_id(id_task);
|
||||||
}
|
}
|
||||||
|
// push to beginning of the queue, so it has highest priority
|
||||||
|
queue_tasks.post(cancel_tasks, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
|
// receive the results from task(s) created by create_tasks_cmpl
|
||||||
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
|
void receive_cmpl_results(std::unordered_set<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
|
||||||
std::vector<server_task_result> results(id_set.size());
|
// TODO: currently, there is no way to detect the client has cancelled the request
|
||||||
for (size_t i = 0; i < id_set.size(); i++) {
|
std::vector<server_task_result> results(id_tasks.size());
|
||||||
server_task_result result = queue_results.recv(id_set);
|
for (size_t i = 0; i < id_tasks.size(); i++) {
|
||||||
|
server_task_result result = queue_results.recv(id_tasks);
|
||||||
|
|
||||||
if (result.error) {
|
if (result.error) {
|
||||||
error_handler(result.data);
|
error_handler(result.data);
|
||||||
cancel_tasks(id_set);
|
cancel_tasks(id_tasks);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1565,24 +1573,24 @@ struct server_context {
|
||||||
result_handler(results);
|
result_handler(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
|
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
||||||
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
|
void receive_cmpl_results_stream(std::unordered_set<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
|
||||||
size_t n_finished = 0;
|
size_t n_finished = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
server_task_result result = queue_results.recv(id_set);
|
server_task_result result = queue_results.recv(id_tasks);
|
||||||
if (!result_handler(result)) {
|
if (!result_handler(result)) {
|
||||||
cancel_tasks(id_set);
|
cancel_tasks(id_tasks);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result.error) {
|
if (result.error) {
|
||||||
error_handler(result.data);
|
error_handler(result.data);
|
||||||
cancel_tasks(id_set);
|
cancel_tasks(id_tasks);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result.stop) {
|
if (result.stop) {
|
||||||
if (++n_finished == id_set.size()) {
|
if (++n_finished == id_tasks.size()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2953,12 +2961,12 @@ int main(int argc, char ** argv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type);
|
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
bool stream = json_value(data, "stream", false);
|
||||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||||
|
@ -2984,7 +2992,7 @@ int main(int argc, char ** argv) {
|
||||||
server_sent_event(sink, "error", error_data);
|
server_sent_event(sink, "error", error_data);
|
||||||
});
|
});
|
||||||
sink.done();
|
sink.done();
|
||||||
return true;
|
return false;
|
||||||
};
|
};
|
||||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
|
||||||
}
|
}
|
||||||
|
@ -3009,12 +3017,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
bool stream = json_value(data, "stream", false);
|
||||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
const auto completion_id = gen_chatcmplid();
|
const auto completion_id = gen_chatcmplid();
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
|
@ -3111,12 +3119,12 @@ int main(int argc, char ** argv) {
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||||
for (const auto & res : results) {
|
for (const auto & res : results) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue