server : do not process embedding requests when disabled
This commit is contained in:
parent
f84809b7ad
commit
22ae1a622e
1 changed files with 88 additions and 76 deletions
|
@ -602,70 +602,78 @@ struct llama_server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.sparams.logit_bias.clear();
|
{
|
||||||
|
slot.sparams.logit_bias.clear();
|
||||||
|
|
||||||
if (json_value(data, "ignore_eos", false)) {
|
if (json_value(data, "ignore_eos", false)) {
|
||||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & logit_bias = data.find("logit_bias");
|
const auto & logit_bias = data.find("logit_bias");
|
||||||
if (logit_bias != data.end() && logit_bias->is_array()) {
|
if (logit_bias != data.end() && logit_bias->is_array()) {
|
||||||
const int n_vocab = llama_n_vocab(model);
|
const int n_vocab = llama_n_vocab(model);
|
||||||
for (const auto & el : *logit_bias) {
|
for (const auto & el : *logit_bias) {
|
||||||
if (el.is_array() && el.size() == 2) {
|
if (el.is_array() && el.size() == 2) {
|
||||||
float bias;
|
float bias;
|
||||||
if (el[1].is_number()) {
|
if (el[1].is_number()) {
|
||||||
bias = el[1].get<float>();
|
bias = el[1].get<float>();
|
||||||
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
||||||
bias = -INFINITY;
|
bias = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
|
||||||
|
|
||||||
if (el[0].is_number_integer()) {
|
|
||||||
llama_token tok = el[0].get<llama_token>();
|
|
||||||
if (tok >= 0 && tok < n_vocab) {
|
|
||||||
slot.sparams.logit_bias[tok] = bias;
|
|
||||||
}
|
}
|
||||||
} else if (el[0].is_string()) {
|
|
||||||
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
if (el[0].is_number_integer()) {
|
||||||
for (auto tok : toks) {
|
llama_token tok = el[0].get<llama_token>();
|
||||||
slot.sparams.logit_bias[tok] = bias;
|
if (tok >= 0 && tok < n_vocab) {
|
||||||
|
slot.sparams.logit_bias[tok] = bias;
|
||||||
|
}
|
||||||
|
} else if (el[0].is_string()) {
|
||||||
|
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||||
|
for (auto tok : toks) {
|
||||||
|
slot.sparams.logit_bias[tok] = bias;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.params.antiprompt.clear();
|
{
|
||||||
|
slot.params.antiprompt.clear();
|
||||||
|
|
||||||
const auto & stop = data.find("stop");
|
const auto & stop = data.find("stop");
|
||||||
if (stop != data.end() && stop->is_array()) {
|
if (stop != data.end() && stop->is_array()) {
|
||||||
for (const auto & word : *stop) {
|
for (const auto & word : *stop) {
|
||||||
if (!word.empty()) {
|
if (!word.empty()) {
|
||||||
slot.params.antiprompt.push_back(word);
|
slot.params.antiprompt.push_back(word);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & samplers_sequence = data.find("samplers");
|
{
|
||||||
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
|
const auto & samplers_sequence = data.find("samplers");
|
||||||
std::vector<std::string> sampler_names;
|
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
|
||||||
for (const auto & sampler_name : *samplers_sequence) {
|
std::vector<std::string> sampler_names;
|
||||||
if (sampler_name.is_string()) {
|
for (const auto & sampler_name : *samplers_sequence) {
|
||||||
sampler_names.emplace_back(sampler_name);
|
if (sampler_name.is_string()) {
|
||||||
|
sampler_names.emplace_back(sampler_name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
|
||||||
|
} else {
|
||||||
|
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
||||||
}
|
}
|
||||||
slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
|
|
||||||
} else {
|
|
||||||
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.ctx_sampling != nullptr) {
|
{
|
||||||
llama_sampling_free(slot.ctx_sampling);
|
if (slot.ctx_sampling != nullptr) {
|
||||||
|
llama_sampling_free(slot.ctx_sampling);
|
||||||
|
}
|
||||||
|
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||||
|
llama_set_rng_seed(ctx, slot.params.seed);
|
||||||
}
|
}
|
||||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
|
||||||
llama_set_rng_seed(ctx, slot.params.seed);
|
|
||||||
|
|
||||||
slot.command = LOAD_PROMPT;
|
slot.command = LOAD_PROMPT;
|
||||||
|
|
||||||
|
@ -1009,40 +1017,32 @@ struct llama_server_context {
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(model);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
if (!params.embedding) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
res.result_json = json {
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
if (embd == NULL) {
|
||||||
};
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
} else {
|
}
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
if (embd == NULL) {
|
||||||
if (embd == NULL) {
|
LOG_ERROR("failed to get embeddings", {
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
{"token", batch.token [i]},
|
||||||
}
|
|
||||||
|
|
||||||
if (embd == NULL) {
|
|
||||||
LOG_ERROR("failed to get embeddings", {
|
|
||||||
{"token", batch.token [i]},
|
|
||||||
{"seq_id", batch.seq_id[i][0]}
|
{"seq_id", batch.seq_id[i][0]}
|
||||||
});
|
});
|
||||||
|
|
||||||
res.result_json = json {
|
|
||||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
||||||
};
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
res.result_json = json {
|
res.result_json = json {
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res.result_json = json {
|
||||||
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
|
@ -1774,7 +1774,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
||||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
||||||
|
@ -2087,7 +2087,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
|
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
|
||||||
else { invalid_param = true; break; }
|
else { invalid_param = true; break; }
|
||||||
}
|
}
|
||||||
} else if (arg == "--embedding") {
|
} else if (arg == "--embedding" || arg == "--embeddings") {
|
||||||
params.embedding = true;
|
params.embedding = true;
|
||||||
} else if (arg == "-cb" || arg == "--cont-batching") {
|
} else if (arg == "-cb" || arg == "--cont-batching") {
|
||||||
params.cont_batching = true;
|
params.cont_batching = true;
|
||||||
|
@ -2860,6 +2860,7 @@ int main(int argc, char ** argv) {
|
||||||
svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) {
|
svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
if (body.count("content") != 0) {
|
if (body.count("content") != 0) {
|
||||||
tokens = llama.tokenize(body["content"], false);
|
tokens = llama.tokenize(body["content"], false);
|
||||||
|
@ -2871,6 +2872,7 @@ int main(int argc, char ** argv) {
|
||||||
svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) {
|
svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
std::string content;
|
std::string content;
|
||||||
if (body.count("tokens") != 0) {
|
if (body.count("tokens") != 0) {
|
||||||
const std::vector<llama_token> tokens = body["tokens"];
|
const std::vector<llama_token> tokens = body["tokens"];
|
||||||
|
@ -2881,8 +2883,13 @@ int main(int argc, char ** argv) {
|
||||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/embedding", [&llama](const httplib::Request & req, httplib::Response & res) {
|
svr.Post("/embedding", [¶ms, &llama](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
if (!params.embedding) {
|
||||||
|
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
json prompt;
|
json prompt;
|
||||||
|
@ -2906,8 +2913,13 @@ int main(int argc, char ** argv) {
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/v1/embeddings", [&llama](const httplib::Request & req, httplib::Response & res) {
|
svr.Post("/v1/embeddings", [¶ms, &llama](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
if (!params.embedding) {
|
||||||
|
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
json prompt;
|
json prompt;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue