diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b2a047414..f4fb0ac10 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -203,50 +203,6 @@ struct server_slot { double t_prompt_processing; // ms double t_token_generation; // ms - - bool check_server_health() { - // Check server health logic here - bool check_server_health(const std::string& server, const std::string& port) { - using namespace boost::asio; - io_service svc; - ip::tcp::socket socket(svc); - ip::tcp::resolver resolver(svc); - boost::system::error_code ec; - - // Try to connect - connect(socket, resolver.resolve({server, port}), ec); - if (ec) { - std::cout << "Connection failed: " << ec.message() << std::endl; - return false; - } - - // Send HTTP GET request to /health endpoint - std::string request = "GET /health HTTP/1.1\r\nHost: " + server + "\r\n\r\n"; - write(socket, buffer(request), ec); - if (ec) { - std::cout << "Write failed: " << ec.message() << std::endl; - return false; - } - - // Read the response - boost::asio::streambuf response; - read_until(socket, response, "\r\n", ec); - std::istream response_stream(&response); - std::string http_version; - unsigned int status_code; - response_stream >> http_version >> status_code; - - bool server_status_ok = false; - - // Check HTTP response status code - if (status_code == 200 || status_code == 500 || status_code == 503) { - server_status_ok = true; - } - - return server_status_ok - } - return true; // Return false if the server is unhealthy - } void reset() { n_prompt_tokens = 0; @@ -509,6 +465,50 @@ struct server_queue { condition_tasks.notify_all(); } + //adding server health checking + std::string hostname_health = "127.0.0.1"; + std::string port_health = "8080"; + + bool check_server_health(const std::string& server, const std::string& port) { + using namespace boost::asio; + io_service svc; + ip::tcp::socket socket(svc); + ip::tcp::resolver resolver(svc); + boost::system::error_code ec; + + // Try to connect + connect(socket, resolver.resolve({server, port}), ec); + if (ec) { + std::cout << "Connection failed: " << ec.message() << std::endl; + return false; + } + + // Send HTTP GET request to /health endpoint + std::string request = "GET /health HTTP/1.1\r\nHost: " + server + "\r\n\r\n"; + write(socket, buffer(request), ec); + if (ec) { + std::cout << "Write failed: " << ec.message() << std::endl; + return false; + } + + // Read the response + boost::asio::streambuf response; + read_until(socket, response, "\r\n", ec); + std::istream response_stream(&response); + std::string http_version; + unsigned int status_code; + response_stream >> http_version >> status_code; + + bool server_status_ok = false; + + // Check HTTP response status code + if (status_code == 200 || status_code == 500 || status_code == 503) { + server_status_ok = true; + } + + return server_status_ok + } + /** * Main loop consists of these steps: * - Wait until a new task arrives @@ -520,6 +520,13 @@ struct server_queue { running = true; while (true) { + bool health_check = check_server_health(hostname_health, port_health); + if (health_check == false) { + while(!queue_tasks.empty()) { + queue_tasks.erase(queue_tasks.begin()); + } + break; + } LOG_VERBOSE("new task may arrive", {}); while (true) { @@ -3473,13 +3480,6 @@ int main(int argc, char ** argv) { }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - if (!check_server_health()) { - json error_response = {{"error", "Server is currently unavailable."}}; - res.set_content(error_response.dump(), "application/json; charset=utf-8"); - res.status = 503; // Service Unavailable - return; - } // added a server health checker to make sure completions only go through when there is a server running - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body);