server: tests: add OAI multi user scenario

This commit is contained in:
Pierrick HYMBERT 2024-02-20 22:00:09 +01:00
parent 9b7ea97979
commit 11adf1d864
2 changed files with 76 additions and 39 deletions

View file

@ -57,3 +57,26 @@ Feature: llama.cpp server
Then the server is busy
Then the server is idle
Then all prompts are predicted
Scenario: Multi users OAI Compatibility
Given a system prompt "You are an AI assistant."
And a model tinyllama-2
And 1024 max tokens to predict
And streaming is enabled
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write yet another very long music lyrics.
"""
Given concurrent OAI completions requests
Then the server is busy
Then the server is idle
Then all prompts are predicted

View file

@ -16,6 +16,7 @@ def step_server_config(context, server_fqdn, server_port, n_slots):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.completions = []
context.completion_threads = []
context.prompts = []
openai.api_key = 'llama.cpp'
@ -76,11 +77,58 @@ def step_max_tokens(context, max_tokens):
@step(u'streaming is {enable_streaming}')
def step_streaming(context, enable_streaming):
context.enable_streaming = bool(enable_streaming)
context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
@step(u'an OAI compatible chat completions request')
def step_oai_chat_completions(context):
oai_chat_completions(context, context.user_prompt)
@step(u'a prompt')
def step_a_prompt(context):
context.prompts.append(context.text)
@step(u'concurrent completion requests')
def step_concurrent_completion_requests(context):
concurrent_requests(context, request_completion)
@step(u'concurrent OAI completions requests')
def step_oai_chat_completions(context):
concurrent_requests(context, oai_chat_completions)
@step(u'all prompts are predicted')
def step_all_prompts_are_predicted(context):
for completion_thread in context.completion_threads:
completion_thread.join()
for completion in context.completions:
assert_n_tokens_predicted(completion)
def concurrent_requests(context, f_completion):
context.completions.clear()
context.completion_threads.clear()
for prompt in context.prompts:
completion_thread = threading.Thread(target=f_completion, args=(context, prompt))
completion_thread.start()
context.completion_threads.append(completion_thread)
context.prompts.clear()
def request_completion(context, prompt, n_predict=None):
response = requests.post(f'{context.base_url}/completion', json={
"prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else 4096,
})
status_code = response.status_code
assert status_code == 200
context.completions.append(response.json())
def oai_chat_completions(context, user_prompt):
chat_completion = openai.Completion.create(
messages=[
{
@ -89,7 +137,7 @@ def step_oai_chat_completions(context):
},
{
"role": "user",
"content": context.user_prompt,
"content": user_prompt,
}
],
model=context.model,
@ -120,39 +168,6 @@ def step_oai_chat_completions(context):
})
@step(u'a prompt')
def step_a_prompt(context):
context.prompts.append(context.text)
@step(u'concurrent completion requests')
def step_n_concurrent_prompts(context):
context.completions.clear()
context.completion_threads = []
for prompt in context.prompts:
completion_thread = threading.Thread(target=request_completion, args=(context, prompt))
completion_thread.start()
context.completion_threads.append(completion_thread)
@step(u'all prompts are predicted')
def step_all_prompts_must_be_predicted(context):
for completion_thread in context.completion_threads:
completion_thread.join()
for completion in context.completions:
assert_n_tokens_predicted(completion)
def request_completion(context, prompt, n_predict=None):
response = requests.post(f'{context.base_url}/completion', json={
"prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else 4096,
})
status_code = response.status_code
assert status_code == 200
context.completions.append(response.json())
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n']
@ -163,10 +178,9 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None):
status_code = 500
while status_code != expected_http_status_code:
while True:
health_response = requests.get(f'{context.base_url}/health', params)
status_code = health_response.status_code
health = health_response.json()
if status_code != expected_http_status_code or health['status'] != expected_health_status:
time.sleep(0.001)
if status_code == expected_http_status_code and health['status'] == expected_health_status:
break