From a43e1dc66c911804483dfb67b675ff99034229d8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 5 Dec 2024 22:35:07 +0100 Subject: [PATCH] apply review comments --- examples/server/server.cpp | 46 +++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 44e6ead3a..b58f10186 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -234,7 +234,7 @@ struct server_task_result { }; // using shared_ptr for polymorphism of server_task_result -using task_result_ptr = std::unique_ptr; +using server_task_result_ptr = std::unique_ptr; inline std::string stop_type_to_str(stop_type type) { switch (type) { @@ -1097,7 +1097,7 @@ struct server_response { std::unordered_set waiting_task_ids; // the main result queue (using ptr for polymorphism) - std::vector queue_results; + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; @@ -1137,7 +1137,7 @@ struct server_response { } // This function blocks the thread until there is a response for one of the id_tasks - task_result_ptr recv(const std::unordered_set & id_tasks) { + server_task_result_ptr recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ @@ -1146,7 +1146,7 @@ struct server_response { for (int i = 0; i < (int) queue_results.size(); i++) { if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - task_result_ptr res = std::move(queue_results[i]); + server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } @@ -1157,13 +1157,13 @@ struct server_response { } // single-task version of recv() - task_result_ptr recv(int id_task) { + server_task_result_ptr recv(int id_task) { std::unordered_set id_tasks = {id_task}; return recv(id_tasks); } // Send a new result to a waiting id_task - void send(task_result_ptr && result) { + void send(server_task_result_ptr && result) { SRV_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); @@ -2078,11 +2078,11 @@ struct server_context { // receive the results from task(s) created by create_tasks_inference void receive_multi_results( const std::unordered_set & id_tasks, - const std::function&)> & result_handler, + const std::function&)> & result_handler, const std::function & error_handler) { - std::vector results(id_tasks.size()); + std::vector results(id_tasks.size()); for (size_t i = 0; i < id_tasks.size(); i++) { - task_result_ptr result = queue_results.recv(id_tasks); + server_task_result_ptr result = queue_results.recv(id_tasks); if (result->is_error()) { error_handler(result->to_json()); @@ -2104,12 +2104,12 @@ struct server_context { // receive the results from task(s) created by create_tasks_inference, in stream mode void receive_cmpl_results_stream( - const std::unordered_set & id_tasks, const - std::function & result_handler, const - std::function & error_handler) { + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler) { size_t n_finished = 0; while (true) { - task_result_ptr result = queue_results.recv(id_tasks); + server_task_result_ptr result = queue_results.recv(id_tasks); if (result->is_error()) { error_handler(result->to_json()); @@ -3108,7 +3108,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - task_result_ptr result = ctx_server.queue_results.recv(task.id); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { @@ -3148,7 +3148,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - task_result_ptr result = ctx_server.queue_results.recv(task.id); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { @@ -3257,7 +3257,7 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); if (result->is_error()) { @@ -3288,7 +3288,7 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); if (result->is_error()) { @@ -3310,7 +3310,7 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); if (result->is_error()) { @@ -3395,7 +3395,7 @@ int main(int argc, char ** argv) { const auto task_ids = server_task::get_list_id(tasks); if (!stream) { - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { if (results.size() == 1) { // single result res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json()); @@ -3414,7 +3414,7 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](task_result_ptr & result) -> bool { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json(); if (res_json.is_array()) { for (const auto & res : res_json) { @@ -3609,7 +3609,7 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); @@ -3688,7 +3688,7 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); @@ -3747,7 +3747,7 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); if (result->is_error()) {