From 22ae1a622ead2cc55392780d613705fc660357be Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Mar 2024 18:58:26 +0200 Subject: [PATCH] server : do not process embedding requests when disabled --- examples/server/server.cpp | 164 ++++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 76 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 898364617..64d418f8f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -602,70 +602,78 @@ struct llama_server_context { } } - slot.sparams.logit_bias.clear(); + { + slot.sparams.logit_bias.clear(); - if (json_value(data, "ignore_eos", false)) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } + if (json_value(data, "ignore_eos", false)) { + slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); - for (const auto & el : *logit_bias) { - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias[tok] = bias; + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_n_vocab(model); + for (const auto & el : *logit_bias) { + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; } - } else if (el[0].is_string()) { - auto toks = llama_tokenize(model, el[0].get(), false); - for (auto tok : toks) { - slot.sparams.logit_bias[tok] = bias; + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + slot.sparams.logit_bias[tok] = bias; + } + } else if (el[0].is_string()) { + auto toks = llama_tokenize(model, el[0].get(), false); + for (auto tok : toks) { + slot.sparams.logit_bias[tok] = bias; + } } } } } } - slot.params.antiprompt.clear(); + { + slot.params.antiprompt.clear(); - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - slot.params.antiprompt.push_back(word); + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + slot.params.antiprompt.push_back(word); + } } } } - const auto & samplers_sequence = data.find("samplers"); - if (samplers_sequence != data.end() && samplers_sequence->is_array()) { - std::vector sampler_names; - for (const auto & sampler_name : *samplers_sequence) { - if (sampler_name.is_string()) { - sampler_names.emplace_back(sampler_name); + { + const auto & samplers_sequence = data.find("samplers"); + if (samplers_sequence != data.end() && samplers_sequence->is_array()) { + std::vector sampler_names; + for (const auto & sampler_name : *samplers_sequence) { + if (sampler_name.is_string()) { + sampler_names.emplace_back(sampler_name); + } } + slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); + } else { + slot.sparams.samplers_sequence = default_sparams.samplers_sequence; } - slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); - } else { - slot.sparams.samplers_sequence = default_sparams.samplers_sequence; } - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + { + if (slot.ctx_sampling != nullptr) { + llama_sampling_free(slot.ctx_sampling); + } + slot.ctx_sampling = llama_sampling_init(slot.sparams); + llama_set_rng_seed(ctx, slot.params.seed); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); - llama_set_rng_seed(ctx, slot.params.seed); slot.command = LOAD_PROMPT; @@ -1009,40 +1017,32 @@ struct llama_server_context { const int n_embd = llama_n_embd(model); - if (!params.embedding) { - LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}}); + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } - res.result_json = json { - {"embedding", std::vector(n_embd, 0.0f)}, - }; - } else { - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - LOG_ERROR("failed to get embeddings", { - {"token", batch.token [i]}, + if (embd == NULL) { + LOG_ERROR("failed to get embeddings", { + {"token", batch.token [i]}, {"seq_id", batch.seq_id[i][0]} - }); - - res.result_json = json { - {"embedding", std::vector(n_embd, 0.0f)}, - }; - - continue; - } + }); res.result_json = json { - {"embedding", std::vector(embd, embd + n_embd)}, + {"embedding", std::vector(n_embd, 0.0f)}, }; + + continue; } + + res.result_json = json { + {"embedding", std::vector(embd, embd + n_embd)}, + }; } queue_results.send(res); @@ -1774,7 +1774,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); - printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); + printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); @@ -2087,7 +2087,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else { invalid_param = true; break; } } - } else if (arg == "--embedding") { + } else if (arg == "--embedding" || arg == "--embeddings") { params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; @@ -2860,6 +2860,7 @@ int main(int argc, char ** argv) { svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); + std::vector tokens; if (body.count("content") != 0) { tokens = llama.tokenize(body["content"], false); @@ -2871,6 +2872,7 @@ int main(int argc, char ** argv) { svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); + std::string content; if (body.count("tokens") != 0) { const std::vector tokens = body["tokens"]; @@ -2881,8 +2883,13 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/embedding", [&llama](const httplib::Request & req, httplib::Response & res) { + svr.Post("/embedding", [¶ms, &llama](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!params.embedding) { + res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); + return; + } + const json body = json::parse(req.body); json prompt; @@ -2906,8 +2913,13 @@ int main(int argc, char ** argv) { return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); - svr.Post("/v1/embeddings", [&llama](const httplib::Request & req, httplib::Response & res) { + svr.Post("/v1/embeddings", [¶ms, &llama](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!params.embedding) { + res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); + return; + } + const json body = json::parse(req.body); json prompt;