server: tests: add OAI multi user scenario
This commit is contained in:
parent
9b7ea97979
commit
11adf1d864
2 changed files with 76 additions and 39 deletions
|
@ -57,3 +57,26 @@ Feature: llama.cpp server
|
|||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all prompts are predicted
|
||||
|
||||
|
||||
Scenario: Multi users OAI Compatibility
|
||||
Given a system prompt "You are an AI assistant."
|
||||
And a model tinyllama-2
|
||||
And 1024 max tokens to predict
|
||||
And streaming is enabled
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long story about AI.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another very long music lyrics.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write yet another very long music lyrics.
|
||||
"""
|
||||
Given concurrent OAI completions requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all prompts are predicted
|
|
@ -16,6 +16,7 @@ def step_server_config(context, server_fqdn, server_port, n_slots):
|
|||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
context.completions = []
|
||||
context.completion_threads = []
|
||||
context.prompts = []
|
||||
|
||||
openai.api_key = 'llama.cpp'
|
||||
|
@ -76,11 +77,58 @@ def step_max_tokens(context, max_tokens):
|
|||
|
||||
@step(u'streaming is {enable_streaming}')
|
||||
def step_streaming(context, enable_streaming):
|
||||
context.enable_streaming = bool(enable_streaming)
|
||||
context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
|
||||
|
||||
|
||||
@step(u'an OAI compatible chat completions request')
|
||||
def step_oai_chat_completions(context):
|
||||
oai_chat_completions(context, context.user_prompt)
|
||||
|
||||
|
||||
@step(u'a prompt')
|
||||
def step_a_prompt(context):
|
||||
context.prompts.append(context.text)
|
||||
|
||||
|
||||
@step(u'concurrent completion requests')
|
||||
def step_concurrent_completion_requests(context):
|
||||
concurrent_requests(context, request_completion)
|
||||
|
||||
|
||||
@step(u'concurrent OAI completions requests')
|
||||
def step_oai_chat_completions(context):
|
||||
concurrent_requests(context, oai_chat_completions)
|
||||
|
||||
|
||||
@step(u'all prompts are predicted')
|
||||
def step_all_prompts_are_predicted(context):
|
||||
for completion_thread in context.completion_threads:
|
||||
completion_thread.join()
|
||||
for completion in context.completions:
|
||||
assert_n_tokens_predicted(completion)
|
||||
|
||||
|
||||
def concurrent_requests(context, f_completion):
|
||||
context.completions.clear()
|
||||
context.completion_threads.clear()
|
||||
for prompt in context.prompts:
|
||||
completion_thread = threading.Thread(target=f_completion, args=(context, prompt))
|
||||
completion_thread.start()
|
||||
context.completion_threads.append(completion_thread)
|
||||
context.prompts.clear()
|
||||
|
||||
|
||||
def request_completion(context, prompt, n_predict=None):
|
||||
response = requests.post(f'{context.base_url}/completion', json={
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else 4096,
|
||||
})
|
||||
status_code = response.status_code
|
||||
assert status_code == 200
|
||||
context.completions.append(response.json())
|
||||
|
||||
|
||||
def oai_chat_completions(context, user_prompt):
|
||||
chat_completion = openai.Completion.create(
|
||||
messages=[
|
||||
{
|
||||
|
@ -89,7 +137,7 @@ def step_oai_chat_completions(context):
|
|||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": context.user_prompt,
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
model=context.model,
|
||||
|
@ -120,39 +168,6 @@ def step_oai_chat_completions(context):
|
|||
})
|
||||
|
||||
|
||||
@step(u'a prompt')
|
||||
def step_a_prompt(context):
|
||||
context.prompts.append(context.text)
|
||||
|
||||
|
||||
@step(u'concurrent completion requests')
|
||||
def step_n_concurrent_prompts(context):
|
||||
context.completions.clear()
|
||||
context.completion_threads = []
|
||||
for prompt in context.prompts:
|
||||
completion_thread = threading.Thread(target=request_completion, args=(context, prompt))
|
||||
completion_thread.start()
|
||||
context.completion_threads.append(completion_thread)
|
||||
|
||||
|
||||
@step(u'all prompts are predicted')
|
||||
def step_all_prompts_must_be_predicted(context):
|
||||
for completion_thread in context.completion_threads:
|
||||
completion_thread.join()
|
||||
for completion in context.completions:
|
||||
assert_n_tokens_predicted(completion)
|
||||
|
||||
|
||||
def request_completion(context, prompt, n_predict=None):
|
||||
response = requests.post(f'{context.base_url}/completion', json={
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else 4096,
|
||||
})
|
||||
status_code = response.status_code
|
||||
assert status_code == 200
|
||||
context.completions.append(response.json())
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
|
||||
content = completion_response['content']
|
||||
n_predicted = completion_response['timings']['predicted_n']
|
||||
|
@ -163,10 +178,9 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
|
|||
|
||||
|
||||
def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None):
|
||||
status_code = 500
|
||||
while status_code != expected_http_status_code:
|
||||
while True:
|
||||
health_response = requests.get(f'{context.base_url}/health', params)
|
||||
status_code = health_response.status_code
|
||||
health = health_response.json()
|
||||
if status_code != expected_http_status_code or health['status'] != expected_health_status:
|
||||
time.sleep(0.001)
|
||||
if status_code == expected_http_status_code and health['status'] == expected_health_status:
|
||||
break
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue