remove virtual for to_json_oai_compat()

This commit is contained in:
Xuan Son Nguyen 2024-12-05 23:29:27 +01:00
parent 4c3d2580b2
commit ffc4441b1d

View file

@ -128,10 +128,11 @@ struct slot_params {
bool can_speculative; bool can_speculative;
// OAI-compat fields // OAI-compat fields
bool oaicompat = false; bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
bool verbose = false;
json to_json() { json to_json() {
std::vector<std::string> samplers; std::vector<std::string> samplers;
@ -226,10 +227,6 @@ struct server_task_result {
return -1; return -1;
} }
virtual json to_json() = 0; virtual json to_json() = 0;
virtual json to_json_oai_compat() {
// used by server_task_result_cmpl_final and server_task_result_cmpl_partial
return json();
}
virtual ~server_task_result() = default; virtual ~server_task_result() = default;
}; };
@ -299,16 +296,21 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params; slot_params generation_params;
// OAI-compat fields // OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
bool verbose = false;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
} }
virtual json to_json() override { virtual json to_json() override {
// non-OAI-compat JSON if (oaicompat) {
return to_json_oai_compat();
}
// otherwise, non-OAI-compat JSON
json res = json { json res = json {
{"index", index}, {"index", index},
{"content", content}, {"content", content},
@ -332,7 +334,7 @@ struct server_task_result_cmpl_final : server_task_result {
return res; return res;
} }
virtual json to_json_oai_compat() override { json to_json_oai_compat() {
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop"; finish_reason = "stop";
@ -388,9 +390,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings; result_timings timings;
// OAI-compat fields // OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
bool verbose = false;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -401,6 +405,9 @@ struct server_task_result_cmpl_partial : server_task_result {
} }
virtual json to_json() override { virtual json to_json() override {
if (oaicompat) {
return to_json_oai_compat();
}
bool is_stop = stop != STOP_TYPE_NONE; bool is_stop = stop != STOP_TYPE_NONE;
// non-OAI-compat JSON // non-OAI-compat JSON
json res = json { json res = json {
@ -425,7 +432,7 @@ struct server_task_result_cmpl_partial : server_task_result {
return res; return res;
} }
virtual json to_json_oai_compat() override { json to_json_oai_compat() {
bool first = n_decoded == 0; bool first = n_decoded == 0;
std::string finish_reason; std::string finish_reason;
@ -1461,6 +1468,7 @@ struct server_context {
if (data.count("__oaicompat") != 0) { if (data.count("__oaicompat") != 0) {
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
slot.params.oaicompat = true; slot.params.oaicompat = true;
slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false);
slot.params.oaicompat_model = json_value(data, "model", model_name); slot.params.oaicompat_model = json_value(data, "model", model_name);
slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string()); slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string());
} else { } else {
@ -1850,9 +1858,11 @@ struct server_context {
res->stop = slot.stop; res->stop = slot.stop;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -1899,9 +1909,11 @@ struct server_context {
res->stopping_word = slot.stopping_word; res->stopping_word = slot.stopping_word;
res->stop = slot.stop; res->stop = slot.stop;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -3397,12 +3409,12 @@ int main(int argc, char ** argv) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) { ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
if (results.size() == 1) { if (results.size() == 1) {
// single result // single result
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json()); res_ok(res, results[0]->to_json());
} else { } else {
// multiple results (multitask) // multiple results (multitask)
json arr = json::array(); json arr = json::array();
for (auto & res : results) { for (auto & res : results) {
arr.push_back(oai_compat ? res->to_json_oai_compat() : res->to_json()); arr.push_back(res->to_json());
} }
res_ok(res, arr); res_ok(res, arr);
} }
@ -3414,7 +3426,7 @@ int main(int argc, char ** argv) {
} else { } else {
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json(); json res_json = result->to_json();
if (res_json.is_array()) { if (res_json.is_array()) {
for (const auto & res : res_json) { for (const auto & res : res_json) {
if (!server_sent_event(sink, "data", res)) { if (!server_sent_event(sink, "data", res)) {
@ -3506,7 +3518,7 @@ int main(int argc, char ** argv) {
} }
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
data["__oaicompat_chat"] = true;
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
}; };