llama : add reranking support (#9510)
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
1b2f992cd2
commit
f4d2b8846a
18 changed files with 602 additions and 56 deletions
|
@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.server_api_key = None
|
||||
context.server_continuous_batching = False
|
||||
context.server_embeddings = False
|
||||
context.server_reranking = False
|
||||
context.server_metrics = False
|
||||
context.server_process = None
|
||||
context.seed = None
|
||||
|
@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.concurrent_tasks = []
|
||||
context.prompts = []
|
||||
|
||||
context.reranking_query = None
|
||||
context.reranking_documents = []
|
||||
context.reranking_results = None
|
||||
|
||||
|
||||
@step('a model file {hf_file} from HF repo {hf_repo}')
|
||||
def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
||||
|
@ -172,10 +177,13 @@ def step_server_continuous_batching(context):
|
|||
context.server_continuous_batching = True
|
||||
|
||||
|
||||
@step('embeddings extraction')
|
||||
@step('enable embeddings endpoint')
|
||||
def step_server_embeddings(context):
|
||||
context.server_embeddings = True
|
||||
|
||||
@step('enable reranking endpoint')
|
||||
def step_server_reranking(context):
|
||||
context.server_reranking = True
|
||||
|
||||
@step('prometheus compatible metrics exposed')
|
||||
def step_server_metrics(context):
|
||||
|
@ -452,6 +460,14 @@ def step_impl(context, n_ga_w):
|
|||
def step_prompt_passkey(context):
|
||||
context.prompt_passkey = context_text(context)
|
||||
|
||||
@step('a rerank query')
|
||||
def step_set_rerank_query(context):
|
||||
context.reranking_query = context_text(context)
|
||||
context.reranking_documents = []
|
||||
|
||||
@step('a rerank document')
|
||||
def step_set_rerank_document(context):
|
||||
context.reranking_documents.append(context_text(context))
|
||||
|
||||
@step('{n_prompts:d} fixed prompts')
|
||||
def step_fixed_prompts(context, n_prompts):
|
||||
|
@ -619,6 +635,22 @@ async def step_compute_embedding(context):
|
|||
context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
|
||||
|
||||
|
||||
@step('reranking request')
|
||||
@async_run_until_complete
|
||||
async def step_compute_reranking(context):
|
||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||
async with session.post(f'{context.base_url}/reranking',
|
||||
json={
|
||||
"query": context.reranking_query,
|
||||
"documents": context.reranking_documents,
|
||||
}) as response:
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
context.reranking_results = response_json['results']
|
||||
else:
|
||||
context.reranking_results = response.status
|
||||
|
||||
|
||||
@step('all embeddings are the same')
|
||||
@async_run_until_complete
|
||||
async def step_all_embeddings_are_the_same(context):
|
||||
|
@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context):
|
|||
for i in range(n_embedding_requests):
|
||||
assert_embeddings(context.tasks_result.pop().pop())
|
||||
|
||||
@step('reranking results are returned')
|
||||
def reranking_results_are_returned(context):
|
||||
assert len(context.reranking_results) == len(context.reranking_documents)
|
||||
|
||||
@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}')
|
||||
def reranking_results_are_returned(context, idx_high: int, idx_low: int):
|
||||
max_score, max_idx = 0, 0
|
||||
min_score, min_idx = 0, 0
|
||||
for res in context.reranking_results:
|
||||
if max_score < res['relevance_score']:
|
||||
max_score = res['relevance_score']
|
||||
max_idx = res['index']
|
||||
if min_score > res['relevance_score']:
|
||||
min_score = res['relevance_score']
|
||||
min_idx = res['index']
|
||||
print(context.reranking_results)
|
||||
assert max_idx == idx_high
|
||||
assert min_idx == idx_low
|
||||
|
||||
@step('adding special tokens')
|
||||
def step_tokenize_set_add_special(context):
|
||||
|
@ -1362,6 +1412,8 @@ def start_server_background(context):
|
|||
server_args.append('--cont-batching')
|
||||
if context.server_embeddings:
|
||||
server_args.append('--embedding')
|
||||
if context.server_reranking:
|
||||
server_args.append('--reranking')
|
||||
if context.server_metrics:
|
||||
server_args.append('--metrics')
|
||||
if context.model_alias:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue