diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 48ef8ff2a..1ebf32ee0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -202,6 +202,50 @@ struct server_slot { double t_prompt_processing; // ms double t_token_generation; // ms + bool check_server_health() { + // Check server health logic here + bool check_server_health(const std::string& server, const std::string& port) { + using namespace boost::asio; + io_service svc; + ip::tcp::socket socket(svc); + ip::tcp::resolver resolver(svc); + boost::system::error_code ec; + + // Try to connect + connect(socket, resolver.resolve({server, port}), ec); + if (ec) { + std::cout << "Connection failed: " << ec.message() << std::endl; + return false; + } + + // Send HTTP GET request to /health endpoint + std::string request = "GET /health HTTP/1.1\r\nHost: " + server + "\r\n\r\n"; + write(socket, buffer(request), ec); + if (ec) { + std::cout << "Write failed: " << ec.message() << std::endl; + return false; + } + + // Read the response + boost::asio::streambuf response; + read_until(socket, response, "\r\n", ec); + std::istream response_stream(&response); + std::string http_version; + unsigned int status_code; + response_stream >> http_version >> status_code; + + bool server_status_ok = false; + + // Check HTTP response status code + if (status_code == 200 || status_code == 500 || status_code == 503) { + server_status_ok = true; + } + + return server_status_ok + } + return true; // Return false if the server is unhealthy + } + void reset() { n_prompt_tokens = 0; generated_text = ""; @@ -3427,6 +3471,13 @@ int main(int argc, char ** argv) { }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (!check_server_health()) { + json error_response = {{"error", "Server is currently unavailable."}}; + res.set_content(error_response.dump(), "application/json; charset=utf-8"); + res.status = 503; // Service Unavailable + return; + } // added a server health checker to make sure completions only go through when there is a server running + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body);