apply review comments

This commit is contained in:
Xuan Son Nguyen 2024-12-05 22:35:07 +01:00
parent 2e560f90ff
commit a43e1dc66c

View file

@ -234,7 +234,7 @@ struct server_task_result {
};
// using shared_ptr for polymorphism of server_task_result
using task_result_ptr = std::unique_ptr<server_task_result>;
using server_task_result_ptr = std::unique_ptr<server_task_result>;
inline std::string stop_type_to_str(stop_type type) {
switch (type) {
@ -1097,7 +1097,7 @@ struct server_response {
std::unordered_set<int> waiting_task_ids;
// the main result queue (using ptr for polymorphism)
std::vector<task_result_ptr> queue_results;
std::vector<server_task_result_ptr> queue_results;
std::mutex mutex_results;
std::condition_variable condition_results;
@ -1137,7 +1137,7 @@ struct server_response {
}
// This function blocks the thread until there is a response for one of the id_tasks
task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
@ -1146,7 +1146,7 @@ struct server_response {
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
task_result_ptr res = std::move(queue_results[i]);
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
@ -1157,13 +1157,13 @@ struct server_response {
}
// single-task version of recv()
task_result_ptr recv(int id_task) {
server_task_result_ptr recv(int id_task) {
std::unordered_set<int> id_tasks = {id_task};
return recv(id_tasks);
}
// Send a new result to a waiting id_task
void send(task_result_ptr && result) {
void send(server_task_result_ptr && result) {
SRV_DBG("sending result for task id = %d\n", result->id);
std::unique_lock<std::mutex> lock(mutex_results);
@ -2078,11 +2078,11 @@ struct server_context {
// receive the results from task(s) created by create_tasks_inference
void receive_multi_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<task_result_ptr>&)> & result_handler,
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
const std::function<void(json)> & error_handler) {
std::vector<task_result_ptr> results(id_tasks.size());
std::vector<server_task_result_ptr> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) {
task_result_ptr result = queue_results.recv(id_tasks);
server_task_result_ptr result = queue_results.recv(id_tasks);
if (result->is_error()) {
error_handler(result->to_json());
@ -2104,12 +2104,12 @@ struct server_context {
// receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const
std::function<bool(task_result_ptr&)> & result_handler, const
std::function<void(json)> & error_handler) {
const std::unordered_set<int> & id_tasks,
const std::function<bool(server_task_result_ptr&)> & result_handler,
const std::function<void(json)> & error_handler) {
size_t n_finished = 0;
while (true) {
task_result_ptr result = queue_results.recv(id_tasks);
server_task_result_ptr result = queue_results.recv(id_tasks);
if (result->is_error()) {
error_handler(result->to_json());
@ -3108,7 +3108,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(task, true); // high-priority task
// get the result
task_result_ptr result = ctx_server.queue_results.recv(task.id);
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
@ -3148,7 +3148,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(task, true); // high-priority task
// get the result
task_result_ptr result = ctx_server.queue_results.recv(task.id);
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
@ -3257,7 +3257,7 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result->is_error()) {
@ -3288,7 +3288,7 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result->is_error()) {
@ -3310,7 +3310,7 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result->is_error()) {
@ -3395,7 +3395,7 @@ int main(int argc, char ** argv) {
const auto task_ids = server_task::get_list_id(tasks);
if (!stream) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
if (results.size() == 1) {
// single result
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json());
@ -3414,7 +3414,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](task_result_ptr & result) -> bool {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
@ -3609,7 +3609,7 @@ int main(int argc, char ** argv) {
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
@ -3688,7 +3688,7 @@ int main(int argc, char ** argv) {
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
@ -3747,7 +3747,7 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result->is_error()) {