diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b63a6f243..2ff3d9588 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) { res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); + return res.set_content("", "application/json; charset=utf-8"); }); svr->set_logger(log_server_request); @@ -3371,21 +3372,15 @@ int main(int argc, char ** argv) { const json body = json::parse(req.body); bool is_openai = false; - // an input prompt can string or a list of tokens (integer) - std::vector prompts; + // an input prompt can be a string or a list of tokens (integer) + json prompts = json::array(); if (body.count("input") != 0) { is_openai = true; - if (body["input"].is_array()) { - // support multiple prompts - for (const json & elem : body["input"]) { - prompts.push_back(elem); - } - } else { - // single input prompt - prompts.push_back(body["input"]); - } + prompts = body["input"].is_array() + ? body["input"] // support multiple prompts + : json{body["input"]}; // single input prompt } else if (body.count("content") != 0) { - // only support single prompt here + // with "content", we only support single prompt std::string content = body["content"]; prompts.push_back(content); } else { @@ -3393,22 +3388,25 @@ int main(int argc, char ** argv) { return; } - // process all prompts + // create and queue the task json responses = json::array(); - for (auto & prompt : prompts) { - // TODO @ngxson : maybe support multitask for this endpoint? - // create and queue the task + { const int id_task = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); + // if number of prompts is more than 1, we pass an array to create a multi-task + // otherwise, we pass a single prompt to make a single task + ctx_server.request_completion(id_task, -1, { + {"prompt", prompts.size() == 1 ? prompts[0] : prompts}, + {"n_predict", 0} + }, false, true); // get the result server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); if (!result.error) { - // append to the responses - responses.push_back(result.data); + responses = result.data.count("results") + ? result.data["results"] // result for multi-task + : json{result.data}; // result for single task } else { // error received, ignore everything else res_error(res, result.data); @@ -3417,24 +3415,19 @@ int main(int argc, char ** argv) { } // write JSON response - json root; - if (is_openai) { - json res_oai = json::array(); - int i = 0; - for (auto & elem : responses) { - res_oai.push_back(json{ - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }); - } - root = format_embeddings_response_oaicompat(body, res_oai); - } else { - root = responses[0]; - } + json root = is_openai + ? format_embeddings_response_oaicompat(body, responses) + : responses[0]; return res.set_content(root.dump(), "application/json; charset=utf-8"); }; + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(content), len, mime_type); + return false; + }; + }; + // // Router // @@ -3446,17 +3439,6 @@ int main(int argc, char ** argv) { } // using embedded static files - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; - }; - - svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { - // TODO @ngxson : I have no idea what it is... maybe this is redundant? - return res.set_content("", "application/json; charset=utf-8"); - }); svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 48aeef4eb..2ddb2cd21 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -529,6 +529,16 @@ static std::vector format_partial_response_oaicompat(json result, const st } static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { + json data = json::array(); + int i = 0; + for (auto & elem : embeddings) { + data.push_back(json{ + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }); + } + json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, @@ -536,7 +546,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso {"prompt_tokens", 0}, {"total_tokens", 0} }}, - {"data", embeddings} + {"data", data} }; return res;