server: tests: add OAI multi user scenario

This commit is contained in:
Pierrick HYMBERT 2024-02-20 22:00:09 +01:00
parent 9b7ea97979
commit 11adf1d864
2 changed files with 76 additions and 39 deletions

View file

@ -57,3 +57,26 @@ Feature: llama.cpp server
Then the server is busy Then the server is busy
Then the server is idle Then the server is idle
Then all prompts are predicted Then all prompts are predicted
Scenario: Multi users OAI Compatibility
Given a system prompt "You are an AI assistant."
And a model tinyllama-2
And 1024 max tokens to predict
And streaming is enabled
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write yet another very long music lyrics.
"""
Given concurrent OAI completions requests
Then the server is busy
Then the server is idle
Then all prompts are predicted

View file

@ -16,6 +16,7 @@ def step_server_config(context, server_fqdn, server_port, n_slots):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}' context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.completions = [] context.completions = []
context.completion_threads = []
context.prompts = [] context.prompts = []
openai.api_key = 'llama.cpp' openai.api_key = 'llama.cpp'
@ -76,11 +77,58 @@ def step_max_tokens(context, max_tokens):
@step(u'streaming is {enable_streaming}') @step(u'streaming is {enable_streaming}')
def step_streaming(context, enable_streaming): def step_streaming(context, enable_streaming):
context.enable_streaming = bool(enable_streaming) context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
@step(u'an OAI compatible chat completions request') @step(u'an OAI compatible chat completions request')
def step_oai_chat_completions(context): def step_oai_chat_completions(context):
oai_chat_completions(context, context.user_prompt)
@step(u'a prompt')
def step_a_prompt(context):
context.prompts.append(context.text)
@step(u'concurrent completion requests')
def step_concurrent_completion_requests(context):
concurrent_requests(context, request_completion)
@step(u'concurrent OAI completions requests')
def step_oai_chat_completions(context):
concurrent_requests(context, oai_chat_completions)
@step(u'all prompts are predicted')
def step_all_prompts_are_predicted(context):
for completion_thread in context.completion_threads:
completion_thread.join()
for completion in context.completions:
assert_n_tokens_predicted(completion)
def concurrent_requests(context, f_completion):
context.completions.clear()
context.completion_threads.clear()
for prompt in context.prompts:
completion_thread = threading.Thread(target=f_completion, args=(context, prompt))
completion_thread.start()
context.completion_threads.append(completion_thread)
context.prompts.clear()
def request_completion(context, prompt, n_predict=None):
response = requests.post(f'{context.base_url}/completion', json={
"prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else 4096,
})
status_code = response.status_code
assert status_code == 200
context.completions.append(response.json())
def oai_chat_completions(context, user_prompt):
chat_completion = openai.Completion.create( chat_completion = openai.Completion.create(
messages=[ messages=[
{ {
@ -89,7 +137,7 @@ def step_oai_chat_completions(context):
}, },
{ {
"role": "user", "role": "user",
"content": context.user_prompt, "content": user_prompt,
} }
], ],
model=context.model, model=context.model,
@ -120,39 +168,6 @@ def step_oai_chat_completions(context):
}) })
@step(u'a prompt')
def step_a_prompt(context):
context.prompts.append(context.text)
@step(u'concurrent completion requests')
def step_n_concurrent_prompts(context):
context.completions.clear()
context.completion_threads = []
for prompt in context.prompts:
completion_thread = threading.Thread(target=request_completion, args=(context, prompt))
completion_thread.start()
context.completion_threads.append(completion_thread)
@step(u'all prompts are predicted')
def step_all_prompts_must_be_predicted(context):
for completion_thread in context.completion_threads:
completion_thread.join()
for completion in context.completions:
assert_n_tokens_predicted(completion)
def request_completion(context, prompt, n_predict=None):
response = requests.post(f'{context.base_url}/completion', json={
"prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else 4096,
})
status_code = response.status_code
assert status_code == 200
context.completions.append(response.json())
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None): def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']
@ -163,10 +178,9 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None): def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None):
status_code = 500 while True:
while status_code != expected_http_status_code:
health_response = requests.get(f'{context.base_url}/health', params) health_response = requests.get(f'{context.base_url}/health', params)
status_code = health_response.status_code status_code = health_response.status_code
health = health_response.json() health = health_response.json()
if status_code != expected_http_status_code or health['status'] != expected_health_status: if status_code == expected_http_status_code and health['status'] == expected_health_status:
time.sleep(0.001) break