add --reranking argument

This commit is contained in:
Xuan Son Nguyen 2024-09-27 15:25:39 +02:00
parent 84b0af8355
commit 0d6f6a799f
7 changed files with 43 additions and 18 deletions

View file

@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0; params.kv_overrides.back().key[0] = 0;
} }
if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
return true; return true;
} }
@ -1750,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.embedding = true; params.embedding = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(llama_arg(
{"--reranking", "--rerank"},
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
[](gpt_params & params) {
params.reranking = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(llama_arg( add_opt(llama_arg(
{"--api-key"}, "KEY", {"--api-key"}, "KEY",
"API key to use for authentication (default: none)", "API key to use for authentication (default: none)",

View file

@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
if (params.reranking) {
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
}
cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v); cparams.type_v = kv_cache_type_from_str(params.cache_type_v);

View file

@ -271,6 +271,7 @@ struct gpt_params {
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embendings std::string embd_sep = "\n"; // separator of embendings
bool reranking = false; // enable reranking support on server
// server params // server params
int32_t port = 8080; // server listens on this network port int32_t port = 8080; // server listens on this network port

View file

@ -2888,8 +2888,8 @@ int main(int argc, char ** argv) {
}; };
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2949,8 +2949,8 @@ int main(int argc, char ** argv) {
// TODO: maybe merge this function with "handle_completions_generic" // TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -3074,6 +3074,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
// TODO: somehow clean up this checks in the future
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool is_openai = false; bool is_openai = false;
@ -3125,6 +3130,10 @@ int main(int argc, char ** argv) {
}; };
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
const json body = json::parse(req.body); const json body = json::parse(req.body);
// TODO: implement // TODO: implement
@ -3148,17 +3157,11 @@ int main(int argc, char ** argv) {
return; return;
} }
json documents; std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
if (body.count("documents") != 0) { if (documents.empty()) {
documents = body.at("documents");
if (!documents.is_array() || documents.size() == 0) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
} else {
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
// construct prompt object: array of ["query", "doc0", "doc1", ...] // construct prompt object: array of ["query", "doc0", "doc1", ...]
json prompt; json prompt;

View file

@ -15,7 +15,7 @@ Feature: llama.cpp server
And 128 as batch size And 128 as batch size
And 128 as ubatch size And 128 as ubatch size
And 512 KV cache size And 512 KV cache size
And embeddings extraction And enable embeddings endpoint
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy

View file

@ -12,7 +12,7 @@ Feature: llama.cpp server
And 512 as batch size And 512 as batch size
And 512 as ubatch size And 512 as ubatch size
And 512 KV cache size And 512 KV cache size
And embeddings extraction And enable reranking endpoint
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
@ -39,5 +39,4 @@ Feature: llama.cpp server
""" """
When reranking request When reranking request
Then reranking results are returned Then reranking results are returned
# TODO: this result make no sense, probably need a better model? Then reranking highest score is index 2 and lowest score is index 3
Then reranking highest score is index 3 and lowest score is index 0

View file

@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.server_api_key = None context.server_api_key = None
context.server_continuous_batching = False context.server_continuous_batching = False
context.server_embeddings = False context.server_embeddings = False
context.server_reranking = False
context.server_metrics = False context.server_metrics = False
context.server_process = None context.server_process = None
context.seed = None context.seed = None
@ -176,10 +177,13 @@ def step_server_continuous_batching(context):
context.server_continuous_batching = True context.server_continuous_batching = True
@step('embeddings extraction') @step('enable embeddings endpoint')
def step_server_embeddings(context): def step_server_embeddings(context):
context.server_embeddings = True context.server_embeddings = True
@step('enable reranking endpoint')
def step_server_reranking(context):
context.server_reranking = True
@step('prometheus compatible metrics exposed') @step('prometheus compatible metrics exposed')
def step_server_metrics(context): def step_server_metrics(context):
@ -1408,6 +1412,8 @@ def start_server_background(context):
server_args.append('--cont-batching') server_args.append('--cont-batching')
if context.server_embeddings: if context.server_embeddings:
server_args.append('--embedding') server_args.append('--embedding')
if context.server_reranking:
server_args.append('--reranking')
if context.server_metrics: if context.server_metrics:
server_args.append('--metrics') server_args.append('--metrics')
if context.model_alias: if context.model_alias: