merge embedding handlers

This commit is contained in:
ngxson 2024-03-08 14:55:54 +01:00
parent 07f120eb4e
commit 1866e18513
2 changed files with 278 additions and 279 deletions

View file

@ -42,7 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`.
- `--path`: path from which to serve static files (default examples/server/public)
- `--path`: path from which to serve static files (default: disabled)
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
- `--embedding`: Enable embedding extraction, Default: disabled.
@ -532,7 +532,7 @@ The HTTP server supports OAI-like API
### Extending or building alternative Web Front End
The default location for the static files is `examples/server/public`. You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
Read the documentation in `/completion.js` to see convenient ways to access llama.

View file

@ -2625,168 +2625,6 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Headers", "*");
});
svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
{
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
task.id_target = -1;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
const int n_idle_slots = result.data["idle"];
const int n_processing_slots = result.data["processing"];
json health = {
{"status", "ok"},
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}
};
res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) {
health["slots"] = result.data["slots"];
}
if (n_idle_slots == 0) {
health["status"] = "no slot available";
if (req.has_param("fail_on_no_slot")) {
res.status = 503; // HTTP Service Unavailable
}
}
res.set_content(health.dump(), "application/json");
break;
}
case SERVER_STATE_LOADING_MODEL:
{
res.set_content(R"({"status": "loading model"})", "application/json");
res.status = 503; // HTTP Service Unavailable
} break;
case SERVER_STATE_ERROR:
{
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
res.status = 500; // HTTP Internal Server Error
} break;
}
});
if (sparams.slots_endpoint) {
svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
res.set_content(result.data["slots"].dump(), "application/json");
res.status = 200; // HTTP OK
});
}
if (sparams.metrics_endpoint) {
svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
json data = result.data;
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
const uint64_t t_prompt_processing = data["t_prompt_processing"];
const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
const uint64_t t_tokens_generation = data["t_tokens_generation"];
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json {
{"counter", {{
{"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."},
{"value", data["n_prompt_tokens_processed_total"]}
}, {
{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
{"value", data["n_tokens_predicted_total"]}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
{"value", n_prompt_tokens_processed ? 1e3 / t_prompt_processing * n_prompt_tokens_processed : 0}
},{
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value", n_tokens_predicted ? 1e3 / t_tokens_generation * n_tokens_predicted : 0}
},{
{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * kv_cache_used_cells / params.n_ctx}
},{
{"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."},
{"value", data["kv_cache_tokens_count"]}
},{
{"name", "requests_processing"},
{"help", "Number of request processing."},
{"value", data["processing"]}
},{
{"name", "requests_deferred"},
{"help", "Number of request deferred."},
{"value", data["deferred"]}
}}}
};
std::stringstream prometheus;
for (const auto & el : all_metrics_def.items()) {
const auto & type = el.key();
const auto & metrics_def = el.value();
for (const auto & metric_def : metrics_def) {
const std::string name = metric_def["name"];
const std::string help = metric_def["help"];
auto value = json_value(metric_def, "value", 0);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
}
}
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
res.status = 200; // HTTP OK
});
}
svr.set_logger(log_server_request);
svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
@ -2858,10 +2696,14 @@ int main(int argc, char ** argv) {
}
}
// Middleware for API key validation
//
// Middlewares
//
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::set<std::string> protected_endpoints = {
"/props",
"/completion",
"/completions",
"/v1/completions",
@ -2913,25 +2755,169 @@ int main(int argc, char ** argv) {
return httplib::Server::HandlerResponse::Unhandled;
});
if (sparams.public_path.empty()) {
// using embedded static files
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
//
// Route handlers (or controllers)
//
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
{
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
task.id_target = -1;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
const int n_idle_slots = result.data["idle"];
const int n_processing_slots = result.data["processing"];
json health = {
{"status", "ok"},
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}
};
res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) {
health["slots"] = result.data["slots"];
}
if (n_idle_slots == 0) {
health["status"] = "no slot available";
if (req.has_param("fail_on_no_slot")) {
res.status = 503; // HTTP Service Unavailable
}
}
res.set_content(health.dump(), "application/json");
break;
}
case SERVER_STATE_LOADING_MODEL:
{
res.set_content(R"({"status": "loading model"})", "application/json");
res.status = 503; // HTTP Service Unavailable
} break;
case SERVER_STATE_ERROR:
{
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
res.status = 500; // HTTP Internal Server Error
} break;
}
};
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
res.set_content(result.data["slots"].dump(), "application/json");
res.status = 200; // HTTP OK
};
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.id_multi = -1;
task.id_target = -1;
task.type = SERVER_TASK_TYPE_METRICS;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
json data = result.data;
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
const uint64_t t_prompt_processing = data["t_prompt_processing"];
const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
const uint64_t t_tokens_generation = data["t_tokens_generation"];
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json {
{"counter", {{
{"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."},
{"value", data["n_prompt_tokens_processed_total"]}
}, {
{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
{"value", data["n_tokens_predicted_total"]}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
{"value", n_prompt_tokens_processed ? 1e3 / t_prompt_processing * n_prompt_tokens_processed : 0}
},{
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value", n_tokens_predicted ? 1e3 / t_tokens_generation * n_tokens_predicted : 0}
},{
{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * kv_cache_used_cells / params.n_ctx}
},{
{"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."},
{"value", data["kv_cache_tokens_count"]}
},{
{"name", "requests_processing"},
{"help", "Number of request processing."},
{"value", data["processing"]}
},{
{"name", "requests_deferred"},
{"help", "Number of request deferred."},
{"value", data["deferred"]}
}}}
};
svr.Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr.Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr.Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr.Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
} else {
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
}
std::stringstream prometheus;
svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
for (const auto & el : all_metrics_def.items()) {
const auto & type = el.key();
const auto & metrics_def = el.value();
for (const auto & metric_def : metrics_def) {
const std::string name = metric_def["name"];
const std::string help = metric_def["help"];
auto value = json_value(metric_def, "value", 0);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
}
}
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
res.status = 200; // HTTP OK
};
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", ctx_server.name_user.c_str() },
@ -2941,7 +2927,7 @@ int main(int argc, char ** argv) {
};
res.set_content(data.dump(), "application/json; charset=utf-8");
});
};
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
@ -3020,11 +3006,7 @@ int main(int argc, char ** argv) {
}
};
svr.Post("/completion", handle_completions); // legacy
svr.Post("/completions", handle_completions);
svr.Post("/v1/completions", handle_completions);
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = {
@ -3041,9 +3023,9 @@ int main(int argc, char ** argv) {
};
res.set_content(models.dump(), "application/json; charset=utf-8");
});
};
const auto chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
@ -3115,10 +3097,7 @@ int main(int argc, char ** argv) {
}
};
svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions);
svr.Post("/infill", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body);
@ -3177,13 +3156,9 @@ int main(int argc, char ** argv) {
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});
};
svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
return res.set_content("", "application/json; charset=utf-8");
});
svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
@ -3193,9 +3168,9 @@ int main(int argc, char ** argv) {
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
};
svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
@ -3207,9 +3182,9 @@ int main(int argc, char ** argv) {
const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
};
svr.Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
auto handle_embeddings = [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
@ -3218,94 +3193,118 @@ int main(int argc, char ** argv) {
}
const json body = json::parse(req.body);
bool is_openai = false;
json prompt;
if (body.count("content") != 0) {
prompt = body["content"];
} else {
prompt = "";
}
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
// send the result
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
});
svr.Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}
const json body = json::parse(req.body);
json prompt;
// an input prompt can string or a list of tokens (integer)
std::vector<json> prompts;
if (body.count("input") != 0) {
prompt = body["input"];
if (prompt.is_array()) {
json data = json::array();
int i = 0;
for (const json & elem : prompt) {
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
json embedding = json{
{"embedding", json_value(result.data, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
data.push_back(embedding);
is_openai = true;
if (body["input"].is_array()) {
// support multiple prompts
for (const json & elem : body["input"]) {
prompts.push_back(elem);
}
json result = format_embeddings_response_oaicompat(body, data);
return res.set_content(result.dump(), "application/json; charset=utf-8");
} else {
// single input prompt
prompts.push_back(body["input"]);
}
} else if (body.count("content") != 0) {
// only support single prompt here
std::string content = body["content"];
prompts.push_back(content);
} else {
prompt = "";
// TODO @ngxson : should return an error here
prompts.push_back("");
}
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id();
// process all prompts
json responses = json::array();
for (auto & prompt : prompts) {
// TODO @ngxson : maybe support multitask for this endpoint?
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
json data = json::array({json{
{"embedding", json_value(result.data, "embedding", json::array())},
{"index", 0},
{"object", "embedding"}
}}
);
json root = format_embeddings_response_oaicompat(body, data);
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
responses.push_back(json_value(result.data, "embedding", json::array()));
}
// write JSON response
json root;
if (is_openai) {
json res_oai = json::array();
int i = 0;
for (auto & elem : responses) {
res_oai.push_back(json{
{"embedding", elem},
{"index", i++},
{"object", "embedding"}
});
}
root = format_embeddings_response_oaicompat(body, res_oai);
} else {
root = responses[0];
}
return res.set_content(root.dump(), "application/json; charset=utf-8");
});
};
//
// Router
//
// register static assets routes
if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
}
// using embedded static files
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};
svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
return res.set_content("", "application/json; charset=utf-8");
});
svr.Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr.Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr.Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr.Get("/json-schema-to-grammar.mjs", handle_static_file(
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
// register API routes
svr.Post("/health", handle_health);
svr.Post("/props", handle_props);
svr.Get ("/v1/models", handle_models);
svr.Post("/completion", handle_completions); // legacy
svr.Post("/completions", handle_completions);
svr.Post("/v1/completions", handle_completions);
svr.Post("/chat/completions", handle_chat_completions);
svr.Post("/v1/chat/completions", handle_chat_completions);
svr.Post("/infill", handle_infill);
svr.Post("/embedding", handle_embeddings); // legacy
svr.Post("/embeddings", handle_embeddings);
svr.Post("/v1/embeddings", handle_embeddings);
svr.Post("/tokenize", handle_tokenize);
svr.Post("/detokenize", handle_detokenize);
if (sparams.slots_endpoint) {
svr.Get("/slots", handle_slots);
}
if (sparams.metrics_endpoint) {
svr.Get("/metrics", handle_metrics);
}
//
// Start the server
//
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);