From ffc4441b1d9c03a8c5b65ee53bdc961d4dfe0de0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 5 Dec 2024 23:29:27 +0100 Subject: [PATCH] remove virtual for to_json_oai_compat() --- examples/server/server.cpp | 46 ++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 95d4bfd37..3685df0d9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -128,10 +128,11 @@ struct slot_params { bool can_speculative; // OAI-compat fields - bool oaicompat = false; + bool verbose = false; + bool oaicompat = false; + bool oaicompat_chat = true; std::string oaicompat_model; std::string oaicompat_cmpl_id; - bool verbose = false; json to_json() { std::vector samplers; @@ -226,10 +227,6 @@ struct server_task_result { return -1; } virtual json to_json() = 0; - virtual json to_json_oai_compat() { - // used by server_task_result_cmpl_final and server_task_result_cmpl_partial - return json(); - } virtual ~server_task_result() = default; }; @@ -299,16 +296,21 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields + bool verbose = false; + bool oaicompat = false; + bool oaicompat_chat = true; // TODO: support oaicompat for non-chat std::string oaicompat_model; std::string oaicompat_cmpl_id; - bool verbose = false; virtual int get_index() override { return index; } virtual json to_json() override { - // non-OAI-compat JSON + if (oaicompat) { + return to_json_oai_compat(); + } + // otherwise, non-OAI-compat JSON json res = json { {"index", index}, {"content", content}, @@ -332,7 +334,7 @@ struct server_task_result_cmpl_final : server_task_result { return res; } - virtual json to_json_oai_compat() override { + json to_json_oai_compat() { std::string finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { finish_reason = "stop"; @@ -388,9 +390,11 @@ struct server_task_result_cmpl_partial : server_task_result { result_timings timings; // OAI-compat fields + bool verbose = false; + bool oaicompat = false; + bool oaicompat_chat = true; // TODO: support oaicompat for non-chat std::string oaicompat_model; std::string oaicompat_cmpl_id; - bool verbose = false; virtual int get_index() override { return index; @@ -401,6 +405,9 @@ struct server_task_result_cmpl_partial : server_task_result { } virtual json to_json() override { + if (oaicompat) { + return to_json_oai_compat(); + } bool is_stop = stop != STOP_TYPE_NONE; // non-OAI-compat JSON json res = json { @@ -425,7 +432,7 @@ struct server_task_result_cmpl_partial : server_task_result { return res; } - virtual json to_json_oai_compat() override { + json to_json_oai_compat() { bool first = n_decoded == 0; std::string finish_reason; @@ -1461,6 +1468,7 @@ struct server_context { if (data.count("__oaicompat") != 0) { std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; slot.params.oaicompat = true; + slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false); slot.params.oaicompat_model = json_value(data, "model", model_name); slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string()); } else { @@ -1850,9 +1858,11 @@ struct server_context { res->stop = slot.stop; + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->verbose = slot.params.verbose; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -1899,9 +1909,11 @@ struct server_context { res->stopping_word = slot.stopping_word; res->stop = slot.stop; + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->verbose = slot.params.verbose; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -3397,12 +3409,12 @@ int main(int argc, char ** argv) { ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { if (results.size() == 1) { // single result - res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json()); + res_ok(res, results[0]->to_json()); } else { // multiple results (multitask) json arr = json::array(); for (auto & res : results) { - arr.push_back(oai_compat ? res->to_json_oai_compat() : res->to_json()); + arr.push_back(res->to_json()); } res_ok(res, arr); } @@ -3414,7 +3426,7 @@ int main(int argc, char ** argv) { } else { const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { - json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json(); + json res_json = result->to_json(); if (res_json.is_array()) { for (const auto & res : res_json) { if (!server_sent_event(sink, "data", res)) { @@ -3506,7 +3518,7 @@ int main(int argc, char ** argv) { } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - + data["__oaicompat_chat"] = true; return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); };