apply review comments
This commit is contained in:
parent
2e560f90ff
commit
a43e1dc66c
1 changed files with 23 additions and 23 deletions
|
@ -234,7 +234,7 @@ struct server_task_result {
|
|||
};
|
||||
|
||||
// using shared_ptr for polymorphism of server_task_result
|
||||
using task_result_ptr = std::unique_ptr<server_task_result>;
|
||||
using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
||||
|
||||
inline std::string stop_type_to_str(stop_type type) {
|
||||
switch (type) {
|
||||
|
@ -1097,7 +1097,7 @@ struct server_response {
|
|||
std::unordered_set<int> waiting_task_ids;
|
||||
|
||||
// the main result queue (using ptr for polymorphism)
|
||||
std::vector<task_result_ptr> queue_results;
|
||||
std::vector<server_task_result_ptr> queue_results;
|
||||
|
||||
std::mutex mutex_results;
|
||||
std::condition_variable condition_results;
|
||||
|
@ -1137,7 +1137,7 @@ struct server_response {
|
|||
}
|
||||
|
||||
// This function blocks the thread until there is a response for one of the id_tasks
|
||||
task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
||||
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&]{
|
||||
|
@ -1146,7 +1146,7 @@ struct server_response {
|
|||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
task_result_ptr res = std::move(queue_results[i]);
|
||||
server_task_result_ptr res = std::move(queue_results[i]);
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
}
|
||||
|
@ -1157,13 +1157,13 @@ struct server_response {
|
|||
}
|
||||
|
||||
// single-task version of recv()
|
||||
task_result_ptr recv(int id_task) {
|
||||
server_task_result_ptr recv(int id_task) {
|
||||
std::unordered_set<int> id_tasks = {id_task};
|
||||
return recv(id_tasks);
|
||||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(task_result_ptr && result) {
|
||||
void send(server_task_result_ptr && result) {
|
||||
SRV_DBG("sending result for task id = %d\n", result->id);
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
@ -2078,11 +2078,11 @@ struct server_context {
|
|||
// receive the results from task(s) created by create_tasks_inference
|
||||
void receive_multi_results(
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<void(std::vector<task_result_ptr>&)> & result_handler,
|
||||
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
||||
const std::function<void(json)> & error_handler) {
|
||||
std::vector<task_result_ptr> results(id_tasks.size());
|
||||
std::vector<server_task_result_ptr> results(id_tasks.size());
|
||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
||||
task_result_ptr result = queue_results.recv(id_tasks);
|
||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
||||
|
||||
if (result->is_error()) {
|
||||
error_handler(result->to_json());
|
||||
|
@ -2104,12 +2104,12 @@ struct server_context {
|
|||
|
||||
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||
void receive_cmpl_results_stream(
|
||||
const std::unordered_set<int> & id_tasks, const
|
||||
std::function<bool(task_result_ptr&)> & result_handler, const
|
||||
std::function<void(json)> & error_handler) {
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
||||
const std::function<void(json)> & error_handler) {
|
||||
size_t n_finished = 0;
|
||||
while (true) {
|
||||
task_result_ptr result = queue_results.recv(id_tasks);
|
||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
||||
|
||||
if (result->is_error()) {
|
||||
error_handler(result->to_json());
|
||||
|
@ -3108,7 +3108,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
||||
|
||||
// get the result
|
||||
task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
if (result->is_error()) {
|
||||
|
@ -3148,7 +3148,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
||||
|
||||
// get the result
|
||||
task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
if (result->is_error()) {
|
||||
|
@ -3257,7 +3257,7 @@ int main(int argc, char ** argv) {
|
|||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
|
||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result->is_error()) {
|
||||
|
@ -3288,7 +3288,7 @@ int main(int argc, char ** argv) {
|
|||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
|
||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result->is_error()) {
|
||||
|
@ -3310,7 +3310,7 @@ int main(int argc, char ** argv) {
|
|||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
|
||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result->is_error()) {
|
||||
|
@ -3395,7 +3395,7 @@ int main(int argc, char ** argv) {
|
|||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json());
|
||||
|
@ -3414,7 +3414,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](task_result_ptr & result) -> bool {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
||||
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
|
@ -3609,7 +3609,7 @@ int main(int argc, char ** argv) {
|
|||
// get the result
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
|
@ -3688,7 +3688,7 @@ int main(int argc, char ** argv) {
|
|||
// get the result
|
||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
|
@ -3747,7 +3747,7 @@ int main(int argc, char ** argv) {
|
|||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
|
||||
task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result->is_error()) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue