add --reranking
argument
This commit is contained in:
parent
84b0af8355
commit
0d6f6a799f
7 changed files with 43 additions and 18 deletions
|
@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
|
||||||
params.kv_overrides.back().key[0] = 0;
|
params.kv_overrides.back().key[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.reranking && params.embedding) {
|
||||||
|
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1750,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
params.embedding = true;
|
params.embedding = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
||||||
|
add_opt(llama_arg(
|
||||||
|
{"--reranking", "--rerank"},
|
||||||
|
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
|
||||||
|
[](gpt_params & params) {
|
||||||
|
params.reranking = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--api-key"}, "KEY",
|
{"--api-key"}, "KEY",
|
||||||
"API key to use for authentication (default: none)",
|
"API key to use for authentication (default: none)",
|
||||||
|
|
|
@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
cparams.flash_attn = params.flash_attn;
|
cparams.flash_attn = params.flash_attn;
|
||||||
cparams.no_perf = params.no_perf;
|
cparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
if (params.reranking) {
|
||||||
|
cparams.embeddings = true;
|
||||||
|
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||||
|
}
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||||
|
|
||||||
|
|
|
@ -271,6 +271,7 @@ struct gpt_params {
|
||||||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||||
std::string embd_sep = "\n"; // separator of embendings
|
std::string embd_sep = "\n"; // separator of embendings
|
||||||
|
bool reranking = false; // enable reranking support on server
|
||||||
|
|
||||||
// server params
|
// server params
|
||||||
int32_t port = 8080; // server listens on this network port
|
int32_t port = 8080; // server listens on this network port
|
||||||
|
|
|
@ -2888,8 +2888,8 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2949,8 +2949,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
// TODO: maybe merge this function with "handle_completions_generic"
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3074,6 +3074,11 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
// TODO: somehow clean up this checks in the future
|
||||||
|
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
|
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool is_openai = false;
|
bool is_openai = false;
|
||||||
|
|
||||||
|
@ -3125,6 +3130,10 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (!ctx_server.params.reranking) {
|
||||||
|
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
// TODO: implement
|
// TODO: implement
|
||||||
|
@ -3148,17 +3157,11 @@ int main(int argc, char ** argv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json documents;
|
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
||||||
if (body.count("documents") != 0) {
|
if (documents.empty()) {
|
||||||
documents = body.at("documents");
|
|
||||||
if (!documents.is_array() || documents.size() == 0) {
|
|
||||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
||||||
json prompt;
|
json prompt;
|
||||||
|
|
|
@ -15,7 +15,7 @@ Feature: llama.cpp server
|
||||||
And 128 as batch size
|
And 128 as batch size
|
||||||
And 128 as ubatch size
|
And 128 as ubatch size
|
||||||
And 512 KV cache size
|
And 512 KV cache size
|
||||||
And embeddings extraction
|
And enable embeddings endpoint
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ Feature: llama.cpp server
|
||||||
And 512 as batch size
|
And 512 as batch size
|
||||||
And 512 as ubatch size
|
And 512 as ubatch size
|
||||||
And 512 KV cache size
|
And 512 KV cache size
|
||||||
And embeddings extraction
|
And enable reranking endpoint
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
|
||||||
|
@ -39,5 +39,4 @@ Feature: llama.cpp server
|
||||||
"""
|
"""
|
||||||
When reranking request
|
When reranking request
|
||||||
Then reranking results are returned
|
Then reranking results are returned
|
||||||
# TODO: this result make no sense, probably need a better model?
|
Then reranking highest score is index 2 and lowest score is index 3
|
||||||
Then reranking highest score is index 3 and lowest score is index 0
|
|
||||||
|
|
|
@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.server_api_key = None
|
context.server_api_key = None
|
||||||
context.server_continuous_batching = False
|
context.server_continuous_batching = False
|
||||||
context.server_embeddings = False
|
context.server_embeddings = False
|
||||||
|
context.server_reranking = False
|
||||||
context.server_metrics = False
|
context.server_metrics = False
|
||||||
context.server_process = None
|
context.server_process = None
|
||||||
context.seed = None
|
context.seed = None
|
||||||
|
@ -176,10 +177,13 @@ def step_server_continuous_batching(context):
|
||||||
context.server_continuous_batching = True
|
context.server_continuous_batching = True
|
||||||
|
|
||||||
|
|
||||||
@step('embeddings extraction')
|
@step('enable embeddings endpoint')
|
||||||
def step_server_embeddings(context):
|
def step_server_embeddings(context):
|
||||||
context.server_embeddings = True
|
context.server_embeddings = True
|
||||||
|
|
||||||
|
@step('enable reranking endpoint')
|
||||||
|
def step_server_reranking(context):
|
||||||
|
context.server_reranking = True
|
||||||
|
|
||||||
@step('prometheus compatible metrics exposed')
|
@step('prometheus compatible metrics exposed')
|
||||||
def step_server_metrics(context):
|
def step_server_metrics(context):
|
||||||
|
@ -1408,6 +1412,8 @@ def start_server_background(context):
|
||||||
server_args.append('--cont-batching')
|
server_args.append('--cont-batching')
|
||||||
if context.server_embeddings:
|
if context.server_embeddings:
|
||||||
server_args.append('--embedding')
|
server_args.append('--embedding')
|
||||||
|
if context.server_reranking:
|
||||||
|
server_args.append('--reranking')
|
||||||
if context.server_metrics:
|
if context.server_metrics:
|
||||||
server_args.append('--metrics')
|
server_args.append('--metrics')
|
||||||
if context.model_alias:
|
if context.model_alias:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue