refactor static file handler

This commit is contained in:
ngxson 2024-03-08 12:06:13 +01:00
parent 581ed5c4fe
commit d8a8bd4cc6

View file

@ -112,7 +112,7 @@ struct server_params {
int32_t n_threads_http = -1; int32_t n_threads_http = -1;
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public"; std::string public_path = "";
std::string chat_template = ""; std::string chat_template = "";
std::string system_prompt = ""; std::string system_prompt = "";
@ -2092,7 +2092,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
@ -2826,9 +2826,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
std::unordered_map<std::string, std::string> log_data; std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname; log_data["hostname"] = sparams.hostname;
@ -2887,29 +2884,22 @@ int main(int argc, char ** argv) {
return false; return false;
}; };
// this is only called if no index.html is found in the public --path if (sparams.public_path.empty()) {
svr.Get("/", [](const httplib::Request &, httplib::Response & res) { auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8"); return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false; return false;
}); };
};
// this is only called if no index.js is found in the public --path svr.Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) { svr.Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8"); svr.Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
return false; svr.Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
}); } else {
// Set the base directory for serving static files
// this is only called if no index.html is found in the public --path svr.set_base_dir(sparams.public_path);
svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { }
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
return false;
});
// this is only called if no index.html is found in the public --path
svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
return false;
});
svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));