ManniX-ITA 2024-04-19 19:03:23 +02:00 committed by GitHub
parent b9613ef11a
commit 61b483d3a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2282,6 +2282,7 @@ struct server_context {
{"size", llama_model_size (model)},
};
}
};
static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) {
@ -3114,37 +3115,13 @@ int main(int argc, char ** argv) {
return false;
};
// register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
auto middleware_model_loading = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res, server_state current_state) {
static const std::set<std::string> invalid_endpoints = {
"/slots",
"/metrics",
"/props",
"/v1/models",
"/completion",
"/completions",
"/v1/completions",
"/chat/completions",
"/v1/chat/completions",
"/infill",
"/tokenize",
"/detokenize",
"/embedding",
"/embeddings",
"/v1/embeddings",
};
// If path is not in invalid_endpoints list, skip validation
if (invalid_endpoints.find(req.path) == invalid_endpoints.end()) {
// If path is not an health check skip validation
if (req.path == "/health" || req.path == "/v1/health") {
return true;
}
switch (current_state) {
case SERVER_STATE_LOADING_MODEL:
{
@ -3160,9 +3137,10 @@ int main(int argc, char ** argv) {
return true;
};
svr->set_pre_routing_handler([&state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) {
// register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key, &state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (!middleware_model_loading(req, res, current_state)) {
if (!middleware_model_loading(req, res, current_state) || !middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
@ -3705,7 +3683,7 @@ int main(int argc, char ** argv) {
return res.set_content(root.dump(), "application/json; charset=utf-8");
};
const auto handle_models = [&](const httplib::Request & req, httplib::Response & res) {
const auto handle_models = [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json model_meta = ctx_server.model_meta();
@ -3728,6 +3706,7 @@ int main(int argc, char ** argv) {
// register API routes
svr->Get ("/health", handle_health);
svr->Get ("/v1/health", handle_health);
svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
@ -3756,11 +3735,6 @@ int main(int argc, char ** argv) {
};
};
//
// Router
//
// register static assets routes
if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
@ -3795,11 +3769,6 @@ int main(int argc, char ** argv) {
return 0;
});
if (state.load() == SERVER_STATE_ERROR) {
// HTTP Server could not bind the port
return 1;
}
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
@ -3811,8 +3780,6 @@ int main(int argc, char ** argv) {
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (sparams.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {