server: tests: adding OAI compatible embedding concurrent endpoint
This commit is contained in:
parent
09b77b4c9e
commit
466987eb7b
2 changed files with 104 additions and 43 deletions
|
@ -98,3 +98,26 @@ Feature: Parallel
|
||||||
Then the server is busy
|
Then the server is busy
|
||||||
Then the server is idle
|
Then the server is idle
|
||||||
Then all embeddings are generated
|
Then all embeddings are generated
|
||||||
|
|
||||||
|
Scenario: Multi users OAI compatibility embeddings
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
In which country Paris is located ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Is Madrid the capital of Spain ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
What is the biggest US city ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
What is the capital of Bulgaria ?
|
||||||
|
"""
|
||||||
|
And a model tinyllama-2
|
||||||
|
Given concurrent OAI embedding requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
Then all embeddings are generated
|
||||||
|
|
|
@ -261,35 +261,35 @@ def step_a_prompt_prompt(context, prompt):
|
||||||
@step(u'concurrent completion requests')
|
@step(u'concurrent completion requests')
|
||||||
@async_run_until_complete()
|
@async_run_until_complete()
|
||||||
async def step_concurrent_completion_requests(context):
|
async def step_concurrent_completion_requests(context):
|
||||||
await concurrent_completion_requests(context,
|
await concurrent_requests(context,
|
||||||
request_completion,
|
request_completion,
|
||||||
# prompt is inserted automatically
|
# prompt is inserted automatically
|
||||||
context.base_url,
|
context.base_url,
|
||||||
debug=context.debug,
|
debug=context.debug,
|
||||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||||
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||||
user_api_key=context.user_api_key if hasattr(context,
|
user_api_key=context.user_api_key if hasattr(context,
|
||||||
'user_api_key') else None)
|
'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent OAI completions requests')
|
@step(u'concurrent OAI completions requests')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_oai_chat_completions(context):
|
async def step_oai_chat_completions(context):
|
||||||
await concurrent_completion_requests(context, oai_chat_completions,
|
await concurrent_requests(context, oai_chat_completions,
|
||||||
# user_prompt is inserted automatically
|
# user_prompt is inserted automatically
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
context.base_url,
|
context.base_url,
|
||||||
True, # async_client
|
True, # async_client
|
||||||
model=context.model
|
model=context.model
|
||||||
if hasattr(context, 'model') else None,
|
if hasattr(context, 'model') else None,
|
||||||
n_predict=context.n_predict
|
n_predict=context.n_predict
|
||||||
if hasattr(context, 'n_predict') else None,
|
if hasattr(context, 'n_predict') else None,
|
||||||
enable_streaming=context.enable_streaming
|
enable_streaming=context.enable_streaming
|
||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
server_seed=context.server_seed
|
server_seed=context.server_seed
|
||||||
if hasattr(context, 'server_seed') else None,
|
if hasattr(context, 'server_seed') else None,
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step(u'all prompts are predicted')
|
@step(u'all prompts are predicted')
|
||||||
|
@ -316,9 +316,7 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
||||||
@step(u'embeddings are computed for')
|
@step(u'embeddings are computed for')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_compute_embedding(context):
|
async def step_compute_embedding(context):
|
||||||
content = context.text
|
context.embeddings = await request_embedding(context.text, base_url=context.base_url)
|
||||||
base_url = context.base_url
|
|
||||||
context.embeddings = await request_embedding(content, base_url)
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'embeddings are generated')
|
@step(u'embeddings are generated')
|
||||||
|
@ -327,25 +325,32 @@ def step_assert_embeddings(context):
|
||||||
|
|
||||||
|
|
||||||
@step(u'an OAI compatible embeddings computation request for')
|
@step(u'an OAI compatible embeddings computation request for')
|
||||||
def step_oai_compute_embedding(context):
|
@async_run_until_complete
|
||||||
openai.api_key = 'nope' # openai client always expects an api_keu
|
async def step_oai_compute_embeddings(context):
|
||||||
if context.user_api_key is not None:
|
context.embeddings = await request_oai_embeddings(context.text,
|
||||||
openai.api_key = context.user_api_key
|
base_url=context.base_url,
|
||||||
openai.api_base = f'{context.base_url}/v1'
|
user_api_key=context.user_api_key,
|
||||||
embeddings = openai.Embedding.create(
|
model=context.model)
|
||||||
model=context.model,
|
|
||||||
input=context.text,
|
|
||||||
)
|
|
||||||
context.embeddings = embeddings
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent embedding requests')
|
@step(u'concurrent embedding requests')
|
||||||
@async_run_until_complete()
|
@async_run_until_complete()
|
||||||
async def step_concurrent_embedding_requests(context):
|
async def step_concurrent_embedding_requests(context):
|
||||||
await concurrent_completion_requests(context,
|
await concurrent_requests(context,
|
||||||
request_embedding,
|
request_embedding,
|
||||||
# prompt is inserted automatically
|
# prompt is inserted automatically
|
||||||
context.base_url)
|
base_url=context.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'concurrent OAI embedding requests')
|
||||||
|
@async_run_until_complete()
|
||||||
|
async def step_concurrent_oai_embedding_requests(context):
|
||||||
|
await concurrent_requests(context,
|
||||||
|
request_oai_embeddings,
|
||||||
|
# prompt is inserted automatically
|
||||||
|
base_url=context.base_url,
|
||||||
|
async_client=True,
|
||||||
|
model=context.model)
|
||||||
|
|
||||||
|
|
||||||
@step(u'all embeddings are generated')
|
@step(u'all embeddings are generated')
|
||||||
|
@ -401,7 +406,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
|
||||||
assert context.options_response.headers[cors_header] == cors_header_value
|
assert context.options_response.headers[cors_header] == cors_header_value
|
||||||
|
|
||||||
|
|
||||||
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
async def concurrent_requests(context, f_completion, *args, **kwargs):
|
||||||
n_prompts = len(context.prompts)
|
n_prompts = len(context.prompts)
|
||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"starting {n_prompts} concurrent completion requests...")
|
print(f"starting {n_prompts} concurrent completion requests...")
|
||||||
|
@ -565,7 +570,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
async def request_embedding(content, base_url):
|
async def request_embedding(content, base_url=None):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f'{base_url}/embedding',
|
async with session.post(f'{base_url}/embedding',
|
||||||
json={
|
json={
|
||||||
|
@ -576,6 +581,39 @@ async def request_embedding(content, base_url):
|
||||||
return response_json['embedding']
|
return response_json['embedding']
|
||||||
|
|
||||||
|
|
||||||
|
async def request_oai_embeddings(input,
|
||||||
|
base_url=None, user_api_key=None,
|
||||||
|
model=None, async_client=False):
|
||||||
|
# openai client always expects an api_key
|
||||||
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
|
if async_client:
|
||||||
|
origin = 'llama.cpp'
|
||||||
|
if user_api_key is not None:
|
||||||
|
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(f'{base_url}/v1/embeddings',
|
||||||
|
json={
|
||||||
|
"input": input,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
headers=headers) as response:
|
||||||
|
assert response.status == 200, f"received status code not expected: {response.status}"
|
||||||
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
|
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
|
||||||
|
response_json = await response.json()
|
||||||
|
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
|
||||||
|
assert response_json['object'] == 'list'
|
||||||
|
return response_json['data']
|
||||||
|
else:
|
||||||
|
openai.api_key = user_api_key
|
||||||
|
openai.api_base = f'{base_url}/v1'
|
||||||
|
embeddings = openai.Embedding.create(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||||
content = completion_response['content']
|
content = completion_response['content']
|
||||||
n_predicted = completion_response['timings']['predicted_n']
|
n_predicted = completion_response['timings']['predicted_n']
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue