llama : add reranking support (#9510)
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
1b2f992cd2
commit
f4d2b8846a
18 changed files with 602 additions and 56 deletions
|
@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
|
|||
// tokenize the prompts and trim
|
||||
std::vector<std::vector<int32_t>> inputs;
|
||||
for (const auto & prompt : prompts) {
|
||||
auto inp = ::llama_tokenize(ctx, prompt, true, false);
|
||||
auto inp = ::llama_tokenize(ctx, prompt, true, true);
|
||||
if (inp.size() > n_batch) {
|
||||
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
|
||||
__func__, (long long int) inp.size(), (long long int) n_batch);
|
||||
|
@ -234,6 +234,11 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
LOG("\n");
|
||||
}
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
// NOTE: if you change this log - update the tests in ci/run.sh
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
}
|
||||
} else {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue