server : do not process embedding requests when disabled

This commit is contained in:
Georgi Gerganov 2024-03-05 18:58:26 +02:00
parent f84809b7ad
commit 22ae1a622e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -602,6 +602,7 @@ struct llama_server_context {
} }
} }
{
slot.sparams.logit_bias.clear(); slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false)) { if (json_value(data, "ignore_eos", false)) {
@ -636,7 +637,9 @@ struct llama_server_context {
} }
} }
} }
}
{
slot.params.antiprompt.clear(); slot.params.antiprompt.clear();
const auto & stop = data.find("stop"); const auto & stop = data.find("stop");
@ -647,7 +650,9 @@ struct llama_server_context {
} }
} }
} }
}
{
const auto & samplers_sequence = data.find("samplers"); const auto & samplers_sequence = data.find("samplers");
if (samplers_sequence != data.end() && samplers_sequence->is_array()) { if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
std::vector<std::string> sampler_names; std::vector<std::string> sampler_names;
@ -660,12 +665,15 @@ struct llama_server_context {
} else { } else {
slot.sparams.samplers_sequence = default_sparams.samplers_sequence; slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
} }
}
{
if (slot.ctx_sampling != nullptr) { if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling); llama_sampling_free(slot.ctx_sampling);
} }
slot.ctx_sampling = llama_sampling_init(slot.sparams); slot.ctx_sampling = llama_sampling_init(slot.sparams);
llama_set_rng_seed(ctx, slot.params.seed); llama_set_rng_seed(ctx, slot.params.seed);
}
slot.command = LOAD_PROMPT; slot.command = LOAD_PROMPT;
@ -1009,13 +1017,6 @@ struct llama_server_context {
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
if (!params.embedding) {
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
res.result_json = json {
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
} else {
for (int i = 0; i < batch.n_tokens; ++i) { for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue; continue;
@ -1043,7 +1044,6 @@ struct llama_server_context {
{"embedding", std::vector<float>(embd, embd + n_embd)}, {"embedding", std::vector<float>(embd, embd + n_embd)},
}; };
} }
}
queue_results.send(res); queue_results.send(res);
} }
@ -1774,7 +1774,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n");
@ -2087,7 +2087,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} }
} else if (arg == "--embedding") { } else if (arg == "--embedding" || arg == "--embeddings") {
params.embedding = true; params.embedding = true;
} else if (arg == "-cb" || arg == "--cont-batching") { } else if (arg == "-cb" || arg == "--cont-batching") {
params.cont_batching = true; params.cont_batching = true;
@ -2860,6 +2860,7 @@ int main(int argc, char ** argv) {
svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) { svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (body.count("content") != 0) { if (body.count("content") != 0) {
tokens = llama.tokenize(body["content"], false); tokens = llama.tokenize(body["content"], false);
@ -2871,6 +2872,7 @@ int main(int argc, char ** argv) {
svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) { svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;
if (body.count("tokens") != 0) { if (body.count("tokens") != 0) {
const std::vector<llama_token> tokens = body["tokens"]; const std::vector<llama_token> tokens = body["tokens"];
@ -2881,8 +2883,13 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), "application/json; charset=utf-8");
}); });
svr.Post("/embedding", [&llama](const httplib::Request & req, httplib::Response & res) { svr.Post("/embedding", [&params, &llama](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}
const json body = json::parse(req.body); const json body = json::parse(req.body);
json prompt; json prompt;
@ -2906,8 +2913,13 @@ int main(int argc, char ** argv) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}); });
svr.Post("/v1/embeddings", [&llama](const httplib::Request & req, httplib::Response & res) { svr.Post("/v1/embeddings", [&params, &llama](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}
const json body = json::parse(req.body); const json body = json::parse(req.body);
json prompt; json prompt;