diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 74c25c268..e7ad0f07b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,10 +92,10 @@ struct server_task { server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; // utility function - static std::vector get_list_id(std::vector tasks) { - std::vector ids(tasks.size()); + static std::unordered_set get_list_id(std::vector tasks) { + std::unordered_set ids(tasks.size()); for (size_t i = 0; i < tasks.size(); i++) { - ids[i] = tasks[i].id; + ids.insert(tasks[i].id); } return ids; } @@ -394,22 +394,35 @@ struct server_queue { std::function callback_update_slots; // Add a new task to the end of the queue - int post(server_task task) { + int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); if (task.id == -1) { task.id = id++; LOG_VERBOSE("new task id", {{"new_id", task.id}}); } - queue_tasks.push_back(std::move(task)); + if (front) { + queue_tasks.insert(queue_tasks.begin(), std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } condition_tasks.notify_one(); return task.id; } // multi-task version of post() - int post(std::vector & tasks) { + int post(std::vector & tasks, bool front = false) { for (auto & task : tasks) { - post(task); + if (task.id == -1) { + task.id = id++; + LOG_VERBOSE("new task id", {{"new_id", task.id}}); + } + if (front) { + queue_tasks.insert(queue_tasks.begin(), std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } } + condition_tasks.notify_one(); return 0; } @@ -419,7 +432,7 @@ struct server_queue { queue_tasks_deferred.push_back(std::move(task)); } - // Get the next id for creating anew task + // Get the next id for creating a new task int get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; @@ -566,16 +579,6 @@ struct server_response { return recv(id_tasks); } - // multi-task version of recv() - server_task_result recv(std::vector & tasks) { - std::unordered_set id_tasks; - id_tasks.reserve(tasks.size()); - for (const auto & t : tasks) { - id_tasks.insert(t.id); - } - return recv(id_tasks); - } - // Send a new result to a waiting id_task void send(server_task_result result) { LOG_VERBOSE("send new result", {{"id_task", result.id}}); @@ -1487,7 +1490,7 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - std::vector create_tasks_completion(json data, server_task_cmpl_type cmpl_type) { + std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { std::vector tasks; auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { server_task task; @@ -1537,25 +1540,30 @@ struct server_context { } void cancel_tasks(std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); for (const auto & id_task : id_tasks) { LOG_VERBOSE("cancel task", {{"id_task", id_task}}); server_task task; task.type = SERVER_TASK_TYPE_CANCEL; task.id_target = id_task; - queue_tasks.post(task); + cancel_tasks.push_back(task); queue_results.remove_waiting_task_id(id_task); } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(cancel_tasks, true); } - void receive_cmpl_results(std::vector & id_tasks, std::function&)> result_handler, std::function error_handler) { - std::unordered_set id_set(id_tasks.begin(), id_tasks.end()); - std::vector results(id_set.size()); - for (size_t i = 0; i < id_set.size(); i++) { - server_task_result result = queue_results.recv(id_set); + // receive the results from task(s) created by create_tasks_cmpl + void receive_cmpl_results(std::unordered_set & id_tasks, std::function&)> result_handler, std::function error_handler) { + // TODO: currently, there is no way to detect the client has cancelled the request + std::vector results(id_tasks.size()); + for (size_t i = 0; i < id_tasks.size(); i++) { + server_task_result result = queue_results.recv(id_tasks); if (result.error) { error_handler(result.data); - cancel_tasks(id_set); + cancel_tasks(id_tasks); break; } @@ -1565,24 +1573,24 @@ struct server_context { result_handler(results); } - void receive_cmpl_results_stream(std::vector & id_tasks, std::function result_handler, std::function error_handler) { - std::unordered_set id_set(id_tasks.begin(), id_tasks.end()); + // receive the results from task(s) created by create_tasks_cmpl, in stream mode + void receive_cmpl_results_stream(std::unordered_set & id_tasks, std::function result_handler, std::function error_handler) { size_t n_finished = 0; while (true) { - server_task_result result = queue_results.recv(id_set); + server_task_result result = queue_results.recv(id_tasks); if (!result_handler(result)) { - cancel_tasks(id_set); + cancel_tasks(id_tasks); break; } if (result.error) { error_handler(result.data); - cancel_tasks(id_set); + cancel_tasks(id_tasks); break; } if (result.stop) { - if (++n_finished == id_set.size()) { + if (++n_finished == id_tasks.size()) { break; } } @@ -2953,12 +2961,12 @@ int main(int argc, char ** argv) { return; } - std::vector tasks = ctx_server.create_tasks_completion(data, cmpl_type); + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); bool stream = json_value(data, "stream", false); - std::vector task_ids = server_task::get_list_id(tasks); + std::unordered_set task_ids = server_task::get_list_id(tasks); if (!stream) { ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { @@ -2984,7 +2992,7 @@ int main(int argc, char ** argv) { server_sent_event(sink, "error", error_data); }); sink.done(); - return true; + return false; }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider); } @@ -3009,12 +3017,12 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - std::vector tasks = ctx_server.create_tasks_completion(data, SERVER_TASK_CMPL_TYPE_NORMAL); + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); bool stream = json_value(data, "stream", false); - std::vector task_ids = server_task::get_list_id(tasks); + std::unordered_set task_ids = server_task::get_list_id(tasks); const auto completion_id = gen_chatcmplid(); if (!stream) { @@ -3111,12 +3119,12 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); // get the result - std::vector task_ids = server_task::get_list_id(tasks); + std::unordered_set task_ids = server_task::get_list_id(tasks); ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { for (const auto & res : results) {