From 52a4d5974739160c2bdb5e2cd0893a15cb19f3b4 Mon Sep 17 00:00:00 2001 From: ManniX-ITA <20623405+mann1x@users.noreply.github.com> Date: Thu, 18 Apr 2024 18:14:29 +0200 Subject: [PATCH] Moved endpoints registration before listener and fixes - Moved endpoints registration before HTTP listener starts - Endpoints are returning the correct error when the model is loading or failed to load - Server is exiting if failed to bind the port --- examples/server/server.cpp | 271 +++++++++++++++++++++++-------------- 1 file changed, 166 insertions(+), 105 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6dcdc7c89..cc455d8ce 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2282,6 +2282,17 @@ struct server_context { {"size", llama_model_size (model)}, }; } + + json empty_model_meta() const { + return json { + {"vocab_type", llama_vocab_type (0)}, + {"n_vocab", llama_n_vocab (0)}, + {"n_ctx_train", llama_n_ctx_train (0)}, + {"n_embd", llama_n_embd (0)}, + {"n_params", llama_model_n_params(0)}, + {"size", llama_model_size (0)}, + }; + } }; static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) { @@ -3062,92 +3073,6 @@ int main(int argc, char ** argv) { } }; - // register Health API routes - svr->Get ("/health", handle_health); - - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; - }; - - // - // Router - // - - // register static assets routes - if (!sparams.public_path.empty()) { - // Set the base directory for serving static files - svr->set_base_dir(sparams.public_path); - } - - // using embedded static files - svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); - svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); - svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); - svr->Get("/json-schema-to-grammar.mjs", handle_static_file( - json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); - - // - // Start the server - // - if (sparams.n_threads_http < 1) { - // +2 threads for monitoring endpoints - sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(sparams.n_threads_http); - svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; - - LOG_INFO("HTTP server listening", log_data); - - // run the HTTP server in a thread - see comment below - std::thread t([&]() { - if (!svr->listen_after_bind()) { - state.store(SERVER_STATE_ERROR); - return 1; - } - - return 0; - }); - - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } - - LOG_INFO("model loaded", {}); - - const auto model_meta = ctx_server.model_meta(); - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (sparams.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { - LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - sparams.chat_template = "chatml"; - } - } - - // print sample chat example to make it clear which template is used - { - json chat; - chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}}); - chat.push_back({{"role", "user"}, {"content", "Hello"}}); - chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); - chat.push_back({{"role", "user"}, {"content", "How are you?"}}); - - const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); - - LOG_INFO("chat template", { - {"chat_example", chat_example}, - {"built_in", sparams.chat_template.empty()}, - }); - } - // // Middlewares // @@ -3208,6 +3133,52 @@ int main(int argc, char ** argv) { return httplib::Server::HandlerResponse::Unhandled; }); + auto middleware_model_loading = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res, server_state current_state) { + static const std::set invalid_endpoints = { + "/slots", + "/metrics", + "/props", + "/v1/models", + "/completion", + "/completions", + "/v1/completions", + "/chat/completions", + "/v1/chat/completions", + "/infill", + "/tokenize", + "/detokenize", + "/embedding", + "/embeddings", + "/v1/embeddings", + }; + + // If path is not in invalid_endpoints list, skip validation + if (invalid_endpoints.find(req.path) == invalid_endpoints.end()) { + return true; + } + switch (current_state) { + case SERVER_STATE_LOADING_MODEL: + { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + return false; + } break; + case SERVER_STATE_ERROR: + { + res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); + return false; + } break; + } + return true; + }; + + svr->set_pre_routing_handler([&state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + if (!middleware_model_loading(req, res, current_state)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + // // Route handlers (or controllers) // @@ -3531,25 +3502,6 @@ int main(int argc, char ** argv) { } }; - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - - json models = { - {"object", "list"}, - {"data", { - { - {"id", params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta} - }, - }} - }; - - res.set_content(models.dump(), "application/json; charset=utf-8"); - }; - const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); @@ -3763,9 +3715,30 @@ int main(int argc, char ** argv) { : responses[0]; return res.set_content(root.dump(), "application/json; charset=utf-8"); }; + + const auto handle_models = [&](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json model_meta = ctx_server.model_meta(); + + json models = { + {"object", "list"}, + {"data", { + { + {"id", params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} + }, + }} + }; + + res.set_content(models.dump(), "application/json; charset=utf-8"); + }; // register API routes + svr->Get ("/health", handle_health); svr->Get ("/slots", handle_slots); svr->Get ("/metrics", handle_metrics); svr->Get ("/props", handle_props); @@ -3787,6 +3760,94 @@ int main(int argc, char ** argv) { svr->Post("/slots/:id_slot", handle_slots_action); } + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(content), len, mime_type); + return false; + }; + }; + + // + // Router + // + + // register static assets routes + if (!sparams.public_path.empty()) { + // Set the base directory for serving static files + svr->set_base_dir(sparams.public_path); + } + + // using embedded static files + svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); + svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); + svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); + svr->Get("/json-schema-to-grammar.mjs", handle_static_file( + json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); + + // + // Start the server + // + if (sparams.n_threads_http < 1) { + // +2 threads for monitoring endpoints + sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + log_data["n_threads_http"] = std::to_string(sparams.n_threads_http); + svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; + + LOG_INFO("HTTP server listening", log_data); + + // run the HTTP server in a thread - see comment below + std::thread t([&]() { + if (!svr->listen_after_bind()) { + state.store(SERVER_STATE_ERROR); + return 1; + } + + return 0; + }); + + if (state.load() == SERVER_STATE_ERROR) { + // HTTP Server could not bind the port + return 1; + } + + // load the model + if (!ctx_server.load_model(params)) { + state.store(SERVER_STATE_ERROR); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + } + + LOG_INFO("model loaded", {}); + + const auto model_meta = ctx_server.model_meta(); + + // if a custom chat template is not supplied, we will use the one that comes with the model (if any) + if (sparams.chat_template.empty()) { + if (!ctx_server.validate_model_chat_template()) { + LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + sparams.chat_template = "chatml"; + } + } + + // print sample chat example to make it clear which template is used + { + json chat; + chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}}); + chat.push_back({{"role", "user"}, {"content", "Hello"}}); + chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); + chat.push_back({{"role", "user"}, {"content", "How are you?"}}); + + const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); + + LOG_INFO("chat template", { + {"chat_example", chat_example}, + {"built_in", sparams.chat_template.empty()}, + }); + } + ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); ctx_server.queue_tasks.on_finish_multitask(std::bind(