server: tests: adding OAI compatible embedding concurrent endpoint

This commit is contained in:
Pierrick HYMBERT 2024-02-24 18:06:32 +01:00
parent 09b77b4c9e
commit 466987eb7b
2 changed files with 104 additions and 43 deletions

View file

@ -98,3 +98,26 @@ Feature: Parallel
Then the server is busy Then the server is busy
Then the server is idle Then the server is idle
Then all embeddings are generated Then all embeddings are generated
Scenario: Multi users OAI compatibility embeddings
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
And a prompt:
"""
What is the biggest US city ?
"""
And a prompt:
"""
What is the capital of Bulgaria ?
"""
And a model tinyllama-2
Given concurrent OAI embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View file

@ -261,35 +261,35 @@ def step_a_prompt_prompt(context, prompt):
@step(u'concurrent completion requests') @step(u'concurrent completion requests')
@async_run_until_complete() @async_run_until_complete()
async def step_concurrent_completion_requests(context): async def step_concurrent_completion_requests(context):
await concurrent_completion_requests(context, await concurrent_requests(context,
request_completion, request_completion,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url, context.base_url,
debug=context.debug, debug=context.debug,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None, n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
server_seed=context.server_seed if hasattr(context, 'server_seed') else None, server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key if hasattr(context, user_api_key=context.user_api_key if hasattr(context,
'user_api_key') else None) 'user_api_key') else None)
@step(u'concurrent OAI completions requests') @step(u'concurrent OAI completions requests')
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context): async def step_oai_chat_completions(context):
await concurrent_completion_requests(context, oai_chat_completions, await concurrent_requests(context, oai_chat_completions,
# user_prompt is inserted automatically # user_prompt is inserted automatically
context.system_prompt, context.system_prompt,
context.base_url, context.base_url,
True, # async_client True, # async_client
model=context.model model=context.model
if hasattr(context, 'model') else None, if hasattr(context, 'model') else None,
n_predict=context.n_predict n_predict=context.n_predict
if hasattr(context, 'n_predict') else None, if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed server_seed=context.server_seed
if hasattr(context, 'server_seed') else None, if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@step(u'all prompts are predicted') @step(u'all prompts are predicted')
@ -316,9 +316,7 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
@step(u'embeddings are computed for') @step(u'embeddings are computed for')
@async_run_until_complete @async_run_until_complete
async def step_compute_embedding(context): async def step_compute_embedding(context):
content = context.text context.embeddings = await request_embedding(context.text, base_url=context.base_url)
base_url = context.base_url
context.embeddings = await request_embedding(content, base_url)
@step(u'embeddings are generated') @step(u'embeddings are generated')
@ -327,25 +325,32 @@ def step_assert_embeddings(context):
@step(u'an OAI compatible embeddings computation request for') @step(u'an OAI compatible embeddings computation request for')
def step_oai_compute_embedding(context): @async_run_until_complete
openai.api_key = 'nope' # openai client always expects an api_keu async def step_oai_compute_embeddings(context):
if context.user_api_key is not None: context.embeddings = await request_oai_embeddings(context.text,
openai.api_key = context.user_api_key base_url=context.base_url,
openai.api_base = f'{context.base_url}/v1' user_api_key=context.user_api_key,
embeddings = openai.Embedding.create( model=context.model)
model=context.model,
input=context.text,
)
context.embeddings = embeddings
@step(u'concurrent embedding requests') @step(u'concurrent embedding requests')
@async_run_until_complete() @async_run_until_complete()
async def step_concurrent_embedding_requests(context): async def step_concurrent_embedding_requests(context):
await concurrent_completion_requests(context, await concurrent_requests(context,
request_embedding, request_embedding,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url) base_url=context.base_url)
@step(u'concurrent OAI embedding requests')
@async_run_until_complete()
async def step_concurrent_oai_embedding_requests(context):
await concurrent_requests(context,
request_oai_embeddings,
# prompt is inserted automatically
base_url=context.base_url,
async_client=True,
model=context.model)
@step(u'all embeddings are generated') @step(u'all embeddings are generated')
@ -401,7 +406,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
assert context.options_response.headers[cors_header] == cors_header_value assert context.options_response.headers[cors_header] == cors_header_value
async def concurrent_completion_requests(context, f_completion, *args, **kwargs): async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts) n_prompts = len(context.prompts)
if context.debug: if context.debug:
print(f"starting {n_prompts} concurrent completion requests...") print(f"starting {n_prompts} concurrent completion requests...")
@ -565,7 +570,7 @@ async def oai_chat_completions(user_prompt,
return completion_response return completion_response
async def request_embedding(content, base_url): async def request_embedding(content, base_url=None):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding', async with session.post(f'{base_url}/embedding',
json={ json={
@ -576,6 +581,39 @@ async def request_embedding(content, base_url):
return response_json['embedding'] return response_json['embedding']
async def request_oai_embeddings(input,
base_url=None, user_api_key=None,
model=None, async_client=False):
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
origin = 'llama.cpp'
if user_api_key is not None:
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/v1/embeddings',
json={
"input": input,
"model": model,
},
headers=headers) as response:
assert response.status == 200, f"received status code not expected: {response.status}"
assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list'
return response_json['data']
else:
openai.api_key = user_api_key
openai.api_base = f'{base_url}/v1'
embeddings = openai.Embedding.create(
model=model,
input=input,
)
return embeddings
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']