add --reranking
argument
This commit is contained in:
parent
84b0af8355
commit
0d6f6a799f
7 changed files with 43 additions and 18 deletions
|
@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
|
|||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
|
||||
if (params.reranking && params.embedding) {
|
||||
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1750,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
||||
add_opt(llama_arg(
|
||||
{"--reranking", "--rerank"},
|
||||
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
|
||||
[](gpt_params & params) {
|
||||
params.reranking = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
|
||||
add_opt(llama_arg(
|
||||
{"--api-key"}, "KEY",
|
||||
"API key to use for authentication (default: none)",
|
||||
|
|
|
@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
|
||||
if (params.reranking) {
|
||||
cparams.embeddings = true;
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
}
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
|
||||
|
|
|
@ -271,6 +271,7 @@ struct gpt_params {
|
|||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||
std::string embd_sep = "\n"; // separator of embendings
|
||||
bool reranking = false; // enable reranking support on server
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
|
|
|
@ -2888,8 +2888,8 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2949,8 +2949,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3074,6 +3074,11 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: somehow clean up this checks in the future
|
||||
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
|
@ -3125,6 +3130,10 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
// TODO: implement
|
||||
|
@ -3148,17 +3157,11 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
|
||||
json documents;
|
||||
if (body.count("documents") != 0) {
|
||||
documents = body.at("documents");
|
||||
if (!documents.is_array() || documents.size() == 0) {
|
||||
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
||||
if (documents.empty()) {
|
||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
||||
json prompt;
|
||||
|
|
|
@ -15,7 +15,7 @@ Feature: llama.cpp server
|
|||
And 128 as batch size
|
||||
And 128 as ubatch size
|
||||
And 512 KV cache size
|
||||
And embeddings extraction
|
||||
And enable embeddings endpoint
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ Feature: llama.cpp server
|
|||
And 512 as batch size
|
||||
And 512 as ubatch size
|
||||
And 512 KV cache size
|
||||
And embeddings extraction
|
||||
And enable reranking endpoint
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
|
@ -39,5 +39,4 @@ Feature: llama.cpp server
|
|||
"""
|
||||
When reranking request
|
||||
Then reranking results are returned
|
||||
# TODO: this result make no sense, probably need a better model?
|
||||
Then reranking highest score is index 3 and lowest score is index 0
|
||||
Then reranking highest score is index 2 and lowest score is index 3
|
||||
|
|
|
@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.server_api_key = None
|
||||
context.server_continuous_batching = False
|
||||
context.server_embeddings = False
|
||||
context.server_reranking = False
|
||||
context.server_metrics = False
|
||||
context.server_process = None
|
||||
context.seed = None
|
||||
|
@ -176,10 +177,13 @@ def step_server_continuous_batching(context):
|
|||
context.server_continuous_batching = True
|
||||
|
||||
|
||||
@step('embeddings extraction')
|
||||
@step('enable embeddings endpoint')
|
||||
def step_server_embeddings(context):
|
||||
context.server_embeddings = True
|
||||
|
||||
@step('enable reranking endpoint')
|
||||
def step_server_reranking(context):
|
||||
context.server_reranking = True
|
||||
|
||||
@step('prometheus compatible metrics exposed')
|
||||
def step_server_metrics(context):
|
||||
|
@ -1408,6 +1412,8 @@ def start_server_background(context):
|
|||
server_args.append('--cont-batching')
|
||||
if context.server_embeddings:
|
||||
server_args.append('--embedding')
|
||||
if context.server_reranking:
|
||||
server_args.append('--reranking')
|
||||
if context.server_metrics:
|
||||
server_args.append('--metrics')
|
||||
if context.model_alias:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue