diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cc8025fe..4286682d4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2282,6 +2282,7 @@ struct server_context { {"size", llama_model_size (model)}, }; } + }; static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) { @@ -3114,37 +3115,13 @@ int main(int argc, char ** argv) { return false; }; - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - auto middleware_model_loading = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res, server_state current_state) { - static const std::set invalid_endpoints = { - "/slots", - "/metrics", - "/props", - "/v1/models", - "/completion", - "/completions", - "/v1/completions", - "/chat/completions", - "/v1/chat/completions", - "/infill", - "/tokenize", - "/detokenize", - "/embedding", - "/embeddings", - "/v1/embeddings", - }; - // If path is not in invalid_endpoints list, skip validation - if (invalid_endpoints.find(req.path) == invalid_endpoints.end()) { + // If path is not an health check skip validation + if (req.path == "/health" || req.path == "/v1/health") { return true; } + switch (current_state) { case SERVER_STATE_LOADING_MODEL: { @@ -3160,9 +3137,10 @@ int main(int argc, char ** argv) { return true; }; - svr->set_pre_routing_handler([&state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) { + // register server middlewares + svr->set_pre_routing_handler([&middleware_validate_api_key, &state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); - if (!middleware_model_loading(req, res, current_state)) { + if (!middleware_model_loading(req, res, current_state) || !middleware_validate_api_key(req, res)) { return httplib::Server::HandlerResponse::Handled; } return httplib::Server::HandlerResponse::Unhandled; @@ -3705,7 +3683,7 @@ int main(int argc, char ** argv) { return res.set_content(root.dump(), "application/json; charset=utf-8"); }; - const auto handle_models = [&](const httplib::Request & req, httplib::Response & res) { + const auto handle_models = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json model_meta = ctx_server.model_meta(); @@ -3728,6 +3706,7 @@ int main(int argc, char ** argv) { // register API routes svr->Get ("/health", handle_health); + svr->Get ("/v1/health", handle_health); svr->Get ("/slots", handle_slots); svr->Get ("/metrics", handle_metrics); svr->Get ("/props", handle_props); @@ -3756,11 +3735,6 @@ int main(int argc, char ** argv) { }; }; - // - // Router - // - - // register static assets routes if (!sparams.public_path.empty()) { // Set the base directory for serving static files svr->set_base_dir(sparams.public_path); @@ -3794,12 +3768,7 @@ int main(int argc, char ** argv) { return 0; }); - - if (state.load() == SERVER_STATE_ERROR) { - // HTTP Server could not bind the port - return 1; - } - + // load the model if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); @@ -3811,8 +3780,6 @@ int main(int argc, char ** argv) { LOG_INFO("model loaded", {}); - const auto model_meta = ctx_server.model_meta(); - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (sparams.chat_template.empty()) { if (!ctx_server.validate_model_chat_template()) {