Server: reorganize some http logic (#5939)
* refactor static file handler * use set_pre_routing_handler for validate_api_key * merge embedding handlers * correct http verb for endpoints * fix embedding response * fix test case CORS Options * fix code style
This commit is contained in:
parent
e1fa9569ba
commit
950ba1ab84
4 changed files with 379 additions and 358 deletions
|
@ -113,7 +113,7 @@ struct server_params {
|
|||
int32_t n_threads_http = -1;
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = "examples/server/public";
|
||||
std::string public_path = "";
|
||||
std::string chat_template = "";
|
||||
std::string system_prompt = "";
|
||||
|
||||
|
@ -2145,7 +2145,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
|
||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
@ -2211,7 +2211,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.api_keys.emplace_back(argv[i]);
|
||||
sparams.api_keys.push_back(argv[i]);
|
||||
} else if (arg == "--api-key-file") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -2712,7 +2712,143 @@ int main(int argc, char ** argv) {
|
|||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
});
|
||||
|
||||
svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
const char fmt[] = "500 Internal Server Error\n%s";
|
||||
|
||||
char buf[BUFSIZ];
|
||||
try {
|
||||
std::rethrow_exception(std::move(ep));
|
||||
} catch (std::exception &e) {
|
||||
snprintf(buf, sizeof(buf), fmt, e.what());
|
||||
} catch (...) {
|
||||
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
|
||||
}
|
||||
|
||||
res.set_content(buf, "text/plain; charset=utf-8");
|
||||
res.status = 500;
|
||||
});
|
||||
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 401) {
|
||||
res.set_content("Unauthorized", "text/plain; charset=utf-8");
|
||||
}
|
||||
if (res.status == 400) {
|
||||
res.set_content("Invalid request", "text/plain; charset=utf-8");
|
||||
}
|
||||
if (res.status == 404) {
|
||||
res.set_content("File Not Found", "text/plain; charset=utf-8");
|
||||
}
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
svr->set_read_timeout (sparams.read_timeout);
|
||||
svr->set_write_timeout(sparams.write_timeout);
|
||||
|
||||
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> log_data;
|
||||
|
||||
log_data["hostname"] = sparams.hostname;
|
||||
log_data["port"] = std::to_string(sparams.port);
|
||||
|
||||
if (sparams.api_keys.size() == 1) {
|
||||
auto key = sparams.api_keys[0];
|
||||
log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
|
||||
} else if (sparams.api_keys.size() > 1) {
|
||||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||
}
|
||||
|
||||
// load the model
|
||||
if (!ctx_server.load_model(params)) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
} else {
|
||||
ctx_server.initialize();
|
||||
state.store(SERVER_STATE_READY);
|
||||
}
|
||||
|
||||
LOG_INFO("model loaded", {});
|
||||
|
||||
const auto model_meta = ctx_server.model_meta();
|
||||
|
||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
sparams.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Middlewares
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||
static const std::set<std::string> protected_endpoints = {
|
||||
"/props",
|
||||
"/completion",
|
||||
"/completions",
|
||||
"/v1/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
"/infill",
|
||||
"/tokenize",
|
||||
"/detokenize",
|
||||
"/embedding",
|
||||
"/embeddings",
|
||||
"/v1/embeddings",
|
||||
};
|
||||
|
||||
// If API key is not set, skip validation
|
||||
if (sparams.api_keys.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If path is not in protected_endpoints list, skip validation
|
||||
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
|
||||
std::string prefix = "Bearer ";
|
||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
// TODO: make another middleware for CORS related logic
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||
res.status = 401; // Unauthorized
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// register server middlewares
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!middleware_validate_api_key(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
|
||||
//
|
||||
// Route handlers (or controllers)
|
||||
//
|
||||
|
||||
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
switch (current_state) {
|
||||
case SERVER_STATE_READY:
|
||||
|
@ -2765,252 +2901,136 @@ int main(int argc, char ** argv) {
|
|||
res.status = 500; // HTTP Internal Server Error
|
||||
} break;
|
||||
}
|
||||
});
|
||||
|
||||
if (sparams.slots_endpoint) {
|
||||
svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
|
||||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.id_multi = -1;
|
||||
task.id_target = -1;
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
res.set_content(result.data["slots"].dump(), "application/json");
|
||||
res.status = 200; // HTTP OK
|
||||
});
|
||||
}
|
||||
|
||||
if (sparams.metrics_endpoint) {
|
||||
svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
|
||||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.id_multi = -1;
|
||||
task.id_target = -1;
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
task.data.push_back({{"reset_bucket", true}});
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
json data = result.data;
|
||||
|
||||
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
|
||||
const uint64_t t_prompt_processing = data["t_prompt_processing"];
|
||||
|
||||
const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
|
||||
const uint64_t t_tokens_generation = data["t_tokens_generation"];
|
||||
|
||||
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
|
||||
|
||||
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
||||
json all_metrics_def = json {
|
||||
{"counter", {{
|
||||
{"name", "prompt_tokens_total"},
|
||||
{"help", "Number of prompt tokens processed."},
|
||||
{"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
|
||||
}, {
|
||||
{"name", "prompt_seconds_total"},
|
||||
{"help", "Prompt process time"},
|
||||
{"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
|
||||
}, {
|
||||
{"name", "tokens_predicted_total"},
|
||||
{"help", "Number of generation tokens processed."},
|
||||
{"value", (uint64_t) data["n_tokens_predicted_total"]}
|
||||
}, {
|
||||
{"name", "tokens_predicted_seconds_total"},
|
||||
{"help", "Predict process time"},
|
||||
{"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
|
||||
}}},
|
||||
{"gauge", {{
|
||||
{"name", "prompt_tokens_seconds"},
|
||||
{"help", "Average prompt throughput in tokens/s."},
|
||||
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
|
||||
},{
|
||||
{"name", "predicted_tokens_seconds"},
|
||||
{"help", "Average generation throughput in tokens/s."},
|
||||
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
|
||||
},{
|
||||
{"name", "kv_cache_usage_ratio"},
|
||||
{"help", "KV-cache usage. 1 means 100 percent usage."},
|
||||
{"value", 1. * kv_cache_used_cells / params.n_ctx}
|
||||
},{
|
||||
{"name", "kv_cache_tokens"},
|
||||
{"help", "KV-cache tokens."},
|
||||
{"value", (uint64_t) data["kv_cache_tokens_count"]}
|
||||
},{
|
||||
{"name", "requests_processing"},
|
||||
{"help", "Number of request processing."},
|
||||
{"value", (uint64_t) data["processing"]}
|
||||
},{
|
||||
{"name", "requests_deferred"},
|
||||
{"help", "Number of request deferred."},
|
||||
{"value", (uint64_t) data["deferred"]}
|
||||
}}}
|
||||
};
|
||||
|
||||
std::stringstream prometheus;
|
||||
|
||||
for (const auto & el : all_metrics_def.items()) {
|
||||
const auto & type = el.key();
|
||||
const auto & metrics_def = el.value();
|
||||
|
||||
for (const auto & metric_def : metrics_def) {
|
||||
const std::string name = metric_def["name"];
|
||||
const std::string help = metric_def["help"];
|
||||
|
||||
auto value = json_value(metric_def, "value", 0.);
|
||||
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
|
||||
<< "# TYPE llamacpp:" << name << " " << type << "\n"
|
||||
<< "llamacpp:" << name << " " << value << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_start = data["t_start"];
|
||||
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
|
||||
|
||||
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
||||
res.status = 200; // HTTP OK
|
||||
});
|
||||
}
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
const char fmt[] = "500 Internal Server Error\n%s";
|
||||
|
||||
char buf[BUFSIZ];
|
||||
try {
|
||||
std::rethrow_exception(std::move(ep));
|
||||
} catch (std::exception &e) {
|
||||
snprintf(buf, sizeof(buf), fmt, e.what());
|
||||
} catch (...) {
|
||||
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
|
||||
}
|
||||
|
||||
res.set_content(buf, "text/plain; charset=utf-8");
|
||||
res.status = 500;
|
||||
});
|
||||
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 401) {
|
||||
res.set_content("Unauthorized", "text/plain; charset=utf-8");
|
||||
}
|
||||
if (res.status == 400) {
|
||||
res.set_content("Invalid request", "text/plain; charset=utf-8");
|
||||
}
|
||||
if (res.status == 404) {
|
||||
res.set_content("File Not Found", "text/plain; charset=utf-8");
|
||||
}
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
svr->set_read_timeout (sparams.read_timeout);
|
||||
svr->set_write_timeout(sparams.write_timeout);
|
||||
|
||||
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Set the base directory for serving static files
|
||||
svr->set_base_dir(sparams.public_path);
|
||||
|
||||
std::unordered_map<std::string, std::string> log_data;
|
||||
|
||||
log_data["hostname"] = sparams.hostname;
|
||||
log_data["port"] = std::to_string(sparams.port);
|
||||
|
||||
if (sparams.api_keys.size() == 1) {
|
||||
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
|
||||
} else if (sparams.api_keys.size() > 1) {
|
||||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||
}
|
||||
|
||||
// load the model
|
||||
if (!ctx_server.load_model(params)) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
} else {
|
||||
ctx_server.initialize();
|
||||
state.store(SERVER_STATE_READY);
|
||||
}
|
||||
|
||||
LOG_INFO("model loaded", {});
|
||||
|
||||
const auto model_meta = ctx_server.model_meta();
|
||||
|
||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
sparams.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware for API key validation
|
||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||
// If API key is not set, skip validation
|
||||
if (sparams.api_keys.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
|
||||
std::string prefix = "Bearer ";
|
||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||
res.status = 401; // Unauthorized
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// this is only called if no index.html is found in the public --path
|
||||
svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!sparams.slots_endpoint) {
|
||||
res.status = 501;
|
||||
res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
|
||||
return;
|
||||
}
|
||||
|
||||
// this is only called if no index.js is found in the public --path
|
||||
svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.id_multi = -1;
|
||||
task.id_target = -1;
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
|
||||
// this is only called if no index.html is found in the public --path
|
||||
svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
// this is only called if no index.html is found in the public --path
|
||||
svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_content(result.data["slots"].dump(), "application/json");
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!sparams.metrics_endpoint) {
|
||||
res.status = 501;
|
||||
res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
|
||||
return;
|
||||
}
|
||||
|
||||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.id_multi = -1;
|
||||
task.id_target = -1;
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
task.data.push_back({{"reset_bucket", true}});
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
json data = result.data;
|
||||
|
||||
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
|
||||
const uint64_t t_prompt_processing = data["t_prompt_processing"];
|
||||
|
||||
const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
|
||||
const uint64_t t_tokens_generation = data["t_tokens_generation"];
|
||||
|
||||
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
|
||||
|
||||
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
||||
json all_metrics_def = json {
|
||||
{"counter", {{
|
||||
{"name", "prompt_tokens_total"},
|
||||
{"help", "Number of prompt tokens processed."},
|
||||
{"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
|
||||
}, {
|
||||
{"name", "prompt_seconds_total"},
|
||||
{"help", "Prompt process time"},
|
||||
{"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
|
||||
}, {
|
||||
{"name", "tokens_predicted_total"},
|
||||
{"help", "Number of generation tokens processed."},
|
||||
{"value", (uint64_t) data["n_tokens_predicted_total"]}
|
||||
}, {
|
||||
{"name", "tokens_predicted_seconds_total"},
|
||||
{"help", "Predict process time"},
|
||||
{"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
|
||||
}}},
|
||||
{"gauge", {{
|
||||
{"name", "prompt_tokens_seconds"},
|
||||
{"help", "Average prompt throughput in tokens/s."},
|
||||
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
|
||||
},{
|
||||
{"name", "predicted_tokens_seconds"},
|
||||
{"help", "Average generation throughput in tokens/s."},
|
||||
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
|
||||
},{
|
||||
{"name", "kv_cache_usage_ratio"},
|
||||
{"help", "KV-cache usage. 1 means 100 percent usage."},
|
||||
{"value", 1. * kv_cache_used_cells / params.n_ctx}
|
||||
},{
|
||||
{"name", "kv_cache_tokens"},
|
||||
{"help", "KV-cache tokens."},
|
||||
{"value", (uint64_t) data["kv_cache_tokens_count"]}
|
||||
},{
|
||||
{"name", "requests_processing"},
|
||||
{"help", "Number of request processing."},
|
||||
{"value", (uint64_t) data["processing"]}
|
||||
},{
|
||||
{"name", "requests_deferred"},
|
||||
{"help", "Number of request deferred."},
|
||||
{"value", (uint64_t) data["deferred"]}
|
||||
}}}
|
||||
};
|
||||
|
||||
std::stringstream prometheus;
|
||||
|
||||
for (const auto & el : all_metrics_def.items()) {
|
||||
const auto & type = el.key();
|
||||
const auto & metrics_def = el.value();
|
||||
|
||||
for (const auto & metric_def : metrics_def) {
|
||||
const std::string name = metric_def["name"];
|
||||
const std::string help = metric_def["help"];
|
||||
|
||||
auto value = json_value(metric_def, "value", 0.);
|
||||
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
|
||||
<< "# TYPE llamacpp:" << name << " " << type << "\n"
|
||||
<< "llamacpp:" << name << " " << value << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_start = data["t_start"];
|
||||
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
|
||||
|
||||
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json data = {
|
||||
{ "user_name", ctx_server.name_user.c_str() },
|
||||
|
@ -3020,13 +3040,10 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
};
|
||||
|
||||
const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
|
@ -3102,11 +3119,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
svr->Post("/completion", completions); // legacy
|
||||
svr->Post("/completions", completions);
|
||||
svr->Post("/v1/completions", completions);
|
||||
|
||||
svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json models = {
|
||||
|
@ -3123,14 +3136,10 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
};
|
||||
|
||||
const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
@ -3201,14 +3210,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
svr->Post("/chat/completions", chat_completions);
|
||||
svr->Post("/v1/chat/completions", chat_completions);
|
||||
|
||||
svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
|
@ -3266,13 +3269,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
|
@ -3282,9 +3281,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
};
|
||||
|
||||
svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
|
@ -3296,9 +3295,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const json data = format_detokenized_response(content);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
};
|
||||
|
||||
svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_embeddings = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!params.embedding) {
|
||||
res.status = 501;
|
||||
|
@ -3307,94 +3306,114 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
json prompt;
|
||||
if (body.count("content") != 0) {
|
||||
prompt = body["content"];
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
// send the result
|
||||
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!params.embedding) {
|
||||
res.status = 501;
|
||||
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
||||
return;
|
||||
}
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
json prompt;
|
||||
// an input prompt can string or a list of tokens (integer)
|
||||
std::vector<json> prompts;
|
||||
if (body.count("input") != 0) {
|
||||
prompt = body["input"];
|
||||
if (prompt.is_array()) {
|
||||
json data = json::array();
|
||||
|
||||
int i = 0;
|
||||
for (const json & elem : prompt) {
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
json embedding = json{
|
||||
{"embedding", json_value(result.data, "embedding", json::array())},
|
||||
{"index", i++},
|
||||
{"object", "embedding"}
|
||||
};
|
||||
|
||||
data.push_back(embedding);
|
||||
is_openai = true;
|
||||
if (body["input"].is_array()) {
|
||||
// support multiple prompts
|
||||
for (const json & elem : body["input"]) {
|
||||
prompts.push_back(elem);
|
||||
}
|
||||
|
||||
json result = format_embeddings_response_oaicompat(body, data);
|
||||
|
||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||
} else {
|
||||
// single input prompt
|
||||
prompts.push_back(body["input"]);
|
||||
}
|
||||
} else if (body.count("content") != 0) {
|
||||
// only support single prompt here
|
||||
std::string content = body["content"];
|
||||
prompts.push_back(content);
|
||||
} else {
|
||||
prompt = "";
|
||||
// TODO @ngxson : should return an error here
|
||||
prompts.push_back("");
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
// process all prompts
|
||||
json responses = json::array();
|
||||
for (auto & prompt : prompts) {
|
||||
// TODO @ngxson : maybe support multitask for this endpoint?
|
||||
// create and queue the task
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
json data = json::array({json{
|
||||
{"embedding", json_value(result.data, "embedding", json::array())},
|
||||
{"index", 0},
|
||||
{"object", "embedding"}
|
||||
}}
|
||||
);
|
||||
|
||||
json root = format_embeddings_response_oaicompat(body, data);
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
responses.push_back(result.data);
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
json root;
|
||||
if (is_openai) {
|
||||
json res_oai = json::array();
|
||||
int i = 0;
|
||||
for (auto & elem : responses) {
|
||||
res_oai.push_back(json{
|
||||
{"embedding", json_value(elem, "embedding", json::array())},
|
||||
{"index", i++},
|
||||
{"object", "embedding"}
|
||||
});
|
||||
}
|
||||
root = format_embeddings_response_oaicompat(body, res_oai);
|
||||
} else {
|
||||
root = responses[0];
|
||||
}
|
||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
};
|
||||
|
||||
//
|
||||
// Router
|
||||
//
|
||||
|
||||
// register static assets routes
|
||||
if (!sparams.public_path.empty()) {
|
||||
// Set the base directory for serving static files
|
||||
svr->set_base_dir(sparams.public_path);
|
||||
}
|
||||
|
||||
// using embedded static files
|
||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||
return false;
|
||||
};
|
||||
};
|
||||
|
||||
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
});
|
||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
|
||||
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
|
||||
|
||||
// register API routes
|
||||
svr->Get ("/health", handle_health);
|
||||
svr->Get ("/slots", handle_slots);
|
||||
svr->Get ("/metrics", handle_metrics);
|
||||
svr->Get ("/props", handle_props);
|
||||
svr->Get ("/v1/models", handle_models);
|
||||
svr->Post("/completion", handle_completions); // legacy
|
||||
svr->Post("/completions", handle_completions);
|
||||
svr->Post("/v1/completions", handle_completions);
|
||||
svr->Post("/chat/completions", handle_chat_completions);
|
||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
if (sparams.n_threads_http < 1) {
|
||||
// +2 threads for monitoring endpoints
|
||||
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue