server: tests: add OAI multi user scenario
This commit is contained in:
parent
9b7ea97979
commit
11adf1d864
2 changed files with 76 additions and 39 deletions
|
@ -57,3 +57,26 @@ Feature: llama.cpp server
|
||||||
Then the server is busy
|
Then the server is busy
|
||||||
Then the server is idle
|
Then the server is idle
|
||||||
Then all prompts are predicted
|
Then all prompts are predicted
|
||||||
|
|
||||||
|
|
||||||
|
Scenario: Multi users OAI Compatibility
|
||||||
|
Given a system prompt "You are an AI assistant."
|
||||||
|
And a model tinyllama-2
|
||||||
|
And 1024 max tokens to predict
|
||||||
|
And streaming is enabled
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long story about AI.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write another very long music lyrics.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write yet another very long music lyrics.
|
||||||
|
"""
|
||||||
|
Given concurrent OAI completions requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
Then all prompts are predicted
|
|
@ -16,6 +16,7 @@ def step_server_config(context, server_fqdn, server_port, n_slots):
|
||||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||||
|
|
||||||
context.completions = []
|
context.completions = []
|
||||||
|
context.completion_threads = []
|
||||||
context.prompts = []
|
context.prompts = []
|
||||||
|
|
||||||
openai.api_key = 'llama.cpp'
|
openai.api_key = 'llama.cpp'
|
||||||
|
@ -76,11 +77,58 @@ def step_max_tokens(context, max_tokens):
|
||||||
|
|
||||||
@step(u'streaming is {enable_streaming}')
|
@step(u'streaming is {enable_streaming}')
|
||||||
def step_streaming(context, enable_streaming):
|
def step_streaming(context, enable_streaming):
|
||||||
context.enable_streaming = bool(enable_streaming)
|
context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
|
||||||
|
|
||||||
|
|
||||||
@step(u'an OAI compatible chat completions request')
|
@step(u'an OAI compatible chat completions request')
|
||||||
def step_oai_chat_completions(context):
|
def step_oai_chat_completions(context):
|
||||||
|
oai_chat_completions(context, context.user_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'a prompt')
|
||||||
|
def step_a_prompt(context):
|
||||||
|
context.prompts.append(context.text)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'concurrent completion requests')
|
||||||
|
def step_concurrent_completion_requests(context):
|
||||||
|
concurrent_requests(context, request_completion)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'concurrent OAI completions requests')
|
||||||
|
def step_oai_chat_completions(context):
|
||||||
|
concurrent_requests(context, oai_chat_completions)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'all prompts are predicted')
|
||||||
|
def step_all_prompts_are_predicted(context):
|
||||||
|
for completion_thread in context.completion_threads:
|
||||||
|
completion_thread.join()
|
||||||
|
for completion in context.completions:
|
||||||
|
assert_n_tokens_predicted(completion)
|
||||||
|
|
||||||
|
|
||||||
|
def concurrent_requests(context, f_completion):
|
||||||
|
context.completions.clear()
|
||||||
|
context.completion_threads.clear()
|
||||||
|
for prompt in context.prompts:
|
||||||
|
completion_thread = threading.Thread(target=f_completion, args=(context, prompt))
|
||||||
|
completion_thread.start()
|
||||||
|
context.completion_threads.append(completion_thread)
|
||||||
|
context.prompts.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def request_completion(context, prompt, n_predict=None):
|
||||||
|
response = requests.post(f'{context.base_url}/completion', json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"n_predict": int(n_predict) if n_predict is not None else 4096,
|
||||||
|
})
|
||||||
|
status_code = response.status_code
|
||||||
|
assert status_code == 200
|
||||||
|
context.completions.append(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def oai_chat_completions(context, user_prompt):
|
||||||
chat_completion = openai.Completion.create(
|
chat_completion = openai.Completion.create(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -89,7 +137,7 @@ def step_oai_chat_completions(context):
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": context.user_prompt,
|
"content": user_prompt,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=context.model,
|
model=context.model,
|
||||||
|
@ -120,39 +168,6 @@ def step_oai_chat_completions(context):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@step(u'a prompt')
|
|
||||||
def step_a_prompt(context):
|
|
||||||
context.prompts.append(context.text)
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent completion requests')
|
|
||||||
def step_n_concurrent_prompts(context):
|
|
||||||
context.completions.clear()
|
|
||||||
context.completion_threads = []
|
|
||||||
for prompt in context.prompts:
|
|
||||||
completion_thread = threading.Thread(target=request_completion, args=(context, prompt))
|
|
||||||
completion_thread.start()
|
|
||||||
context.completion_threads.append(completion_thread)
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'all prompts are predicted')
|
|
||||||
def step_all_prompts_must_be_predicted(context):
|
|
||||||
for completion_thread in context.completion_threads:
|
|
||||||
completion_thread.join()
|
|
||||||
for completion in context.completions:
|
|
||||||
assert_n_tokens_predicted(completion)
|
|
||||||
|
|
||||||
|
|
||||||
def request_completion(context, prompt, n_predict=None):
|
|
||||||
response = requests.post(f'{context.base_url}/completion', json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"n_predict": int(n_predict) if n_predict is not None else 4096,
|
|
||||||
})
|
|
||||||
status_code = response.status_code
|
|
||||||
assert status_code == 200
|
|
||||||
context.completions.append(response.json())
|
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
|
||||||
content = completion_response['content']
|
content = completion_response['content']
|
||||||
n_predicted = completion_response['timings']['predicted_n']
|
n_predicted = completion_response['timings']['predicted_n']
|
||||||
|
@ -163,10 +178,9 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
|
||||||
|
|
||||||
|
|
||||||
def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None):
|
def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None):
|
||||||
status_code = 500
|
while True:
|
||||||
while status_code != expected_http_status_code:
|
|
||||||
health_response = requests.get(f'{context.base_url}/health', params)
|
health_response = requests.get(f'{context.base_url}/health', params)
|
||||||
status_code = health_response.status_code
|
status_code = health_response.status_code
|
||||||
health = health_response.json()
|
health = health_response.json()
|
||||||
if status_code != expected_http_status_code or health['status'] != expected_health_status:
|
if status_code == expected_http_status_code and health['status'] == expected_health_status:
|
||||||
time.sleep(0.001)
|
break
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue