Enable CORS requests on all routes

This commit is contained in:
StrangebytesDev 2024-02-28 13:36:17 -08:00
parent 87c91c0766
commit 6baa61c1e0
2 changed files with 4 additions and 13 deletions

View file

@ -2821,11 +2821,11 @@ int main(int argc, char **argv)
svr.set_default_headers({{"Server", "llama.cpp"}}); svr.set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight // Allow CORS requests on all routes
svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "*"); res.set_header("Access-Control-Allow-Headers", "*");
}); });
@ -3113,7 +3113,6 @@ int main(int argc, char **argv)
svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = { json data = {
{ "user_name", llama.name_user.c_str() }, { "user_name", llama.name_user.c_str() },
{ "assistant_name", llama.name_assistant.c_str() }, { "assistant_name", llama.name_assistant.c_str() },
@ -3125,7 +3124,6 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_api_key(req, res)) {
return; return;
} }
@ -3202,7 +3200,6 @@ int main(int argc, char **argv)
svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res) svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::time_t t = std::time(0); std::time_t t = std::time(0);
json models = { json models = {
@ -3222,7 +3219,6 @@ int main(int argc, char **argv)
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_api_key(req, res)) {
return; return;
} }
@ -3305,7 +3301,6 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_api_key(req, res)) {
return; return;
} }
@ -3375,7 +3370,6 @@ int main(int argc, char **argv)
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (body.count("content") != 0) if (body.count("content") != 0)
@ -3388,7 +3382,6 @@ int main(int argc, char **argv)
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;
if (body.count("tokens") != 0) if (body.count("tokens") != 0)
@ -3403,7 +3396,6 @@ int main(int argc, char **argv)
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
json prompt; json prompt;
if (body.count("content") != 0) if (body.count("content") != 0)
@ -3439,7 +3431,6 @@ int main(int argc, char **argv)
svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
json prompt; json prompt;

View file

@ -46,5 +46,5 @@ Feature: Security
| localhost | Access-Control-Allow-Origin | localhost | | localhost | Access-Control-Allow-Origin | localhost |
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
| origin | Access-Control-Allow-Credentials | true | | origin | Access-Control-Allow-Credentials | true |
| web.mydomain.fr | Access-Control-Allow-Methods | POST | | web.mydomain.fr | Access-Control-Allow-Methods | * |
| web.mydomain.fr | Access-Control-Allow-Headers | * | | web.mydomain.fr | Access-Control-Allow-Headers | * |