made ServerState atomic and turned two-line spaces into one-line

This commit is contained in:
Behnam M 2024-01-10 14:39:37 -05:00 committed by GitHub
parent 29a5a94d5c
commit efd36b0270
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -26,6 +26,7 @@
#include <mutex> #include <mutex>
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
#include <atomic>
#ifndef SERVER_VERBOSE #ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1 #define SERVER_VERBOSE 1
@ -146,15 +147,12 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
// parallel // parallel
// //
enum ServerState { enum ServerState {
LOADING_MODEL, // Server is starting up, model not fully loaded yet LOADING_MODEL, // Server is starting up, model not fully loaded yet
READY, // Server is ready and model is loaded READY, // Server is ready and model is loaded
ERROR // An error occurred, load_model failed ERROR // An error occurred, load_model failed
}; };
enum task_type { enum task_type {
COMPLETION_TASK, COMPLETION_TASK,
CANCEL_TASK CANCEL_TASK
@ -2462,7 +2460,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
} }
static std::string random_string() static std::string random_string()
{ {
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
@ -2801,14 +2798,15 @@ int main(int argc, char **argv)
httplib::Server svr; httplib::Server svr;
ServerState server_state = LOADING_MODEL; std::atomic<ServerState> server_state{LOADING_MODEL};
svr.set_default_headers({{"Server", "llama.cpp"}, svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"}, {"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}}); {"Access-Control-Allow-Headers", "content-type"}});
svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
switch(server_state) { ServerState current_state = server_state.load();
switch(current_state) {
case READY: case READY:
res.set_content(R"({"status": "ok"})", "application/json"); res.set_content(R"({"status": "ok"})", "application/json");
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
@ -2887,29 +2885,27 @@ int main(int argc, char **argv)
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
} }
LOG_INFO("HTTP server listening", log_data); LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below // run the HTTP server in a thread - see comment below
std::thread t([&]() std::thread t([&]()
{ {
if (!svr.listen_after_bind()) if (!svr.listen_after_bind())
{ {
server_state = ERROR; server_state.store(ERROR);
return 1; return 1;
} }
return 0; return 0;
}); });
// load the model // load the model
if (!llama.load_model(params)) if (!llama.load_model(params))
{ {
server_state = ERROR; server_state.store(ERROR);
return 1; return 1;
} else { } else {
llama.initialize(); llama.initialize();
server_state = READY; server_state.store(READY);
} }
// Middleware for API key validation // Middleware for API key validation
@ -2938,7 +2934,6 @@ int main(int argc, char **argv)
return false; return false;
}; };
// this is only called if no index.html is found in the public --path // this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response &res) svr.Get("/", [](const httplib::Request &, httplib::Response &res)
{ {
@ -3046,7 +3041,6 @@ int main(int argc, char **argv)
} }
}); });
svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res) svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
{ {
std::time_t t = std::time(0); std::time_t t = std::time(0);