server : add optional API Key Authentication example (#4441)
* Add API key authentication for enhanced server-client security * server : to snake_case --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							parent
							
								
									ee4725a686
								
							
						
					
					
						commit
						88ae8952b6
					
				
					 3 changed files with 70 additions and 10 deletions
				
			
		|  | @ -36,6 +36,7 @@ using json = nlohmann::json; | |||
| struct server_params | ||||
| { | ||||
|     std::string hostname = "127.0.0.1"; | ||||
|     std::string api_key; | ||||
|     std::string public_path = "examples/server/public"; | ||||
|     int32_t port = 8080; | ||||
|     int32_t read_timeout = 600; | ||||
|  | @ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | |||
|     printf("  --host                ip address to listen (default  (default: %s)\n", sparams.hostname.c_str()); | ||||
|     printf("  --port PORT           port to listen (default  (default: %d)\n", sparams.port); | ||||
|     printf("  --path PUBLIC_PATH    path from which to serve static files (default %s)\n", sparams.public_path.c_str()); | ||||
|     printf("  --api-key API_KEY     optional api key to enhance server security. If set, requests must include this key for access.\n"); | ||||
|     printf("  -to N, --timeout N    server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); | ||||
|     printf("  --embedding           enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); | ||||
|     printf("  -np N, --parallel N   number of slots for process requests (default: %d)\n", params.n_parallel); | ||||
|  | @ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |||
|             } | ||||
|             sparams.public_path = argv[i]; | ||||
|         } | ||||
|         else if (arg == "--api-key") | ||||
|         { | ||||
|             if (++i >= argc) | ||||
|             { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             sparams.api_key = argv[i]; | ||||
|         } | ||||
|         else if (arg == "--timeout" || arg == "-to") | ||||
|         { | ||||
|             if (++i >= argc) | ||||
|  | @ -2669,6 +2680,32 @@ int main(int argc, char **argv) | |||
| 
 | ||||
|     httplib::Server svr; | ||||
| 
 | ||||
|     // Middleware for API key validation
 | ||||
|     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { | ||||
|         // If API key is not set, skip validation
 | ||||
|         if (sparams.api_key.empty()) { | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         // Check for API key in the header
 | ||||
|         auto auth_header = req.get_header_value("Authorization"); | ||||
|         std::string prefix = "Bearer "; | ||||
|         if (auth_header.substr(0, prefix.size()) == prefix) { | ||||
|             std::string received_api_key = auth_header.substr(prefix.size()); | ||||
|             if (received_api_key == sparams.api_key) { | ||||
|                 return true; // API key is valid
 | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // API key is invalid or not provided
 | ||||
|         res.set_content("Unauthorized: Invalid API Key", "text/plain"); | ||||
|         res.status = 401; // Unauthorized
 | ||||
| 
 | ||||
|         LOG_WARNING("Unauthorized: Invalid API Key", {}); | ||||
| 
 | ||||
|         return false; | ||||
|     }; | ||||
| 
 | ||||
|     svr.set_default_headers({{"Server", "llama.cpp"}, | ||||
|                              {"Access-Control-Allow-Origin", "*"}, | ||||
|                              {"Access-Control-Allow-Headers", "content-type"}}); | ||||
|  | @ -2711,8 +2748,11 @@ int main(int argc, char **argv) | |||
|                 res.set_content(data.dump(), "application/json"); | ||||
|             }); | ||||
| 
 | ||||
|     svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) | ||||
|     svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) | ||||
|             { | ||||
|                 if (!validate_api_key(req, res)) { | ||||
|                     return; | ||||
|                 } | ||||
|                 json data = json::parse(req.body); | ||||
|                 const int task_id = llama.request_completion(data, false, false, -1); | ||||
|                 if (!json_value(data, "stream", false)) { | ||||
|  | @ -2799,8 +2839,11 @@ int main(int argc, char **argv) | |||
|             }); | ||||
| 
 | ||||
|     // TODO: add mount point without "/v1" prefix -- how?
 | ||||
|     svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res) | ||||
|     svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) | ||||
|             { | ||||
|                 if (!validate_api_key(req, res)) { | ||||
|                     return; | ||||
|                 } | ||||
|                 json data = oaicompat_completion_params_parse(json::parse(req.body)); | ||||
| 
 | ||||
|                 const int task_id = llama.request_completion(data, false, false, -1); | ||||
|  | @ -2869,8 +2912,11 @@ int main(int argc, char **argv) | |||
|                 } | ||||
|             }); | ||||
| 
 | ||||
|     svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) | ||||
|     svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) | ||||
|             { | ||||
|                 if (!validate_api_key(req, res)) { | ||||
|                     return; | ||||
|                 } | ||||
|                 json data = json::parse(req.body); | ||||
|                 const int task_id = llama.request_completion(data, true, false, -1); | ||||
|                 if (!json_value(data, "stream", false)) { | ||||
|  | @ -3005,11 +3051,15 @@ int main(int argc, char **argv) | |||
| 
 | ||||
|     svr.set_error_handler([](const httplib::Request &, httplib::Response &res) | ||||
|             { | ||||
|                 if (res.status == 401) | ||||
|                 { | ||||
|                     res.set_content("Unauthorized", "text/plain"); | ||||
|                 } | ||||
|                 if (res.status == 400) | ||||
|                 { | ||||
|                     res.set_content("Invalid request", "text/plain"); | ||||
|                 } | ||||
|                 else if (res.status != 500) | ||||
|                 else if (res.status == 404) | ||||
|                 { | ||||
|                     res.set_content("File Not Found", "text/plain"); | ||||
|                     res.status = 404; | ||||
|  | @ -3032,11 +3082,15 @@ int main(int argc, char **argv) | |||
|     // to make it ctrl+clickable:
 | ||||
|     LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); | ||||
| 
 | ||||
|     LOG_INFO("HTTP server listening", { | ||||
|                                           {"hostname", sparams.hostname}, | ||||
|                                           {"port", sparams.port}, | ||||
|                                       }); | ||||
|     std::unordered_map<std::string, std::string> log_data; | ||||
|     log_data["hostname"] = sparams.hostname; | ||||
|     log_data["port"] = std::to_string(sparams.port); | ||||
| 
 | ||||
|     if (!sparams.api_key.empty()) { | ||||
|         log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); | ||||
|     } | ||||
| 
 | ||||
|     LOG_INFO("HTTP server listening", log_data); | ||||
|     // run the HTTP server in a thread - see comment below
 | ||||
|     std::thread t([&]() | ||||
|             { | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue