Server: reorganize some http logic (#5939)

* refactor static file handler

* use set_pre_routing_handler for validate_api_key

* merge embedding handlers

* correct http verb for endpoints

* fix embedding response

* fix test case CORS Options

* fix code style
This commit is contained in:
Xuan Son Nguyen 2024-03-09 11:27:53 +01:00 committed by GitHub
parent e1fa9569ba
commit 950ba1ab84
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 379 additions and 358 deletions

View file

@ -113,7 +113,7 @@ struct server_params {
int32_t n_threads_http = -1;
std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public";
std::string public_path = "";
std::string chat_template = "";
std::string system_prompt = "";
@ -2145,7 +2145,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@ -2211,7 +2211,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
sparams.api_keys.emplace_back(argv[i]);
sparams.api_keys.push_back(argv[i]);
} else if (arg == "--api-key-file") {
if (++i >= argc) {
invalid_param = true;
@ -2712,7 +2712,143 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Headers", "*");
});
svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
svr->set_logger(log_server_request);
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try {
std::rethrow_exception(std::move(ep));
} catch (std::exception &e) {
snprintf(buf, sizeof(buf), fmt, e.what());
} catch (...) {
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain; charset=utf-8");
res.status = 500;
});
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 401) {
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400) {
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
if (res.status == 404) {
res.set_content("File Not Found", "text/plain; charset=utf-8");
}
});
// set timeouts and change hostname and port
svr->set_read_timeout (sparams.read_timeout);
svr->set_write_timeout(sparams.write_timeout);
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}
std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);
if (sparams.api_keys.size() == 1) {
auto key = sparams.api_keys[0];
log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
} else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
}
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.initialize();
state.store(SERVER_STATE_READY);
}
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
if (sparams.chat_template.empty()) { // custom chat template is not supplied
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}
//
// Middlewares
//
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::set<std::string> protected_endpoints = {
"/props",
"/completion",
"/completions",
"/v1/completions",
"/chat/completions",
"/v1/chat/completions",
"/infill",
"/tokenize",
"/detokenize",
"/embedding",
"/embeddings",
"/v1/embeddings",
};
// If API key is not set, skip validation
if (sparams.api_keys.empty()) {
return true;
}
// If path is not in protected_endpoints list, skip validation
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
return true; // API key is valid
}
}
// API key is invalid or not provided
// TODO: make another middleware for CORS related logic
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized
LOG_WARNING("Unauthorized: Invalid API Key", {});
return false;
};
// register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
//
// Route handlers (or controllers)
//
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
@ -2765,252 +2901,136 @@ int main(int argc, char ** argv) {
res.status = 500; // HTTP Internal Server Error
} break;
}
});
if (sparams.slots_endpoint) {
svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
res.set_content(result.data["slots"].dump(), "application/json");
res.status = 200; // HTTP OK
});
}
if (sparams.metrics_endpoint) {
svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
task.data.push_back({{"reset_bucket", true}});
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
json data = result.data;
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
const uint64_t t_prompt_processing = data["t_prompt_processing"];
const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
const uint64_t t_tokens_generation = data["t_tokens_generation"];
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json {
{"counter", {{
{"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."},
{"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
}, {
{"name", "prompt_seconds_total"},
{"help", "Prompt process time"},
{"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
}, {
{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
{"value", (uint64_t) data["n_tokens_predicted_total"]}
}, {
{"name", "tokens_predicted_seconds_total"},
{"help", "Predict process time"},
{"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
},{
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
},{
{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * kv_cache_used_cells / params.n_ctx}
},{
{"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."},
{"value", (uint64_t) data["kv_cache_tokens_count"]}
},{
{"name", "requests_processing"},
{"help", "Number of request processing."},
{"value", (uint64_t) data["processing"]}
},{
{"name", "requests_deferred"},
{"help", "Number of request deferred."},
{"value", (uint64_t) data["deferred"]}
}}}
};
std::stringstream prometheus;
for (const auto & el : all_metrics_def.items()) {
const auto & type = el.key();
const auto & metrics_def = el.value();
for (const auto & metric_def : metrics_def) {
const std::string name = metric_def["name"];
const std::string help = metric_def["help"];
auto value = json_value(metric_def, "value", 0.);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
}
}
const int64_t t_start = data["t_start"];
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
res.status = 200; // HTTP OK
});
}
svr->set_logger(log_server_request);
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try {
std::rethrow_exception(std::move(ep));
} catch (std::exception &e) {
snprintf(buf, sizeof(buf), fmt, e.what());
} catch (...) {
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain; charset=utf-8");
res.status = 500;
});
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 401) {
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400) {
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
if (res.status == 404) {
res.set_content("File Not Found", "text/plain; charset=utf-8");
}
});
// set timeouts and change hostname and port
svr->set_read_timeout (sparams.read_timeout);
svr->set_write_timeout(sparams.write_timeout);
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);
if (sparams.api_keys.size() == 1) {
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
} else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
}
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.initialize();
state.store(SERVER_STATE_READY);
}
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
if (sparams.chat_template.empty()) { // custom chat template is not supplied
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}
// Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, skip validation
if (sparams.api_keys.empty()) {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
return true; // API key is valid
}
}
// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized
LOG_WARNING("Unauthorized: Invalid API Key", {});
return false;
};
// this is only called if no index.html is found in the public --path
svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
return false;
});
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.slots_endpoint) {
res.status = 501;
res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
return;
}
// this is only called if no index.js is found in the public --path
svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
return false;
});
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
// this is only called if no index.html is found in the public --path
svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
return false;
});
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// this is only called if no index.html is found in the public --path
svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
return false;
});
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_content(result.data["slots"].dump(), "application/json");
res.status = 200; // HTTP OK
};
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.metrics_endpoint) {
res.status = 501;
res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
return;
}
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
task.data.push_back({{"reset_bucket", true}});
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
json data = result.data;
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
const uint64_t t_prompt_processing = data["t_prompt_processing"];
const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
const uint64_t t_tokens_generation = data["t_tokens_generation"];
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json {
{"counter", {{
{"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."},
{"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
}, {
{"name", "prompt_seconds_total"},
{"help", "Prompt process time"},
{"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
}, {
{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
{"value", (uint64_t) data["n_tokens_predicted_total"]}
}, {
{"name", "tokens_predicted_seconds_total"},
{"help", "Predict process time"},
{"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
},{
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
},{
{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * kv_cache_used_cells / params.n_ctx}
},{
{"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."},
{"value", (uint64_t) data["kv_cache_tokens_count"]}
},{
{"name", "requests_processing"},
{"help", "Number of request processing."},
{"value", (uint64_t) data["processing"]}
},{
{"name", "requests_deferred"},
{"help", "Number of request deferred."},
{"value", (uint64_t) data["deferred"]}
}}}
};
std::stringstream prometheus;
for (const auto & el : all_metrics_def.items()) {
const auto & type = el.key();
const auto & metrics_def = el.value();
for (const auto & metric_def : metrics_def) {
const std::string name = metric_def["name"];
const std::string help = metric_def["help"];
auto value = json_value(metric_def, "value", 0.);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
}
}
const int64_t t_start = data["t_start"];
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
res.status = 200; // HTTP OK
};
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", ctx_server.name_user.c_str() },
@ -3020,13 +3040,10 @@ int main(int argc, char ** argv) {
};
res.set_content(data.dump(), "application/json; charset=utf-8");
});
};
const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body);
@ -3102,11 +3119,7 @@ int main(int argc, char ** argv) {
}
};
svr->Post("/completion", completions); // legacy
svr->Post("/completions", completions);
svr->Post("/v1/completions", completions);
svr->Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = {
@ -3123,14 +3136,10 @@ int main(int argc, char ** argv) {
};
res.set_content(models.dump(), "application/json; charset=utf-8");
});
};
const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3201,14 +3210,8 @@ int main(int argc, char ** argv) {
}
};
svr->Post("/chat/completions", chat_completions);
svr->Post("/v1/chat/completions", chat_completions);
svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body);
@ -3266,13 +3269,9 @@ int main(int argc, char ** argv) {
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});
};
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
return res.set_content("", "application/json; charset=utf-8");
});
svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
@ -3282,9 +3281,9 @@ int main(int argc, char ** argv) {
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
};
svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
@ -3296,9 +3295,9 @@ int main(int argc, char ** argv) {
const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
};
svr->Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings = [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
@ -3307,94 +3306,114 @@ int main(int argc, char ** argv) {
}
const json body = json::parse(req.body);
bool is_openai = false;
json prompt;
if (body.count("content") != 0) {
prompt = body["content"];
} else {
prompt = "";
}
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
// send the result
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
});
svr->Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}
const json body = json::parse(req.body);
json prompt;
// an input prompt can string or a list of tokens (integer)
std::vector<json> prompts;
if (body.count("input") != 0) {
prompt = body["input"];
if (prompt.is_array()) {
json data = json::array();
int i = 0;
for (const json & elem : prompt) {
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
json embedding = json{
{"embedding", json_value(result.data, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
data.push_back(embedding);
is_openai = true;
if (body["input"].is_array()) {
// support multiple prompts
for (const json & elem : body["input"]) {
prompts.push_back(elem);
}
json result = format_embeddings_response_oaicompat(body, data);
return res.set_content(result.dump(), "application/json; charset=utf-8");
} else {
// single input prompt
prompts.push_back(body["input"]);
}
} else if (body.count("content") != 0) {
// only support single prompt here
std::string content = body["content"];
prompts.push_back(content);
} else {
prompt = "";
// TODO @ngxson : should return an error here
prompts.push_back("");
}
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id();
// process all prompts
json responses = json::array();
for (auto & prompt : prompts) {
// TODO @ngxson : maybe support multitask for this endpoint?
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
json data = json::array({json{
{"embedding", json_value(result.data, "embedding", json::array())},
{"index", 0},
{"object", "embedding"}
}}
);
json root = format_embeddings_response_oaicompat(body, data);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
responses.push_back(result.data);
}
// write JSON response
json root;
if (is_openai) {
json res_oai = json::array();
int i = 0;
for (auto & elem : responses) {
res_oai.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
}
root = format_embeddings_response_oaicompat(body, res_oai);
} else {
root = responses[0];
}
return res.set_content(root.dump(), "application/json; charset=utf-8");
});
};
//
// Router
//
// register static assets routes
if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
}
// using embedded static files
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
return res.set_content("", "application/json; charset=utf-8");
});
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
// register API routes
svr->Get ("/health", handle_health);
svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
svr->Get ("/v1/models", handle_models);
svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions);
svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
//
// Start the server
//
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);