llama : add reranking support (#9510)
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
1b2f992cd2
commit
f4d2b8846a
18 changed files with 602 additions and 56 deletions
|
@ -15,7 +15,7 @@ Feature: llama.cpp server
|
|||
And 128 as batch size
|
||||
And 128 as ubatch size
|
||||
And 512 KV cache size
|
||||
And embeddings extraction
|
||||
And enable embeddings endpoint
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
|
|
42
examples/server/tests/features/rerank.feature
Normal file
42
examples/server/tests/features/rerank.feature
Normal file
|
@ -0,0 +1,42 @@
|
|||
@llama.cpp
|
||||
@rerank
|
||||
Feature: llama.cpp server
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf
|
||||
And a model file jina-reranker-v1-tiny-en.gguf
|
||||
And a model alias jina-reranker-v1-tiny-en
|
||||
And 42 as server seed
|
||||
And 2 slots
|
||||
And 512 as batch size
|
||||
And 512 as ubatch size
|
||||
And 512 KV cache size
|
||||
And enable reranking endpoint
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario: Rerank
|
||||
Given a rerank query:
|
||||
"""
|
||||
Machine learning is
|
||||
"""
|
||||
And a rerank document:
|
||||
"""
|
||||
A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.
|
||||
"""
|
||||
And a rerank document:
|
||||
"""
|
||||
Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.
|
||||
"""
|
||||
And a rerank document:
|
||||
"""
|
||||
Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.
|
||||
"""
|
||||
And a rerank document:
|
||||
"""
|
||||
Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.
|
||||
"""
|
||||
When reranking request
|
||||
Then reranking results are returned
|
||||
Then reranking highest score is index 2 and lowest score is index 3
|
|
@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.server_api_key = None
|
||||
context.server_continuous_batching = False
|
||||
context.server_embeddings = False
|
||||
context.server_reranking = False
|
||||
context.server_metrics = False
|
||||
context.server_process = None
|
||||
context.seed = None
|
||||
|
@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.concurrent_tasks = []
|
||||
context.prompts = []
|
||||
|
||||
context.reranking_query = None
|
||||
context.reranking_documents = []
|
||||
context.reranking_results = None
|
||||
|
||||
|
||||
@step('a model file {hf_file} from HF repo {hf_repo}')
|
||||
def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
||||
|
@ -172,10 +177,13 @@ def step_server_continuous_batching(context):
|
|||
context.server_continuous_batching = True
|
||||
|
||||
|
||||
@step('embeddings extraction')
|
||||
@step('enable embeddings endpoint')
|
||||
def step_server_embeddings(context):
|
||||
context.server_embeddings = True
|
||||
|
||||
@step('enable reranking endpoint')
|
||||
def step_server_reranking(context):
|
||||
context.server_reranking = True
|
||||
|
||||
@step('prometheus compatible metrics exposed')
|
||||
def step_server_metrics(context):
|
||||
|
@ -452,6 +460,14 @@ def step_impl(context, n_ga_w):
|
|||
def step_prompt_passkey(context):
|
||||
context.prompt_passkey = context_text(context)
|
||||
|
||||
@step('a rerank query')
|
||||
def step_set_rerank_query(context):
|
||||
context.reranking_query = context_text(context)
|
||||
context.reranking_documents = []
|
||||
|
||||
@step('a rerank document')
|
||||
def step_set_rerank_document(context):
|
||||
context.reranking_documents.append(context_text(context))
|
||||
|
||||
@step('{n_prompts:d} fixed prompts')
|
||||
def step_fixed_prompts(context, n_prompts):
|
||||
|
@ -619,6 +635,22 @@ async def step_compute_embedding(context):
|
|||
context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
|
||||
|
||||
|
||||
@step('reranking request')
|
||||
@async_run_until_complete
|
||||
async def step_compute_reranking(context):
|
||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||
async with session.post(f'{context.base_url}/reranking',
|
||||
json={
|
||||
"query": context.reranking_query,
|
||||
"documents": context.reranking_documents,
|
||||
}) as response:
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
context.reranking_results = response_json['results']
|
||||
else:
|
||||
context.reranking_results = response.status
|
||||
|
||||
|
||||
@step('all embeddings are the same')
|
||||
@async_run_until_complete
|
||||
async def step_all_embeddings_are_the_same(context):
|
||||
|
@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context):
|
|||
for i in range(n_embedding_requests):
|
||||
assert_embeddings(context.tasks_result.pop().pop())
|
||||
|
||||
@step('reranking results are returned')
|
||||
def reranking_results_are_returned(context):
|
||||
assert len(context.reranking_results) == len(context.reranking_documents)
|
||||
|
||||
@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}')
|
||||
def reranking_results_are_returned(context, idx_high: int, idx_low: int):
|
||||
max_score, max_idx = 0, 0
|
||||
min_score, min_idx = 0, 0
|
||||
for res in context.reranking_results:
|
||||
if max_score < res['relevance_score']:
|
||||
max_score = res['relevance_score']
|
||||
max_idx = res['index']
|
||||
if min_score > res['relevance_score']:
|
||||
min_score = res['relevance_score']
|
||||
min_idx = res['index']
|
||||
print(context.reranking_results)
|
||||
assert max_idx == idx_high
|
||||
assert min_idx == idx_low
|
||||
|
||||
@step('adding special tokens')
|
||||
def step_tokenize_set_add_special(context):
|
||||
|
@ -1362,6 +1412,8 @@ def start_server_background(context):
|
|||
server_args.append('--cont-batching')
|
||||
if context.server_embeddings:
|
||||
server_args.append('--embedding')
|
||||
if context.server_reranking:
|
||||
server_args.append('--reranking')
|
||||
if context.server_metrics:
|
||||
server_args.append('--metrics')
|
||||
if context.model_alias:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue