Moved endpoints registration before listener and fixes

- Moved endpoints registration before HTTP listener starts
- Endpoints are returning the correct error when the model is loading or failed to load
- Server is exiting if failed to bind the port
This commit is contained in:
ManniX-ITA 2024-04-18 18:14:29 +02:00 committed by GitHub
parent 4de4670c83
commit 52a4d59747
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2282,6 +2282,17 @@ struct server_context {
{"size", llama_model_size (model)}, {"size", llama_model_size (model)},
}; };
} }
json empty_model_meta() const {
return json {
{"vocab_type", llama_vocab_type (0)},
{"n_vocab", llama_n_vocab (0)},
{"n_ctx_train", llama_n_ctx_train (0)},
{"n_embd", llama_n_embd (0)},
{"n_params", llama_model_n_params(0)},
{"size", llama_model_size (0)},
};
}
}; };
static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) { static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) {
@ -3062,92 +3073,6 @@ int main(int argc, char ** argv) {
} }
}; };
// register Health API routes
svr->Get ("/health", handle_health);
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};
//
// Router
//
// register static assets routes
if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
}
// using embedded static files
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
//
// Start the server
//
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]() {
if (!svr->listen_after_bind()) {
state.store(SERVER_STATE_ERROR);
return 1;
}
return 0;
});
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
}
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (sparams.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
{
json chat;
chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
chat.push_back({{"role", "user"}, {"content", "Hello"}});
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
LOG_INFO("chat template", {
{"chat_example", chat_example},
{"built_in", sparams.chat_template.empty()},
});
}
// //
// Middlewares // Middlewares
// //
@ -3208,6 +3133,52 @@ int main(int argc, char ** argv) {
return httplib::Server::HandlerResponse::Unhandled; return httplib::Server::HandlerResponse::Unhandled;
}); });
auto middleware_model_loading = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res, server_state current_state) {
static const std::set<std::string> invalid_endpoints = {
"/slots",
"/metrics",
"/props",
"/v1/models",
"/completion",
"/completions",
"/v1/completions",
"/chat/completions",
"/v1/chat/completions",
"/infill",
"/tokenize",
"/detokenize",
"/embedding",
"/embeddings",
"/v1/embeddings",
};
// If path is not in invalid_endpoints list, skip validation
if (invalid_endpoints.find(req.path) == invalid_endpoints.end()) {
return true;
}
switch (current_state) {
case SERVER_STATE_LOADING_MODEL:
{
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
return false;
} break;
case SERVER_STATE_ERROR:
{
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
return false;
} break;
}
return true;
};
svr->set_pre_routing_handler([&state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (!middleware_model_loading(req, res, current_state)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
// //
// Route handlers (or controllers) // Route handlers (or controllers)
// //
@ -3531,25 +3502,6 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = {
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", std::time(0)},
{"owned_by", "llamacpp"},
{"meta", model_meta}
},
}}
};
res.set_content(models.dump(), "application/json; charset=utf-8");
};
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
@ -3763,9 +3715,30 @@ int main(int argc, char ** argv) {
: responses[0]; : responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8"); return res.set_content(root.dump(), "application/json; charset=utf-8");
}; };
const auto handle_models = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json model_meta = ctx_server.model_meta();
json models = {
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", std::time(0)},
{"owned_by", "llamacpp"},
{"meta", model_meta}
},
}}
};
res.set_content(models.dump(), "application/json; charset=utf-8");
};
// register API routes // register API routes
svr->Get ("/health", handle_health);
svr->Get ("/slots", handle_slots); svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics); svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props); svr->Get ("/props", handle_props);
@ -3787,6 +3760,94 @@ int main(int argc, char ** argv) {
svr->Post("/slots/:id_slot", handle_slots_action); svr->Post("/slots/:id_slot", handle_slots_action);
} }
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};
//
// Router
//
// register static assets routes
if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
}
// using embedded static files
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
//
// Start the server
//
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]() {
if (!svr->listen_after_bind()) {
state.store(SERVER_STATE_ERROR);
return 1;
}
return 0;
});
if (state.load() == SERVER_STATE_ERROR) {
// HTTP Server could not bind the port
return 1;
}
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
}
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (sparams.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
{
json chat;
chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
chat.push_back({{"role", "user"}, {"content", "Hello"}});
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
LOG_INFO("chat template", {
{"chat_example", chat_example},
{"built_in", sparams.chat_template.empty()},
});
}
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_finish_multitask(std::bind( ctx_server.queue_tasks.on_finish_multitask(std::bind(