Implement credentialed CORS according to MDN

This commit is contained in:
Laura 2023-12-17 21:11:46 +01:00
parent bf571733fd
commit 0876952924

View file

@ -2710,9 +2710,15 @@ int main(int argc, char **argv)
return false;
};
svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
svr.set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight
svr.Options(R"(.*)", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
});
// this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response &res)
@ -2744,7 +2750,7 @@ int main(int argc, char **argv)
svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", llama.name_user.c_str() },
{ "assistant_name", llama.name_assistant.c_str() }
@ -2754,6 +2760,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
@ -2821,10 +2828,9 @@ int main(int argc, char **argv)
}
});
svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::time_t t = std::time(0);
json models = {
@ -2842,9 +2848,11 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});
// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
@ -2918,6 +2926,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
@ -2990,6 +2999,7 @@ int main(int argc, char **argv)
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::vector<llama_token> tokens;
if (body.count("content") != 0)
@ -3002,6 +3012,7 @@ int main(int argc, char **argv)
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::string content;
if (body.count("tokens") != 0)
@ -3016,6 +3027,7 @@ int main(int argc, char **argv)
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
json prompt;
if (body.count("content") != 0)