From cb666718b1d4fe94de819f4888035d53b73b4133 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 4 Dec 2024 23:53:25 +0100 Subject: [PATCH] refactor handle_completions_generic --- examples/server/server.cpp | 100 ++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9057c0a4c..0ab09db22 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2716,7 +2716,16 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) { + // handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( + server_task_inf_type inf_type, + json & data, + httplib::Response & res, + const std::function(server_task_result_cmpl_partial&)> & format_partial = nullptr, + const std::function&)> & format_final = nullptr, + // wether to send [DONE] event after completion (required for OAI-compat) + bool send_done_event = false) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -2731,7 +2740,9 @@ int main(int argc, char ** argv) { if (!stream) { ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - if (results.size() == 1) { + if (format_final) { + res_ok(res, format_final(results)); + } else if (results.size() == 1) { // single result res_ok(res, results[0].to_json()); } else { @@ -2748,12 +2759,25 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [task_ids, &ctx_server, format_partial = std::move(format_partial), send_done_event](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { - return server_sent_event(sink, "data", result.to_json()); + if (format_partial) { + for (const auto & res : format_partial(result)) { + if (!server_sent_event(sink, "data", res)) { + return false; + } + } + return true; + } else { + return server_sent_event(sink, "data", result.to_json()); + } }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); + if (send_done_event) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } sink.done(); return false; }; @@ -2768,7 +2792,13 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); + return handle_completions_generic( + SERVER_TASK_INF_TYPE_COMPLETION, + data, + res, + // TODO: support OAI-compat response via format_partial and format_final + /* format_partial */ nullptr, + /* format_final */ nullptr); }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -2821,8 +2851,7 @@ int main(int argc, char ** argv) { return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); }; - // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, verbose](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -2830,53 +2859,20 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - std::vector tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); const auto completion_id = gen_chatcmplid(); + std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - if (!stream) { - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose); - res_ok(res, result_oai); - }, [&](const json & error_data) { - res_error(res, error_data); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { - std::vector result_array = format_partial_response_oaicompat(model_name, result, completion_id); - for (auto & event_data : result_array) { - if (event_data.empty()) { - continue; // skip the stop token - } - if (!server_sent_event(sink, "data", event_data)) { - return false; // connection is closed - } - } - return true; // ok - }, [&](const json & error_data) { - server_sent_event(sink, "error", error_data); - }); - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - sink.done(); - return true; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } + return handle_completions_generic( + SERVER_TASK_INF_TYPE_COMPLETION, + data, + res, + /* format_partial */ [data, model_name, completion_id](server_task_result_cmpl_partial & result) { + return format_partial_response_oaicompat(model_name, result, completion_id); + }, + /* format_final */ [data, verbose, model_name](std::vector & results) { + return format_final_response_oaicompat(data, results[0], model_name, false, verbose); + }, + /* send_done_event */ true); }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {