remove virtual for to_json_oai_compat()

This commit is contained in:
Xuan Son Nguyen 2024-12-05 23:29:27 +01:00
parent 4c3d2580b2
commit ffc4441b1d

View file

@ -128,10 +128,11 @@ struct slot_params {
bool can_speculative;
// OAI-compat fields
bool oaicompat = false;
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
json to_json() {
std::vector<std::string> samplers;
@ -226,10 +227,6 @@ struct server_task_result {
return -1;
}
virtual json to_json() = 0;
virtual json to_json_oai_compat() {
// used by server_task_result_cmpl_final and server_task_result_cmpl_partial
return json();
}
virtual ~server_task_result() = default;
};
@ -299,16 +296,21 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params;
// OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
virtual int get_index() override {
return index;
}
virtual json to_json() override {
// non-OAI-compat JSON
if (oaicompat) {
return to_json_oai_compat();
}
// otherwise, non-OAI-compat JSON
json res = json {
{"index", index},
{"content", content},
@ -332,7 +334,7 @@ struct server_task_result_cmpl_final : server_task_result {
return res;
}
virtual json to_json_oai_compat() override {
json to_json_oai_compat() {
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
@ -388,9 +390,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
// OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
virtual int get_index() override {
return index;
@ -401,6 +405,9 @@ struct server_task_result_cmpl_partial : server_task_result {
}
virtual json to_json() override {
if (oaicompat) {
return to_json_oai_compat();
}
bool is_stop = stop != STOP_TYPE_NONE;
// non-OAI-compat JSON
json res = json {
@ -425,7 +432,7 @@ struct server_task_result_cmpl_partial : server_task_result {
return res;
}
virtual json to_json_oai_compat() override {
json to_json_oai_compat() {
bool first = n_decoded == 0;
std::string finish_reason;
@ -1461,6 +1468,7 @@ struct server_context {
if (data.count("__oaicompat") != 0) {
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
slot.params.oaicompat = true;
slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false);
slot.params.oaicompat_model = json_value(data, "model", model_name);
slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string());
} else {
@ -1850,9 +1858,11 @@ struct server_context {
res->stop = slot.stop;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -1899,9 +1909,11 @@ struct server_context {
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -3397,12 +3409,12 @@ int main(int argc, char ** argv) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
if (results.size() == 1) {
// single result
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json());
res_ok(res, results[0]->to_json());
} else {
// multiple results (multitask)
json arr = json::array();
for (auto & res : results) {
arr.push_back(oai_compat ? res->to_json_oai_compat() : res->to_json());
arr.push_back(res->to_json());
}
res_ok(res, arr);
}
@ -3414,7 +3426,7 @@ int main(int argc, char ** argv) {
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json();
json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
if (!server_sent_event(sink, "data", res)) {
@ -3506,7 +3518,7 @@ int main(int argc, char ** argv) {
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
data["__oaicompat_chat"] = true;
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
};