server: tests: refactor steps and vocabulary
This commit is contained in:
parent
6c95ec6587
commit
56583bee41
2 changed files with 143 additions and 156 deletions
|
@ -1,39 +1,58 @@
|
|||
Feature: llama.cpp server
|
||||
|
||||
Background: The server is started and ready to accept prompts
|
||||
When wait for the server to be started
|
||||
Then wait for the server to be healthy
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080 with 2 slots
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario: Health endpoint
|
||||
Given an health liveness probe
|
||||
Then the server must be healthy
|
||||
Scenario: Health
|
||||
When the server is healthy
|
||||
Then the server is ready
|
||||
|
||||
Scenario Outline: run a completion request
|
||||
Given a prompt <prompt>
|
||||
When we request a completion
|
||||
Then tokens are predicted
|
||||
Scenario Outline: Completion
|
||||
Given a <prompt> completion request with maximum <n_predict> tokens
|
||||
Then <predicted_n> tokens are predicted
|
||||
|
||||
Examples: Prompts
|
||||
| prompt |
|
||||
| I believe |
|
||||
| Write a joke |
|
||||
| prompt | n_predict | predicted_n |
|
||||
| I believe the meaning of life is | 128 | 128 |
|
||||
| Write a joke about AI | 512 | 512 |
|
||||
|
||||
Scenario Outline: run a completion on the OAI endpoint
|
||||
Scenario Outline: OAI Compatibility
|
||||
Given a system prompt <system_prompt>
|
||||
And a user prompt <user_prompt>
|
||||
And a model <model>
|
||||
When we request the oai completions endpoint
|
||||
Then the oai response contains completion tokens
|
||||
And a user prompt <user_prompt>
|
||||
And a model <model>
|
||||
And <max_tokens> max tokens to predict
|
||||
Given an OAI compatible chat completions request
|
||||
Then <predicted_n> tokens are predicted
|
||||
|
||||
Examples: Prompts
|
||||
| model | system_prompt | user_prompt |
|
||||
| tinyllama-2 | You are ChatGPT. | Say hello |
|
||||
| tinyllama-2 | You are a coding assistant. | Write the fibonacci function in c++ |
|
||||
| model | system_prompt | user_prompt | max_tokens | predicted_n |
|
||||
| llama-2 | You are ChatGPT. | Say hello. | 64 | 64 |
|
||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 512 | 512 |
|
||||
|
||||
|
||||
Scenario: Health endpoint during processing with concurrent requests
|
||||
Given 2 slow concurrent prompts
|
||||
Then wait for all slots processing
|
||||
Then the server is overloaded
|
||||
When wait for all slots idle
|
||||
Then all prompts must be predicted
|
||||
Scenario: Multi users
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a formal complaint email to Air France about my delayed
|
||||
baggage from my flight on Tuesday, January 17th, from Paris to Toulouse. Be verbose.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Translate the following War & Peace chapter into Russian: WELL, PRINCE,
|
||||
Genoa and Lucca are now no more than private estates of the Bonaparte
|
||||
family. No, I warn you, that if you do not tell me we are at war,
|
||||
if you again allow yourself to palliate all the infamies and atrocities
|
||||
of this Antichrist (upon my word, I believe he is), I don’t know you
|
||||
in future, you are no longer my friend, no longer my faithful slave,
|
||||
as you say. There, how do you do, how do you do? I see I’m scaring you,
|
||||
sit down and talk to me.” These words were uttered in July 1805 by
|
||||
Anna Pavlovna Scherer, a distinguished lady of the court,
|
||||
and confidential maid-of-honour to the Empress Marya Fyodorovna.
|
||||
It was her greeting to Prince Vassily, a man high in rank
|
||||
and office, who was the first to arrive at her soirée.
|
||||
"""
|
||||
Given concurrent completion requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all prompts are predicted
|
|
@ -6,82 +6,52 @@ from contextlib import closing
|
|||
import openai
|
||||
import requests
|
||||
from behave import step
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
|
||||
base_fqdn = 'localhost'
|
||||
base_port = 8080
|
||||
base_url = f"http://{base_fqdn}:{base_port}"
|
||||
|
||||
openai.api_key = 'llama.cpp'
|
||||
openai.api_base = f"{base_url}/v1/chat"
|
||||
|
||||
slow_prompt = 'say hello ' * 10
|
||||
fast_prompt = 'Write a joke'
|
||||
|
||||
n_slots = 2
|
||||
|
||||
|
||||
@step(u'wait for the server to be started')
|
||||
def step_wait_for_the_server_to_be_started(context):
|
||||
server_started = False
|
||||
while not server_started:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((base_fqdn, base_port))
|
||||
if result != 0:
|
||||
print("server not ready: ", base_fqdn, base_port, result)
|
||||
time.sleep(1)
|
||||
else:
|
||||
return 0
|
||||
@step(u"a server listening on {server_fqdn}:{server_port} with {n_slots} slots")
|
||||
def step_server_config(context, server_fqdn, server_port, n_slots):
|
||||
context.server_fqdn = server_fqdn
|
||||
context.server_port = int(server_port)
|
||||
context.n_slots = int(n_slots)
|
||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
context.completions = []
|
||||
context.prompts = []
|
||||
|
||||
openai.api_key = 'llama.cpp'
|
||||
openai.api_base = f'{context.base_url}/v1/chat'
|
||||
|
||||
|
||||
@step(u'wait for the server to be healthy')
|
||||
def step_wait_for_the_server_to_be_healthy(context):
|
||||
status_code = 500
|
||||
while status_code != 200:
|
||||
status_code = requests.get(f'{base_url}/health').status_code
|
||||
if status_code != 200:
|
||||
time.sleep(1)
|
||||
@step(u"the server is {expecting_status}")
|
||||
def step_wait_for_the_server_to_be_started(context, expecting_status):
|
||||
match expecting_status:
|
||||
case 'starting':
|
||||
server_started = False
|
||||
while not server_started:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
||||
if result == 0:
|
||||
return 0
|
||||
case 'loading model':
|
||||
wait_for_health_status(context, 503, 'loading model')
|
||||
case 'healthy':
|
||||
wait_for_health_status(context, 200, 'ok')
|
||||
case 'ready' | 'idle':
|
||||
wait_for_health_status(context, 200, 'ok', params={'fail_on_no_slot': True})
|
||||
case 'busy':
|
||||
wait_for_health_status(context, 503, 'no slot available', params={'fail_on_no_slot': True})
|
||||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
|
||||
@step(u'an health liveness probe')
|
||||
def step_an_health_liveness_probe(context):
|
||||
response = requests.get(f'{base_url}/health')
|
||||
context.status_code = response.status_code
|
||||
context.response_data = response.json()
|
||||
@step(u'a {prompt} completion request with maximum {n_predict} tokens')
|
||||
def step_request_completion(context, prompt, n_predict):
|
||||
request_completion(context, prompt, n_predict)
|
||||
|
||||
|
||||
@step(u'the server must be healthy')
|
||||
def step_server_healthy(context):
|
||||
assert context.status_code == 200
|
||||
assert context.response_data['status'] == 'ok'
|
||||
|
||||
|
||||
@step(u'the server is overloaded')
|
||||
@async_run_until_complete()
|
||||
async def step_server_overloaded(context):
|
||||
response = requests.get(f'{base_url}/health?fail_on_no_slot')
|
||||
assert response.status_code == 503
|
||||
assert response.json()['status'] == 'no slot available'
|
||||
|
||||
|
||||
@step(u'a prompt {prompt}')
|
||||
def step_prompt(context, prompt):
|
||||
context.prompt = prompt
|
||||
|
||||
|
||||
@step(u'we request a completion')
|
||||
def step_request_completion(context):
|
||||
response = requests.post(f'{base_url}/completion', json={
|
||||
"prompt": context.prompt
|
||||
})
|
||||
status_code = response.status_code
|
||||
assert status_code == 200
|
||||
context.response_data = response.json()
|
||||
|
||||
|
||||
@step(u'tokens are predicted')
|
||||
def step_request_completion(context):
|
||||
prompt_predicted(context.response_data)
|
||||
@step(u'{predicted_n} tokens are predicted')
|
||||
def step_n_tokens_predicted(context, predicted_n):
|
||||
assert_n_tokens_predicted(context.completions[0], int(predicted_n))
|
||||
|
||||
|
||||
@step(u'a user prompt {user_prompt}')
|
||||
|
@ -99,9 +69,14 @@ def step_model(context, model):
|
|||
context.model = model
|
||||
|
||||
|
||||
@step(u'we request the oai completions endpoint')
|
||||
def step_oai_completions(context):
|
||||
context.chat_completion = openai.Completion.create(
|
||||
@step(u'{max_tokens} max tokens to predict')
|
||||
def step_max_tokens(context, max_tokens):
|
||||
context.max_tokens = int(max_tokens)
|
||||
|
||||
|
||||
@step(u'an OAI compatible chat completions request')
|
||||
def step_oai_chat_completions(context):
|
||||
chat_completion = openai.Completion.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -113,70 +88,63 @@ def step_oai_completions(context):
|
|||
}
|
||||
],
|
||||
model=context.model,
|
||||
max_tokens=context.max_tokens
|
||||
)
|
||||
|
||||
|
||||
@step(u'the oai response contains completion tokens')
|
||||
def step_oai_response_has_completion_tokens(context):
|
||||
assert len(context.chat_completion.choices) == 1
|
||||
assert len(context.chat_completion.choices[0].message) > 0
|
||||
assert context.chat_completion.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def async_prompt(context, prompt):
|
||||
response = requests.post(f'{base_url}/completion', json={
|
||||
"prompt": prompt
|
||||
context.completions.append({
|
||||
'content': chat_completion.choices[0].message,
|
||||
'timings': {
|
||||
'predicted_n': chat_completion.usage.completion_tokens
|
||||
}
|
||||
})
|
||||
|
||||
context.async_responses.append(response)
|
||||
|
||||
@step(u'a prompt')
|
||||
def step_a_prompt(context):
|
||||
context.prompts.append(context.text)
|
||||
|
||||
|
||||
@step(u'{n_prompt} {prompt_type} concurrent prompts')
|
||||
def step_n_concurrent_prompts(context, n_prompt, prompt_type):
|
||||
prompt = fast_prompt
|
||||
if prompt_type == 'slow':
|
||||
prompt = slow_prompt
|
||||
context.async_responses = []
|
||||
context.threads = []
|
||||
for i in range(int(n_prompt)):
|
||||
thread = threading.Thread(target=async_prompt, args=(context, prompt))
|
||||
thread.start()
|
||||
context.threads.append(thread)
|
||||
@step(u'concurrent completion requests')
|
||||
def step_n_concurrent_prompts(context):
|
||||
context.completions.clear()
|
||||
context.completion_threads = []
|
||||
for prompt in context.prompts:
|
||||
completion_thread = threading.Thread(target=request_completion, args=(context, prompt))
|
||||
completion_thread.start()
|
||||
context.completion_threads.append(completion_thread)
|
||||
|
||||
|
||||
def wait_for_slots_processing(context, expected_slots_processing):
|
||||
while True:
|
||||
health = requests.get(f'{base_url}/health').json()
|
||||
if 'slots_processing' in health: # FIXME when #5594 is merged
|
||||
slots_processing = health['slots_processing']
|
||||
else:
|
||||
slots_processing = 0
|
||||
if slots_processing == expected_slots_processing:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@step(u'wait for all slots processing')
|
||||
def step_wait_for_all_slots_processing(context):
|
||||
wait_for_slots_processing(context, n_slots)
|
||||
|
||||
|
||||
@step(u'wait for all slots idle')
|
||||
def step_wait_for_all_slots_idle(context):
|
||||
wait_for_slots_processing(context, 0)
|
||||
|
||||
|
||||
@step(u'all prompts must be predicted')
|
||||
@step(u'all prompts are predicted')
|
||||
def step_all_prompts_must_be_predicted(context):
|
||||
for thread in context.threads:
|
||||
thread.join()
|
||||
for async_response in context.async_responses:
|
||||
assert async_response.status_code == 200
|
||||
response_data = async_response.json()
|
||||
prompt_predicted(response_data)
|
||||
for completion_thread in context.completion_threads:
|
||||
completion_thread.join()
|
||||
for completion in context.completions:
|
||||
assert_n_tokens_predicted(completion)
|
||||
|
||||
|
||||
def prompt_predicted(response_data):
|
||||
assert len(response_data['content']) > 0
|
||||
assert response_data['timings']['predicted_n'] > 0
|
||||
def request_completion(context, prompt, n_predict=None):
|
||||
response = requests.post(f'{context.base_url}/completion', json={
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else 4096,
|
||||
})
|
||||
status_code = response.status_code
|
||||
assert status_code == 200
|
||||
context.completions.append(response.json())
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None):
|
||||
content = completion_response['content']
|
||||
n_predicted = completion_response['timings']['predicted_n']
|
||||
assert len(content) > 0, "no token predicted"
|
||||
if expected_predicted_n is not None:
|
||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||
f' "{n_predicted}" <> "{expected_predicted_n}"')
|
||||
|
||||
|
||||
def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None):
|
||||
status_code = 500
|
||||
while status_code != expected_http_status_code:
|
||||
health_response = requests.get(f'{context.base_url}/health', params)
|
||||
status_code = health_response.status_code
|
||||
health = health_response.json()
|
||||
if status_code != expected_http_status_code or health['status'] != expected_health_status:
|
||||
time.sleep(0.001)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue