use unordered_set everywhere
This commit is contained in:
parent
83249aae0c
commit
24329aac1e
1 changed files with 47 additions and 39 deletions
|
@ -92,10 +92,10 @@ struct server_task {
|
|||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
|
||||
// utility function
|
||||
static std::vector<int> get_list_id(std::vector<server_task> tasks) {
|
||||
std::vector<int> ids(tasks.size());
|
||||
static std::unordered_set<int> get_list_id(std::vector<server_task> tasks) {
|
||||
std::unordered_set<int> ids(tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
ids[i] = tasks[i].id;
|
||||
ids.insert(tasks[i].id);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
@ -394,22 +394,35 @@ struct server_queue {
|
|||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task) {
|
||||
int post(server_task task, bool front = false) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
||||
}
|
||||
if (front) {
|
||||
queue_tasks.insert(queue_tasks.begin(), std::move(task));
|
||||
} else {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
condition_tasks.notify_one();
|
||||
return task.id;
|
||||
}
|
||||
|
||||
// multi-task version of post()
|
||||
int post(std::vector<server_task> & tasks) {
|
||||
int post(std::vector<server_task> & tasks, bool front = false) {
|
||||
for (auto & task : tasks) {
|
||||
post(task);
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
||||
}
|
||||
if (front) {
|
||||
queue_tasks.insert(queue_tasks.begin(), std::move(task));
|
||||
} else {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
}
|
||||
condition_tasks.notify_one();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -566,16 +579,6 @@ struct server_response {
|
|||
return recv(id_tasks);
|
||||
}
|
||||
|
||||
// multi-task version of recv()
|
||||
server_task_result recv(std::vector<server_task> & tasks) {
|
||||
std::unordered_set<int> id_tasks;
|
||||
id_tasks.reserve(tasks.size());
|
||||
for (const auto & t : tasks) {
|
||||
id_tasks.insert(t.id);
|
||||
}
|
||||
return recv(id_tasks);
|
||||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result result) {
|
||||
LOG_VERBOSE("send new result", {{"id_task", result.id}});
|
||||
|
@ -1487,7 +1490,7 @@ struct server_context {
|
|||
// Functions to create new task(s) and receive result(s)
|
||||
//
|
||||
|
||||
std::vector<server_task> create_tasks_completion(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||
server_task task;
|
||||
|
@ -1537,25 +1540,30 @@ struct server_context {
|
|||
}
|
||||
|
||||
void cancel_tasks(std::unordered_set<int> & id_tasks) {
|
||||
std::vector<server_task> cancel_tasks;
|
||||
cancel_tasks.reserve(id_tasks.size());
|
||||
for (const auto & id_task : id_tasks) {
|
||||
LOG_VERBOSE("cancel task", {{"id_task", id_task}});
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_CANCEL;
|
||||
task.id_target = id_task;
|
||||
queue_tasks.post(task);
|
||||
cancel_tasks.push_back(task);
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(cancel_tasks, true);
|
||||
}
|
||||
|
||||
void receive_cmpl_results(std::vector<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
|
||||
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
|
||||
std::vector<server_task_result> results(id_set.size());
|
||||
for (size_t i = 0; i < id_set.size(); i++) {
|
||||
server_task_result result = queue_results.recv(id_set);
|
||||
// receive the results from task(s) created by create_tasks_cmpl
|
||||
void receive_cmpl_results(std::unordered_set<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
|
||||
// TODO: currently, there is no way to detect the client has cancelled the request
|
||||
std::vector<server_task_result> results(id_tasks.size());
|
||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
||||
server_task_result result = queue_results.recv(id_tasks);
|
||||
|
||||
if (result.error) {
|
||||
error_handler(result.data);
|
||||
cancel_tasks(id_set);
|
||||
cancel_tasks(id_tasks);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -1565,24 +1573,24 @@ struct server_context {
|
|||
result_handler(results);
|
||||
}
|
||||
|
||||
void receive_cmpl_results_stream(std::vector<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
|
||||
std::unordered_set<int> id_set(id_tasks.begin(), id_tasks.end());
|
||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
||||
void receive_cmpl_results_stream(std::unordered_set<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
|
||||
size_t n_finished = 0;
|
||||
while (true) {
|
||||
server_task_result result = queue_results.recv(id_set);
|
||||
server_task_result result = queue_results.recv(id_tasks);
|
||||
if (!result_handler(result)) {
|
||||
cancel_tasks(id_set);
|
||||
cancel_tasks(id_tasks);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.error) {
|
||||
error_handler(result.data);
|
||||
cancel_tasks(id_set);
|
||||
cancel_tasks(id_tasks);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.stop) {
|
||||
if (++n_finished == id_set.size()) {
|
||||
if (++n_finished == id_tasks.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -2953,12 +2961,12 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, cmpl_type);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
|
@ -2984,7 +2992,7 @@ int main(int argc, char ** argv) {
|
|||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
sink.done();
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
|
||||
}
|
||||
|
@ -3009,12 +3017,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
|
||||
if (!stream) {
|
||||
|
@ -3111,12 +3119,12 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
// get the result
|
||||
std::vector<int> task_ids = server_task::get_list_id(tasks);
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
for (const auto & res : results) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue