llama : add reranking support (#9510)

* py : add XLMRobertaForSequenceClassification [no ci]

* py : fix scalar-tensor conversion [no ci]

* py : fix position embeddings chop [no ci]

* llama : read new cls tensors [no ci]

* llama : add classigication head (wip) [no ci]

* llama : add "rank" pooling type

ggml-ci

* server : add rerank endpoint

ggml-ci

* llama : aboud ggml_repeat during classification

* rerank : cleanup + comments

* server : accept /rerank endpoint in addition to /v1/rerank [no ci]

* embedding : parse special tokens

* jina : support v1 reranker

* vocab : minor style

ggml-ci

* server : initiate tests for later

ggml-ci

* server : add docs

* llama : add comment [no ci]

* llama : fix uninitialized tensors

* ci : add rerank tests

ggml-ci

* add reranking test

* change test data

* Update examples/server/server.cpp

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>

* add `--reranking` argument

* update server docs

* llama : fix comment [no ci]

ggml-ci

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-09-28 17:42:03 +03:00 committed by GitHub
parent 1b2f992cd2
commit f4d2b8846a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 602 additions and 56 deletions

View file

@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0;
}
if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
return true;
}
@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.verbose_prompt = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
));
add_opt(llama_arg(
{"--no-display-prompt"},
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
}
).set_sparam());
add_opt(llama_arg(
{"--pooling"}, "{none,mean,cls,last}",
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
[](gpt_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(llama_arg(
{"--reranking", "--rerank"},
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
[](gpt_params & params) {
params.reranking = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(llama_arg(
{"--api-key"}, "KEY",
"API key to use for authentication (default: none)",