diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 30ff3b149..af4b55a21 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -67,6 +67,13 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { ERROR_TYPE_INVALID_REQUEST, @@ -101,11 +108,10 @@ struct slot_params { struct common_params_speculative speculative; // OAI-compat fields - bool verbose = false; - bool oaicompat = false; - bool oaicompat_chat = true; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; json to_json() const { std::vector samplers; @@ -529,11 +535,10 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - bool oaicompat = false; - bool oaicompat_chat = true; // TODO: support oaicompat for non-chat - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; virtual int get_index() override { return index; @@ -544,9 +549,16 @@ struct server_task_result_cmpl_final : server_task_result { } virtual json to_json() override { - return oaicompat - ? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat()) - : to_json_non_oaicompat(); + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } } json to_json_non_oaicompat() { @@ -574,6 +586,50 @@ struct server_task_result_cmpl_final : server_task_result { return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + json to_json_oaicompat_chat() { std::string finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { @@ -671,11 +727,10 @@ struct server_task_result_cmpl_partial : server_task_result { result_timings timings; // OAI-compat fields - bool verbose = false; - bool oaicompat = false; - bool oaicompat_chat = true; // TODO: support oaicompat for non-chat - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; virtual int get_index() override { return index; @@ -686,7 +741,16 @@ struct server_task_result_cmpl_partial : server_task_result { } virtual json to_json() override { - return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } } json to_json_non_oaicompat() { @@ -711,6 +775,41 @@ struct server_task_result_cmpl_partial : server_task_result { } json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { bool first = n_decoded == 0; std::time_t t = std::time(0); json choices; @@ -789,14 +888,16 @@ struct server_task_result_embd : server_task_result { int32_t n_tokens; // OAI-compat fields - bool oaicompat = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; virtual int get_index() override { return index; } virtual json to_json() override { - return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); } json to_json_non_oaicompat() { @@ -2042,7 +2143,6 @@ struct server_context { res->verbose = slot.params.verbose; res->oaicompat = slot.params.oaicompat; - res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; @@ -2083,7 +2183,6 @@ struct server_context { res->verbose = slot.params.verbose; res->stream = slot.params.stream; res->oaicompat = slot.params.oaicompat; - res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; @@ -3504,12 +3603,11 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( + const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, httplib::Response & res, - bool oaicompat = false, - bool oaicompat_chat = false) { + oaicompat_type oaicompat) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3534,9 +3632,8 @@ int main(int argc, char ** argv) { task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_chat = oaicompat_chat; - task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(task); @@ -3587,7 +3684,7 @@ int main(int argc, char ** argv) { }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); - if (oaicompat) { + if (oaicompat != OAICOMPAT_TYPE_NONE) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } @@ -3603,17 +3700,25 @@ int main(int argc, char ** argv) { } }; - const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic( + return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, res, - /* oaicompat */ false, - /* oaicompat_chat */ false); + OAICOMPAT_TYPE_NONE); }; - const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = oaicompat_completion_params_parse(json::parse(req.body)); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + res, + OAICOMPAT_TYPE_COMPLETION); + }; + + const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { @@ -3682,22 +3787,25 @@ int main(int argc, char ** argv) { tokenized_prompts[0] ); - return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); + return handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + res, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - return handle_completions_generic( + json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, res, - /* oaicompat */ true, - /* oaicompat_chat */ true); + OAICOMPAT_TYPE_CHAT); }; const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { @@ -3770,10 +3878,10 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) { + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { const json body = json::parse(req.body); - if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3783,7 +3891,7 @@ int main(int argc, char ** argv) { if (body.count("input") != 0) { prompt = body.at("input"); } else if (body.contains("content")) { - oaicompat = false; + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible prompt = body.at("content"); } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); @@ -3852,16 +3960,18 @@ int main(int argc, char ** argv) { } // write JSON response - json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); res_ok(res, root); }; const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, false); + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); }; const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, true); + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); }; const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { @@ -4031,7 +4141,7 @@ int main(int argc, char ** argv) { svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) svr->Post("/completion", handle_completions); // legacy svr->Post("/completions", handle_completions); - svr->Post("/v1/completions", handle_completions); + svr->Post("/v1/completions", handle_completions_oai); svr->Post("/chat/completions", handle_chat_completions); svr->Post("/v1/chat/completions", handle_chat_completions); svr->Post("/infill", handle_infill); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 334f2f192..640cf6a29 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -549,10 +549,46 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons // OAI utils // -static json oaicompat_completion_params_parse( - const struct llama_model * model, - const json & body, /* openai api json semantics */ - const std::string & chat_template) { +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Params supported by OAI but unsupported by llama.cpp + if (body.contains("best_of")) { + throw std::runtime_error("Unsupported param: best_of"); + } + + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json oaicompat_chat_completion_params_parse( + const struct llama_model * model, + const json & body, /* openai api json semantics */ + const std::string & chat_template) { json llama_params; // Apply chat template to the list of messages