server : add OAI compat for /v1/completions

This commit is contained in:
Xuan Son Nguyen 2024-12-25 13:48:02 +01:00
parent 9ba399dfa7
commit 90889fddc9
2 changed files with 198 additions and 52 deletions

View file

@ -67,6 +67,13 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA, SERVER_TASK_TYPE_SET_LORA,
}; };
enum oaicompat_type {
OAICOMPAT_TYPE_NONE,
OAICOMPAT_TYPE_CHAT,
OAICOMPAT_TYPE_COMPLETION,
OAICOMPAT_TYPE_EMBEDDING,
};
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type { enum error_type {
ERROR_TYPE_INVALID_REQUEST, ERROR_TYPE_INVALID_REQUEST,
@ -101,11 +108,10 @@ struct slot_params {
struct common_params_speculative speculative; struct common_params_speculative speculative;
// OAI-compat fields // OAI-compat fields
bool verbose = false; bool verbose = false;
bool oaicompat = false; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
bool oaicompat_chat = true; std::string oaicompat_model;
std::string oaicompat_model; std::string oaicompat_cmpl_id;
std::string oaicompat_cmpl_id;
json to_json() const { json to_json() const {
std::vector<std::string> samplers; std::vector<std::string> samplers;
@ -529,11 +535,10 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params; slot_params generation_params;
// OAI-compat fields // OAI-compat fields
bool verbose = false; bool verbose = false;
bool oaicompat = false; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat std::string oaicompat_model;
std::string oaicompat_model; std::string oaicompat_cmpl_id;
std::string oaicompat_cmpl_id;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -544,9 +549,16 @@ struct server_task_result_cmpl_final : server_task_result {
} }
virtual json to_json() override { virtual json to_json() override {
return oaicompat switch (oaicompat) {
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat()) case OAICOMPAT_TYPE_NONE:
: to_json_non_oaicompat(); return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
} }
json to_json_non_oaicompat() { json to_json_non_oaicompat() {
@ -574,6 +586,50 @@ struct server_task_result_cmpl_final : server_task_result {
return response_fields.empty() ? res : json_get_nested_values(response_fields, res); return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
} }
json to_json_oaicompat() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (!stream && probs_output.size() > 0) {
logprobs = json{
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
};
}
json finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
json res = json {
{"choices", json::array({
json{
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
}
})},
{"created", t},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "text_completion"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
}},
{"id", oaicompat_cmpl_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json_non_oaicompat();
}
if (timings.prompt_n >= 0) {
res.push_back({"timings", timings.to_json()});
}
return res;
}
json to_json_oaicompat_chat() { json to_json_oaicompat_chat() {
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
@ -671,11 +727,10 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings; result_timings timings;
// OAI-compat fields // OAI-compat fields
bool verbose = false; bool verbose = false;
bool oaicompat = false; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat std::string oaicompat_model;
std::string oaicompat_model; std::string oaicompat_cmpl_id;
std::string oaicompat_cmpl_id;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -686,7 +741,16 @@ struct server_task_result_cmpl_partial : server_task_result {
} }
virtual json to_json() override { virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
return to_json_oaicompat_chat();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
} }
json to_json_non_oaicompat() { json to_json_non_oaicompat() {
@ -711,6 +775,41 @@ struct server_task_result_cmpl_partial : server_task_result {
} }
json to_json_oaicompat() { json to_json_oaicompat() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (prob_output.probs.size() > 0) {
logprobs = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json res = json {
{"choices", json::array({
json{
{"text", content},
{"index", index},
{"logprobs", logprobs},
{"finish_reason", nullptr},
}
})},
{"created", t},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "text_completion"},
{"id", oaicompat_cmpl_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json_non_oaicompat();
}
if (timings.prompt_n >= 0) {
res.push_back({"timings", timings.to_json()});
}
return res;
}
json to_json_oaicompat_chat() {
bool first = n_decoded == 0; bool first = n_decoded == 0;
std::time_t t = std::time(0); std::time_t t = std::time(0);
json choices; json choices;
@ -789,14 +888,16 @@ struct server_task_result_embd : server_task_result {
int32_t n_tokens; int32_t n_tokens;
// OAI-compat fields // OAI-compat fields
bool oaicompat = false; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
} }
virtual json to_json() override { virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); return oaicompat == OAICOMPAT_TYPE_EMBEDDING
? to_json_oaicompat()
: to_json_non_oaicompat();
} }
json to_json_non_oaicompat() { json to_json_non_oaicompat() {
@ -2042,7 +2143,6 @@ struct server_context {
res->verbose = slot.params.verbose; res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
@ -2083,7 +2183,6 @@ struct server_context {
res->verbose = slot.params.verbose; res->verbose = slot.params.verbose;
res->stream = slot.params.stream; res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
@ -3504,12 +3603,11 @@ int main(int argc, char ** argv) {
// handle completion-like requests (completion, chat, infill) // handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results // we can optionally provide a custom format for partial results and final results
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type, server_task_type type,
json & data, json & data,
httplib::Response & res, httplib::Response & res,
bool oaicompat = false, oaicompat_type oaicompat) {
bool oaicompat_chat = false) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {
@ -3534,9 +3632,8 @@ int main(int argc, char ** argv) {
task.id_selected_slot = json_value(data, "id_slot", -1); task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat // OAI-compat
task.params.oaicompat = oaicompat; task.params.oaicompat = oaicompat;
task.params.oaicompat_chat = oaicompat_chat; task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl // oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task); tasks.push_back(task);
@ -3587,7 +3684,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) { }, [&](const json & error_data) {
server_sent_event(sink, "error", error_data); server_sent_event(sink, "error", error_data);
}); });
if (oaicompat) { if (oaicompat != OAICOMPAT_TYPE_NONE) {
static const std::string ev_done = "data: [DONE]\n\n"; static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size()); sink.write(ev_done.data(), ev_done.size());
} }
@ -3603,17 +3700,25 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
data, data,
res, res,
/* oaicompat */ false, OAICOMPAT_TYPE_NONE);
/* oaicompat_chat */ false);
}; };
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = oaicompat_completion_params_parse(json::parse(req.body));
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
res,
OAICOMPAT_TYPE_COMPLETION);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility // check model compatibility
std::string err; std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
@ -3682,22 +3787,25 @@ int main(int argc, char ** argv) {
tokenized_prompts[0] tokenized_prompts[0]
); );
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); return handle_completions_impl(
SERVER_TASK_TYPE_INFILL,
data,
res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
}; };
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
return handle_completions_generic( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
data, data,
res, res,
/* oaicompat */ true, OAICOMPAT_TYPE_CHAT);
/* oaicompat_chat */ true);
}; };
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
@ -3770,10 +3878,10 @@ int main(int argc, char ** argv) {
res_ok(res, data); res_ok(res, data);
}; };
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) { const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
@ -3783,7 +3891,7 @@ int main(int argc, char ** argv) {
if (body.count("input") != 0) { if (body.count("input") != 0) {
prompt = body.at("input"); prompt = body.at("input");
} else if (body.contains("content")) { } else if (body.contains("content")) {
oaicompat = false; oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
prompt = body.at("content"); prompt = body.at("content");
} else { } else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
@ -3852,16 +3960,18 @@ int main(int argc, char ** argv) {
} }
// write JSON response // write JSON response
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
? format_embeddings_response_oaicompat(body, responses, use_base64)
: json(responses);
res_ok(res, root); res_ok(res, root);
}; };
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false); handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
}; };
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true); handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
}; };
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
@ -4031,7 +4141,7 @@ int main(int argc, char ** argv) {
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
svr->Post("/completion", handle_completions); // legacy svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions); svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions); svr->Post("/v1/completions", handle_completions_oai);
svr->Post("/chat/completions", handle_chat_completions); svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions); svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/infill", handle_infill); svr->Post("/infill", handle_infill);

View file

@ -549,10 +549,46 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
// OAI utils // OAI utils
// //
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(const json & body) {
const struct llama_model * model, json llama_params;
const json & body, /* openai api json semantics */
const std::string & chat_template) { if (!body.contains("prompt")) {
throw std::runtime_error("\"prompt\" is required");
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is allowed");
}
// Params supported by OAI but unsupported by llama.cpp
if (body.contains("best_of")) {
throw std::runtime_error("Unsupported param: best_of");
}
// Copy remaining properties to llama_params
for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
}
}
return llama_params;
}
static json oaicompat_chat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
json llama_params; json llama_params;
// Apply chat template to the list of messages // Apply chat template to the list of messages