server: tests: switch to asyncio for concurrent tests, match result content with regex
This commit is contained in:
parent
016b221549
commit
e43406e36d
6 changed files with 393 additions and 161 deletions
54
examples/server/tests/features/parallel.feature
Normal file
54
examples/server/tests/features/parallel.feature
Normal file
|
@ -0,0 +1,54 @@
|
|||
@llama.cpp
|
||||
Feature: Parallel
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model file stories260K.gguf
|
||||
And a model alias tinyllama-2
|
||||
And 42 as server seed
|
||||
And 32 KV cache size
|
||||
And 2 slots
|
||||
And continuous batching
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario Outline: Multi users completion
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long story about AI.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another very long music lyrics.
|
||||
"""
|
||||
And <n_predict> max tokens to predict
|
||||
Given concurrent completion requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
And all slots are idle
|
||||
Then all prompts are predicted with <n_predict> tokens
|
||||
Examples:
|
||||
| n_predict |
|
||||
| 512 |
|
||||
|
||||
Scenario Outline: Multi users OAI completions compatibility
|
||||
Given a system prompt You are a writer.
|
||||
And a model tinyllama-2
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long book.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another a poem.
|
||||
"""
|
||||
And <n_predict> max tokens to predict
|
||||
And streaming is <streaming>
|
||||
Given concurrent OAI completions requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all prompts are predicted with <n_predict> tokens
|
||||
Examples:
|
||||
| streaming | n_predict |
|
||||
| disabled | 512 |
|
||||
#| enabled | 512 | FIXME: phymbert: need to investigate why in aiohttp with streaming only one token is generated
|
|
@ -48,4 +48,3 @@ Feature: Security
|
|||
| origin | Access-Control-Allow-Credentials | true |
|
||||
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
||||
|
||||
|
|
|
@ -20,12 +20,12 @@ Feature: llama.cpp server
|
|||
Given a prompt <prompt>
|
||||
And <n_predict> max tokens to predict
|
||||
And a completion request with no api error
|
||||
Then <n_predicted> tokens are predicted with content: <content>
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
|
||||
Examples: Prompts
|
||||
| prompt | n_predict | content | n_predicted |
|
||||
| I believe the meaning of life is | 8 | <space>going to read. | 8 |
|
||||
| Write a joke about AI | 64 | tion came to the park. And all his friends were very scared and did not | 32 |
|
||||
| prompt | n_predict | re_content | n_predicted |
|
||||
| I believe the meaning of life is | 8 | read | 8 |
|
||||
| Write a joke about AI | 64 | (park<or>friends<or>scared)+ | 32 |
|
||||
|
||||
Scenario Outline: OAI Compatibility
|
||||
Given a model <model>
|
||||
|
@ -34,12 +34,12 @@ Feature: llama.cpp server
|
|||
And <max_tokens> max tokens to predict
|
||||
And streaming is <enable_streaming>
|
||||
Given an OAI compatible chat completions request with no api error
|
||||
Then <n_predicted> tokens are predicted with content: <content>
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
|
||||
Examples: Prompts
|
||||
| model | system_prompt | user_prompt | max_tokens | content | n_predicted | enable_streaming |
|
||||
| llama-2 | Book | What is the best book | 8 | "Mom, what' | 8 | disabled |
|
||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | "Hey," said the bird.<LF>The bird was very happy and thanked the bird for hel | 32 | enabled |
|
||||
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
|
||||
| llama-2 | Book | What is the best book | 8 | (Mom<or>what)+ | 8 | disabled |
|
||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks<or>happy<or>bird)+ | 32 | enabled |
|
||||
|
||||
Scenario: Embedding
|
||||
When embeddings are computed for:
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
from contextlib import closing
|
||||
from re import RegexFlag
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
import requests
|
||||
from behave import step
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
|
||||
|
||||
@step(
|
||||
u"a server listening on {server_fqdn}:{server_port}")
|
||||
@step(u"a server listening on {server_fqdn}:{server_port}")
|
||||
def step_server_config(context, server_fqdn, server_port):
|
||||
context.server_fqdn = server_fqdn
|
||||
context.server_port = int(server_port)
|
||||
|
||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
context.server_continuous_batching = False
|
||||
context.model_alias = None
|
||||
context.n_ctx = None
|
||||
context.n_predict = None
|
||||
|
@ -27,7 +32,7 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
context.user_api_key = None
|
||||
|
||||
context.completions = []
|
||||
context.completion_threads = []
|
||||
context.concurrent_completion_tasks = []
|
||||
context.prompts = []
|
||||
|
||||
|
||||
|
@ -61,39 +66,50 @@ def step_server_n_predict(context, n_predict):
|
|||
context.n_server_predict = int(n_predict)
|
||||
|
||||
|
||||
@step(u'continuous batching')
|
||||
def step_server_continuous_batching(context):
|
||||
context.server_continuous_batching = True
|
||||
|
||||
|
||||
@step(u"the server is starting")
|
||||
def step_start_server(context):
|
||||
start_server_background(context)
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
||||
if result == 0:
|
||||
return
|
||||
|
||||
|
||||
@step(u"the server is {expecting_status}")
|
||||
def step_wait_for_the_server_to_be_started(context, expecting_status):
|
||||
@async_run_until_complete
|
||||
async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
||||
match expecting_status:
|
||||
case 'starting':
|
||||
start_server_background(context)
|
||||
server_started = False
|
||||
while not server_started:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
||||
if result == 0:
|
||||
return 0
|
||||
case 'loading model':
|
||||
wait_for_health_status(context, 503, 'loading model')
|
||||
case 'healthy':
|
||||
wait_for_health_status(context, 200, 'ok')
|
||||
await wait_for_health_status(context.base_url, 200, 'ok')
|
||||
|
||||
case 'ready' | 'idle':
|
||||
wait_for_health_status(context, 200, 'ok',
|
||||
params={'fail_on_no_slot': True},
|
||||
slots_idle=context.n_slots,
|
||||
slots_processing=0)
|
||||
request_slots_status(context, [{'id': slot_id, 'state': 0} for slot_id in range(context.n_slots)])
|
||||
await wait_for_health_status(context.base_url, 200, 'ok',
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=context.n_slots,
|
||||
slots_processing=0,
|
||||
expected_slots=[{'id': slot_id, 'state': 0}
|
||||
for slot_id in range(context.n_slots)])
|
||||
case 'busy':
|
||||
wait_for_health_status(context, 503, 'no slot available',
|
||||
params={'fail_on_no_slot': True},
|
||||
slots_idle=0,
|
||||
slots_processing=context.n_slots)
|
||||
request_slots_status(context, [{'id': slot_id, 'state': 1} for slot_id in range(context.n_slots)])
|
||||
await wait_for_health_status(context.base_url, 503,
|
||||
'no slot available',
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=0,
|
||||
slots_processing=context.n_slots,
|
||||
expected_slots=[{'id': slot_id, 'state': 1}
|
||||
for slot_id in range(context.n_slots)])
|
||||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
|
||||
@step(u'all slots are {expected_slot_status_string}')
|
||||
def step_all_slots_status(context, expected_slot_status_string):
|
||||
@async_run_until_complete
|
||||
async def step_all_slots_status(context, expected_slot_status_string):
|
||||
match expected_slot_status_string:
|
||||
case 'idle':
|
||||
expected_slot_status = 0
|
||||
|
@ -102,36 +118,40 @@ def step_all_slots_status(context, expected_slot_status_string):
|
|||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
expected_slots = []
|
||||
for slot_id in range(context.n_slots):
|
||||
expected_slots.append({
|
||||
'id': slot_id,
|
||||
'state': expected_slot_status
|
||||
})
|
||||
request_slots_status(context, expected_slots)
|
||||
expected_slots = [{'id': slot_id, 'state': expected_slot_status}
|
||||
for slot_id in range(context.n_slots)]
|
||||
await request_slots_status(context, expected_slots)
|
||||
|
||||
|
||||
@step(u'a completion request with {api_error} api error')
|
||||
def step_request_completion(context, api_error):
|
||||
request_completion(context, context.prompts.pop(),
|
||||
n_predict=context.n_predict,
|
||||
expect_api_error=api_error == 'raised')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error):
|
||||
expect_api_error = api_error == 'raised'
|
||||
completion = await request_completion(context.prompts.pop(),
|
||||
context.base_url,
|
||||
n_predict=context.n_predict,
|
||||
server_seed=context.server_seed,
|
||||
expect_api_error=expect_api_error,
|
||||
user_api_key=context.user_api_key)
|
||||
context.completions.append(completion)
|
||||
print(f"Completion response: {completion}")
|
||||
if expect_api_error:
|
||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||
|
||||
|
||||
@step(u'{predicted_n} tokens are predicted with content: {content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, content):
|
||||
assert_n_tokens_predicted(context.completions[0], int(predicted_n), content)
|
||||
@step(u'{predicted_n} tokens are predicted matching {re_content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n), re_content)
|
||||
|
||||
|
||||
@step(u'{predicted_n} tokens are predicted')
|
||||
def step_n_tokens_predicted(context, predicted_n):
|
||||
if int(predicted_n) > 0:
|
||||
assert_n_tokens_predicted(context.completions[0], int(predicted_n))
|
||||
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n))
|
||||
|
||||
|
||||
@step(u'a user prompt {user_prompt}')
|
||||
def step_user_prompt(context, user_prompt):
|
||||
context.user_prompt = user_prompt
|
||||
context.prompts.append(user_prompt)
|
||||
|
||||
|
||||
@step(u'a system prompt {system_prompt}')
|
||||
|
@ -151,7 +171,7 @@ def step_max_tokens(context, max_tokens):
|
|||
|
||||
@step(u'streaming is {enable_streaming}')
|
||||
def step_streaming(context, enable_streaming):
|
||||
context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
|
||||
context.enable_streaming = enable_streaming == 'enabled'
|
||||
|
||||
|
||||
@step(u'a user api key {user_api_key}')
|
||||
|
@ -175,8 +195,35 @@ def step_server_api_key(context, server_api_key):
|
|||
|
||||
|
||||
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
||||
def step_oai_chat_completions(context, api_error):
|
||||
oai_chat_completions(context, context.user_prompt, api_error=api_error == 'raised')
|
||||
@async_run_until_complete
|
||||
async def step_oai_chat_completions(context, api_error):
|
||||
print(f"Submitting OAI compatible completions request...")
|
||||
expect_api_error = api_error == 'raised'
|
||||
completion = await oai_chat_completions(context.prompts.pop(),
|
||||
context.system_prompt,
|
||||
context.base_url,
|
||||
False,
|
||||
model=context.model if hasattr(context, 'model') else None,
|
||||
|
||||
n_predict=context.n_predict
|
||||
if hasattr(context, 'n_predict') else None,
|
||||
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
|
||||
server_seed=context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None,
|
||||
|
||||
expect_api_error=expect_api_error)
|
||||
context.completions.append(completion)
|
||||
print(f"Completion response: {completion}")
|
||||
if expect_api_error:
|
||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||
|
||||
print(f"Completion response: {completion}")
|
||||
|
||||
|
||||
@step(u'a prompt')
|
||||
|
@ -190,22 +237,49 @@ def step_a_prompt_prompt(context, prompt):
|
|||
|
||||
|
||||
@step(u'concurrent completion requests')
|
||||
def step_concurrent_completion_requests(context):
|
||||
concurrent_requests(context, request_completion)
|
||||
@async_run_until_complete()
|
||||
async def step_concurrent_completion_requests(context):
|
||||
await concurrent_completion_requests(context,
|
||||
request_completion,
|
||||
# prompt is inserted automatically
|
||||
context.base_url,
|
||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||
user_api_key=context.user_api_key if hasattr(context,
|
||||
'user_api_key') else None)
|
||||
|
||||
|
||||
@step(u'concurrent OAI completions requests')
|
||||
def step_oai_chat_completions(context):
|
||||
concurrent_requests(context, oai_chat_completions)
|
||||
@async_run_until_complete
|
||||
async def step_oai_chat_completions(context):
|
||||
await concurrent_completion_requests(context, oai_chat_completions,
|
||||
# user_prompt is inserted automatically
|
||||
context.system_prompt,
|
||||
context.base_url,
|
||||
True, # async_client
|
||||
model=context.model
|
||||
if hasattr(context, 'model') else None,
|
||||
n_predict=context.n_predict
|
||||
if hasattr(context, 'n_predict') else None,
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
server_seed=context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None)
|
||||
|
||||
|
||||
@step(u'all prompts are predicted')
|
||||
def step_all_prompts_are_predicted(context):
|
||||
for completion_thread in context.completion_threads:
|
||||
completion_thread.join()
|
||||
assert len(context.completions) == len(context.completion_threads)
|
||||
for completion in context.completions:
|
||||
assert_n_tokens_predicted(completion)
|
||||
@step(u'all prompts are predicted with {n_predict} tokens')
|
||||
@async_run_until_complete
|
||||
async def step_all_prompts_are_predicted(context, n_predict):
|
||||
n_completion_tasks = len(context.concurrent_completion_tasks)
|
||||
print(f"Waiting for all {n_completion_tasks} completion responses...")
|
||||
for task_no in range(n_completion_tasks):
|
||||
context.completions.append(await context.concurrent_completion_tasks.pop())
|
||||
n_completions = len(context.completions)
|
||||
assert n_completions > 0
|
||||
for i in range(n_completions):
|
||||
assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=int(n_predict))
|
||||
|
||||
|
||||
@step(u'embeddings are computed for')
|
||||
|
@ -269,126 +343,228 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
|
|||
assert context.options_response.headers[cors_header] == cors_header_value
|
||||
|
||||
|
||||
def concurrent_requests(context, f_completion, *argv):
|
||||
context.completions.clear()
|
||||
context.completion_threads.clear()
|
||||
for prompt in context.prompts:
|
||||
completion_thread = threading.Thread(target=f_completion, args=(context, prompt, *argv))
|
||||
completion_thread.start()
|
||||
context.completion_threads.append(completion_thread)
|
||||
context.prompts.clear()
|
||||
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
||||
n_prompts = len(context.prompts)
|
||||
print(f"starting {n_prompts} concurrent completion requests...")
|
||||
assert n_prompts > 0
|
||||
for prompt_no in range(n_prompts):
|
||||
shifted_args = [context.prompts.pop(), *args]
|
||||
context.concurrent_completion_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
def request_completion(context, prompt, n_predict=None, expect_api_error=None):
|
||||
async def request_completion(prompt,
|
||||
base_url,
|
||||
n_predict=None,
|
||||
server_seed=None,
|
||||
expect_api_error=None,
|
||||
user_api_key=None):
|
||||
print(f"Sending completion request: {prompt}")
|
||||
origin = "my.super.domain"
|
||||
headers = {
|
||||
'Origin': origin
|
||||
}
|
||||
if context.user_api_key is not None:
|
||||
print(f"Set user_api_key: {context.user_api_key}")
|
||||
headers['Authorization'] = f'Bearer {context.user_api_key}'
|
||||
if user_api_key is not None:
|
||||
print(f"Set user_api_key: {user_api_key}")
|
||||
headers['Authorization'] = f'Bearer {user_api_key}'
|
||||
|
||||
response = requests.post(f'{context.base_url}/completion',
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else context.n_predict,
|
||||
"seed": context.server_seed if context.server_seed is not None else 42
|
||||
},
|
||||
headers=headers)
|
||||
if expect_api_error is not None and not expect_api_error:
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
context.completions.append(response.json())
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/completion',
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else -1,
|
||||
"seed": server_seed if server_seed is not None else 42
|
||||
},
|
||||
headers=headers) as response:
|
||||
if expect_api_error is None or not expect_api_error:
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
return await response.json()
|
||||
else:
|
||||
return response.status
|
||||
|
||||
|
||||
def oai_chat_completions(context, user_prompt, api_error=None):
|
||||
openai.api_key = 'nope' # openai client always expects an api_keu
|
||||
if context.user_api_key is not None:
|
||||
openai.api_key = context.user_api_key
|
||||
openai.api_base = f'{context.base_url}/v1/chat'
|
||||
try:
|
||||
chat_completion = openai.Completion.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": context.system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
model=context.model,
|
||||
max_tokens=context.n_predict,
|
||||
stream=context.enable_streaming,
|
||||
seed=context.server_seed if context.server_seed is not None else 42
|
||||
)
|
||||
except openai.error.APIError:
|
||||
if api_error is not None and api_error:
|
||||
return
|
||||
if context.enable_streaming:
|
||||
completion_response = {
|
||||
'content': '',
|
||||
'timings': {
|
||||
'predicted_n': 0
|
||||
async def oai_chat_completions(user_prompt,
|
||||
system_prompt,
|
||||
base_url,
|
||||
async_client,
|
||||
model=None,
|
||||
n_predict=None,
|
||||
enable_streaming=None,
|
||||
server_seed=None,
|
||||
user_api_key=None,
|
||||
expect_api_error=None):
|
||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||
# openai client always expects an api key
|
||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||
seed = server_seed if server_seed is not None else 42
|
||||
enable_streaming = enable_streaming if enable_streaming is not None else False
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"model": model,
|
||||
"max_tokens": n_predict,
|
||||
"stream": enable_streaming,
|
||||
"seed": seed
|
||||
}
|
||||
completion_response = {
|
||||
'content': '',
|
||||
'timings': {
|
||||
'predicted_n': 0
|
||||
}
|
||||
for chunk in chat_completion:
|
||||
assert len(chunk.choices) == 1
|
||||
delta = chunk.choices[0].delta
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
context.completions.append(completion_response)
|
||||
}
|
||||
if async_client:
|
||||
origin = 'llama.cpp'
|
||||
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/v1/chat/completions',
|
||||
json=payload,
|
||||
headers=headers) as response:
|
||||
if enable_streaming:
|
||||
print("payload", payload)
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
assert response.headers['Content-Type'] == "text/event-stream"
|
||||
|
||||
async for line_in_bytes in response.content:
|
||||
line = line_in_bytes.decode('utf8')
|
||||
event_data = line.split(': ', 1)
|
||||
assert event_data[0] == 'data', f'{event_data}'
|
||||
chunk_raw = event_data[1]
|
||||
|
||||
chunk = json.loads(chunk_raw)
|
||||
assert len(chunk['choices']) == 1
|
||||
delta = chunk['choices'][0]['delta']
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}")
|
||||
else:
|
||||
print(f"raw completion response: {response}")
|
||||
if expect_api_error is None or not expect_api_error:
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
|
||||
chat_completion_raw = await response.json()
|
||||
completion_response = {
|
||||
'content': chat_completion_raw['choices'][0]['message'],
|
||||
'timings': {
|
||||
'predicted_n': chat_completion_raw['usage']['completion_tokens']
|
||||
}
|
||||
}
|
||||
else:
|
||||
return response.status
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
context.completions.append({
|
||||
'content': chat_completion.choices[0].message,
|
||||
'timings': {
|
||||
'predicted_n': chat_completion.usage.completion_tokens
|
||||
try:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}/v1/chat'
|
||||
chat_completion = openai.Completion.create(
|
||||
messages=payload['messages'],
|
||||
model=model,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
seed=seed
|
||||
)
|
||||
except openai.error.APIError as e:
|
||||
if expect_api_error is not None and expect_api_error:
|
||||
return 401
|
||||
else:
|
||||
assert False, f'error raised: {e}'
|
||||
|
||||
if enable_streaming:
|
||||
for chunk in chat_completion:
|
||||
assert len(chunk.choices) == 1
|
||||
delta = chunk.choices[0].delta
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'timings': {
|
||||
'predicted_n': chat_completion.usage.completion_tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
print("OAI response formatted to llama.cpp", completion_response)
|
||||
return completion_response
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, expected_content=None):
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||
content = completion_response['content']
|
||||
n_predicted = completion_response['timings']['predicted_n']
|
||||
assert len(content) > 0, "no token predicted"
|
||||
if expected_predicted_n is not None:
|
||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||
f' {n_predicted} <> {expected_predicted_n}')
|
||||
if expected_content is not None:
|
||||
expected_content = expected_content.replace('<space>', ' ').replace('<LF>', '\n')
|
||||
assert content == expected_content, (f'invalid tokens predicted:'
|
||||
f' ```\n{content}\n``` <> ```\n{expected_content}\n```')
|
||||
if re_content is not None:
|
||||
re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
|
||||
assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
|
||||
f'invalid tokens predicted:'
|
||||
f' ```\n{content}\n``` do not match /{re_content}/')
|
||||
|
||||
|
||||
def wait_for_health_status(context, expected_http_status_code,
|
||||
expected_health_status,
|
||||
params=None,
|
||||
slots_idle=None,
|
||||
slots_processing=None):
|
||||
while True:
|
||||
health_response = requests.get(f'{context.base_url}/health', params)
|
||||
status_code = health_response.status_code
|
||||
health = health_response.json()
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
break
|
||||
async def wait_for_health_status(base_url,
|
||||
expected_http_status_code,
|
||||
expected_health_status,
|
||||
params=None,
|
||||
slots_idle=None,
|
||||
slots_processing=None,
|
||||
expected_slots=None):
|
||||
print(f"Starting checking for health for expected_health_status={expected_health_status}")
|
||||
timeout = 3 # seconds
|
||||
interval = 0.5
|
||||
counter = 0
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
async with await session.get(f'{base_url}/health', params=params) as health_response:
|
||||
status_code = health_response.status
|
||||
health = await health_response.json()
|
||||
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
|
||||
f"'{base_url}/health'?{params} is {health}")
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
if expected_slots is not None:
|
||||
assert_slots_status(health['slots'], expected_slots)
|
||||
return
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
if expected_slots is not None:
|
||||
assert_slots_status(health['slots'], expected_slots)
|
||||
return
|
||||
await asyncio.sleep(interval)
|
||||
counter += interval
|
||||
if counter >= timeout:
|
||||
assert False, 'timeout exceeded'
|
||||
|
||||
|
||||
def request_slots_status(context, expected_slots):
|
||||
slots_response = requests.get(f'{context.base_url}/slots')
|
||||
assert slots_response.status_code == 200
|
||||
slots = slots_response.json()
|
||||
async def request_slots_status(context, expected_slots):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with await session.get(f'{context.base_url}/slots') as slots_response:
|
||||
assert slots_response.status == 200
|
||||
slots = await slots_response.json()
|
||||
assert_slots_status(slots, expected_slots)
|
||||
|
||||
|
||||
def assert_slots_status(slots, expected_slots):
|
||||
assert len(slots) == len(expected_slots)
|
||||
for expected, slot in zip(expected_slots, slots):
|
||||
for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
|
||||
for key in expected:
|
||||
assert expected[key] == slot[key], f"expected[{key}] != slot[{key}]"
|
||||
assert expected[key] == slot[key], (f"invalid slot {slot_id}"
|
||||
f" expected[{key}] != slot[{key}]"
|
||||
f" = {expected[key]} != {slot[key]}")
|
||||
|
||||
|
||||
def start_server_background(context):
|
||||
|
@ -398,6 +574,8 @@ def start_server_background(context):
|
|||
server_args = [
|
||||
'--model', context.model_file
|
||||
]
|
||||
if context.server_continuous_batching:
|
||||
server_args.append('--cont-batching')
|
||||
if context.model_alias is not None:
|
||||
server_args.extend(['--alias', context.model_alias])
|
||||
if context.server_seed is not None:
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
aiohttp~=3.9.3
|
||||
behave~=1.2.6
|
||||
openai~=0.25.0
|
||||
|
|
|
@ -3,4 +3,4 @@
|
|||
set -eu
|
||||
|
||||
# Start @llama.cpp scenario
|
||||
behave --summary --stop --tags llama.cpp
|
||||
behave --summary --stop --no-capture --tags llama.cpp
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue