Removed the embedding api endpoint and associated code.

This commit is contained in:
digiwombat 2023-06-02 10:05:52 -04:00
parent 3ff27d30e3
commit 16e1c9813a

View file

@ -404,24 +404,6 @@ struct llama_server_context
return token_text; return token_text;
} }
std::vector<float> embedding(std::string content, int threads) {
content.insert(0, 1, ' ');
std::vector<llama_token> tokens = ::llama_tokenize(ctx, content, true);
if (!tokens.empty())
{
if (llama_eval(ctx, tokens.data(), tokens.size(), 0, threads))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
std::vector<float> embeddings_;
return embeddings_;
}
}
const int n_embd = llama_n_embd(ctx);
auto *const embeddings = llama_get_embeddings(ctx);
std::vector<float> embeddings_(embeddings, embeddings + n_embd);
return embeddings_;
}
}; };
using namespace httplib; using namespace httplib;
@ -440,7 +422,6 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params, con
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n"); fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
fprintf(stderr, " --embedding enable embedding mode\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (llama_mlock_supported()) if (llama_mlock_supported())
{ {
@ -521,10 +502,6 @@ void server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
} }
params.model_alias = argv[i]; params.model_alias = argv[i];
} }
else if (arg == "--embedding")
{
params.embedding = true;
}
else if (arg == "-h" || arg == "--help") else if (arg == "-h" || arg == "--help")
{ {
server_print_usage(argc, argv, default_params, default_sparams); server_print_usage(argc, argv, default_params, default_sparams);
@ -820,16 +797,6 @@ int main(int argc, char **argv)
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); }); { res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
svr.Post("/completion", [&llama](const Request &req, Response &res) { svr.Post("/completion", [&llama](const Request &req, Response &res) {
if (llama.params.embedding) {
json data = {
{"status", "error"},
{"reason", "To use completion function, disable embedding mode"}};
res.set_content(
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
"application/json");
res.status = 400;
return;
}
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
@ -956,38 +923,6 @@ int main(int argc, char **argv)
return res.set_content(data.dump(llama.json_indent), "application/json"); return res.set_content(data.dump(llama.json_indent), "application/json");
}); });
svr.Post("/embedding", [&llama](const Request &req, Response &res)
{
json data;
if(!llama.params.embedding) {
std::vector<float> empty;
data = {
{"embedding", empty},
{"error", "Server is not in embedding mode."} };
fprintf(stderr, "[llama-server] : You need to enable embedding mode by adding --embedding when launching the server.\n");
return res.set_content(data.dump(llama.json_indent), "application/json");
}
json body = json::parse(req.body);
if (body["content"].is_null()) {
std::vector<float> empty;
data = {
{"embedding", empty},
{"error", "The embedding content was not set."} };
fprintf(stderr, "[llama-server] : The embedding content was not set.\n");
}
else
{
std::string content = body["content"].get<std::string>();
data = {
{"embedding", llama.embedding(content, llama.params.n_threads) } };
}
return res.set_content(data.dump(llama.json_indent), "application/json");
});
if(params.embedding) {
fprintf(stderr, "NOTE: Embedding mode enabled. Completion is disabled in this mode.\n");
}
svr.set_logger([](const Request& req, const Response& res) { svr.set_logger([](const Request& req, const Response& res) {
json log = { json log = {
{ "status", res.status }, { "status", res.status },