apply review comments
This commit is contained in:
parent
2e560f90ff
commit
a43e1dc66c
1 changed files with 23 additions and 23 deletions
|
@ -234,7 +234,7 @@ struct server_task_result {
|
||||||
};
|
};
|
||||||
|
|
||||||
// using shared_ptr for polymorphism of server_task_result
|
// using shared_ptr for polymorphism of server_task_result
|
||||||
using task_result_ptr = std::unique_ptr<server_task_result>;
|
using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
||||||
|
|
||||||
inline std::string stop_type_to_str(stop_type type) {
|
inline std::string stop_type_to_str(stop_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
@ -1097,7 +1097,7 @@ struct server_response {
|
||||||
std::unordered_set<int> waiting_task_ids;
|
std::unordered_set<int> waiting_task_ids;
|
||||||
|
|
||||||
// the main result queue (using ptr for polymorphism)
|
// the main result queue (using ptr for polymorphism)
|
||||||
std::vector<task_result_ptr> queue_results;
|
std::vector<server_task_result_ptr> queue_results;
|
||||||
|
|
||||||
std::mutex mutex_results;
|
std::mutex mutex_results;
|
||||||
std::condition_variable condition_results;
|
std::condition_variable condition_results;
|
||||||
|
@ -1137,7 +1137,7 @@ struct server_response {
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function blocks the thread until there is a response for one of the id_tasks
|
// This function blocks the thread until there is a response for one of the id_tasks
|
||||||
task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_results);
|
std::unique_lock<std::mutex> lock(mutex_results);
|
||||||
condition_results.wait(lock, [&]{
|
condition_results.wait(lock, [&]{
|
||||||
|
@ -1146,7 +1146,7 @@ struct server_response {
|
||||||
|
|
||||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||||
task_result_ptr res = std::move(queue_results[i]);
|
server_task_result_ptr res = std::move(queue_results[i]);
|
||||||
queue_results.erase(queue_results.begin() + i);
|
queue_results.erase(queue_results.begin() + i);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -1157,13 +1157,13 @@ struct server_response {
|
||||||
}
|
}
|
||||||
|
|
||||||
// single-task version of recv()
|
// single-task version of recv()
|
||||||
task_result_ptr recv(int id_task) {
|
server_task_result_ptr recv(int id_task) {
|
||||||
std::unordered_set<int> id_tasks = {id_task};
|
std::unordered_set<int> id_tasks = {id_task};
|
||||||
return recv(id_tasks);
|
return recv(id_tasks);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a new result to a waiting id_task
|
// Send a new result to a waiting id_task
|
||||||
void send(task_result_ptr && result) {
|
void send(server_task_result_ptr && result) {
|
||||||
SRV_DBG("sending result for task id = %d\n", result->id);
|
SRV_DBG("sending result for task id = %d\n", result->id);
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(mutex_results);
|
std::unique_lock<std::mutex> lock(mutex_results);
|
||||||
|
@ -2078,11 +2078,11 @@ struct server_context {
|
||||||
// receive the results from task(s) created by create_tasks_inference
|
// receive the results from task(s) created by create_tasks_inference
|
||||||
void receive_multi_results(
|
void receive_multi_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<task_result_ptr>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
||||||
const std::function<void(json)> & error_handler) {
|
const std::function<void(json)> & error_handler) {
|
||||||
std::vector<task_result_ptr> results(id_tasks.size());
|
std::vector<server_task_result_ptr> results(id_tasks.size());
|
||||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
for (size_t i = 0; i < id_tasks.size(); i++) {
|
||||||
task_result_ptr result = queue_results.recv(id_tasks);
|
server_task_result_ptr result = queue_results.recv(id_tasks);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
error_handler(result->to_json());
|
error_handler(result->to_json());
|
||||||
|
@ -2104,12 +2104,12 @@ struct server_context {
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks, const
|
const std::unordered_set<int> & id_tasks,
|
||||||
std::function<bool(task_result_ptr&)> & result_handler, const
|
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
||||||
std::function<void(json)> & error_handler) {
|
const std::function<void(json)> & error_handler) {
|
||||||
size_t n_finished = 0;
|
size_t n_finished = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
task_result_ptr result = queue_results.recv(id_tasks);
|
server_task_result_ptr result = queue_results.recv(id_tasks);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
error_handler(result->to_json());
|
error_handler(result->to_json());
|
||||||
|
@ -3108,7 +3108,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
ctx_server.queue_tasks.post(task, true); // high-priority task
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
|
@ -3148,7 +3148,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
ctx_server.queue_tasks.post(task, true); // high-priority task
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
|
@ -3257,7 +3257,7 @@ int main(int argc, char ** argv) {
|
||||||
const int id_task = ctx_server.queue_tasks.post(task);
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
|
@ -3288,7 +3288,7 @@ int main(int argc, char ** argv) {
|
||||||
const int id_task = ctx_server.queue_tasks.post(task);
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
|
@ -3310,7 +3310,7 @@ int main(int argc, char ** argv) {
|
||||||
const int id_task = ctx_server.queue_tasks.post(task);
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
|
@ -3395,7 +3395,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto task_ids = server_task::get_list_id(tasks);
|
const auto task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
if (results.size() == 1) {
|
if (results.size() == 1) {
|
||||||
// single result
|
// single result
|
||||||
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json());
|
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json());
|
||||||
|
@ -3414,7 +3414,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
|
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](task_result_ptr & result) -> bool {
|
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
||||||
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json();
|
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json();
|
||||||
if (res_json.is_array()) {
|
if (res_json.is_array()) {
|
||||||
for (const auto & res : res_json) {
|
for (const auto & res : res_json) {
|
||||||
|
@ -3609,7 +3609,7 @@ int main(int argc, char ** argv) {
|
||||||
// get the result
|
// get the result
|
||||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
for (auto & res : results) {
|
for (auto & res : results) {
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||||
responses.push_back(res->to_json());
|
responses.push_back(res->to_json());
|
||||||
|
@ -3688,7 +3688,7 @@ int main(int argc, char ** argv) {
|
||||||
// get the result
|
// get the result
|
||||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
for (auto & res : results) {
|
for (auto & res : results) {
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||||
responses.push_back(res->to_json());
|
responses.push_back(res->to_json());
|
||||||
|
@ -3747,7 +3747,7 @@ int main(int argc, char ** argv) {
|
||||||
const int id_task = ctx_server.queue_tasks.post(task);
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue