server: tests: switch to asyncio for concurrent tests, match result content with regex
This commit is contained in:
parent
016b221549
commit
e43406e36d
6 changed files with 393 additions and 161 deletions
54
examples/server/tests/features/parallel.feature
Normal file
54
examples/server/tests/features/parallel.feature
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
@llama.cpp
|
||||||
|
Feature: Parallel
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file stories260K.gguf
|
||||||
|
And a model alias tinyllama-2
|
||||||
|
And 42 as server seed
|
||||||
|
And 32 KV cache size
|
||||||
|
And 2 slots
|
||||||
|
And continuous batching
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario Outline: Multi users completion
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long story about AI.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write another very long music lyrics.
|
||||||
|
"""
|
||||||
|
And <n_predict> max tokens to predict
|
||||||
|
Given concurrent completion requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
And all slots are idle
|
||||||
|
Then all prompts are predicted with <n_predict> tokens
|
||||||
|
Examples:
|
||||||
|
| n_predict |
|
||||||
|
| 512 |
|
||||||
|
|
||||||
|
Scenario Outline: Multi users OAI completions compatibility
|
||||||
|
Given a system prompt You are a writer.
|
||||||
|
And a model tinyllama-2
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long book.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write another a poem.
|
||||||
|
"""
|
||||||
|
And <n_predict> max tokens to predict
|
||||||
|
And streaming is <streaming>
|
||||||
|
Given concurrent OAI completions requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
Then all prompts are predicted with <n_predict> tokens
|
||||||
|
Examples:
|
||||||
|
| streaming | n_predict |
|
||||||
|
| disabled | 512 |
|
||||||
|
#| enabled | 512 | FIXME: phymbert: need to investigate why in aiohttp with streaming only one token is generated
|
|
@ -48,4 +48,3 @@ Feature: Security
|
||||||
| origin | Access-Control-Allow-Credentials | true |
|
| origin | Access-Control-Allow-Credentials | true |
|
||||||
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
||||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
||||||
|
|
||||||
|
|
|
@ -20,12 +20,12 @@ Feature: llama.cpp server
|
||||||
Given a prompt <prompt>
|
Given a prompt <prompt>
|
||||||
And <n_predict> max tokens to predict
|
And <n_predict> max tokens to predict
|
||||||
And a completion request with no api error
|
And a completion request with no api error
|
||||||
Then <n_predicted> tokens are predicted with content: <content>
|
Then <n_predicted> tokens are predicted matching <re_content>
|
||||||
|
|
||||||
Examples: Prompts
|
Examples: Prompts
|
||||||
| prompt | n_predict | content | n_predicted |
|
| prompt | n_predict | re_content | n_predicted |
|
||||||
| I believe the meaning of life is | 8 | <space>going to read. | 8 |
|
| I believe the meaning of life is | 8 | read | 8 |
|
||||||
| Write a joke about AI | 64 | tion came to the park. And all his friends were very scared and did not | 32 |
|
| Write a joke about AI | 64 | (park<or>friends<or>scared)+ | 32 |
|
||||||
|
|
||||||
Scenario Outline: OAI Compatibility
|
Scenario Outline: OAI Compatibility
|
||||||
Given a model <model>
|
Given a model <model>
|
||||||
|
@ -34,12 +34,12 @@ Feature: llama.cpp server
|
||||||
And <max_tokens> max tokens to predict
|
And <max_tokens> max tokens to predict
|
||||||
And streaming is <enable_streaming>
|
And streaming is <enable_streaming>
|
||||||
Given an OAI compatible chat completions request with no api error
|
Given an OAI compatible chat completions request with no api error
|
||||||
Then <n_predicted> tokens are predicted with content: <content>
|
Then <n_predicted> tokens are predicted matching <re_content>
|
||||||
|
|
||||||
Examples: Prompts
|
Examples: Prompts
|
||||||
| model | system_prompt | user_prompt | max_tokens | content | n_predicted | enable_streaming |
|
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
|
||||||
| llama-2 | Book | What is the best book | 8 | "Mom, what' | 8 | disabled |
|
| llama-2 | Book | What is the best book | 8 | (Mom<or>what)+ | 8 | disabled |
|
||||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | "Hey," said the bird.<LF>The bird was very happy and thanked the bird for hel | 32 | enabled |
|
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks<or>happy<or>bird)+ | 32 | enabled |
|
||||||
|
|
||||||
Scenario: Embedding
|
Scenario: Embedding
|
||||||
When embeddings are computed for:
|
When embeddings are computed for:
|
||||||
|
|
|
@ -1,22 +1,27 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
from re import RegexFlag
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from behave import step
|
from behave import step
|
||||||
|
from behave.api.async_step import async_run_until_complete
|
||||||
|
|
||||||
|
|
||||||
@step(
|
@step(u"a server listening on {server_fqdn}:{server_port}")
|
||||||
u"a server listening on {server_fqdn}:{server_port}")
|
|
||||||
def step_server_config(context, server_fqdn, server_port):
|
def step_server_config(context, server_fqdn, server_port):
|
||||||
context.server_fqdn = server_fqdn
|
context.server_fqdn = server_fqdn
|
||||||
context.server_port = int(server_port)
|
context.server_port = int(server_port)
|
||||||
|
|
||||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||||
|
|
||||||
|
context.server_continuous_batching = False
|
||||||
context.model_alias = None
|
context.model_alias = None
|
||||||
context.n_ctx = None
|
context.n_ctx = None
|
||||||
context.n_predict = None
|
context.n_predict = None
|
||||||
|
@ -27,7 +32,7 @@ def step_server_config(context, server_fqdn, server_port):
|
||||||
context.user_api_key = None
|
context.user_api_key = None
|
||||||
|
|
||||||
context.completions = []
|
context.completions = []
|
||||||
context.completion_threads = []
|
context.concurrent_completion_tasks = []
|
||||||
context.prompts = []
|
context.prompts = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,39 +66,50 @@ def step_server_n_predict(context, n_predict):
|
||||||
context.n_server_predict = int(n_predict)
|
context.n_server_predict = int(n_predict)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'continuous batching')
|
||||||
|
def step_server_continuous_batching(context):
|
||||||
|
context.server_continuous_batching = True
|
||||||
|
|
||||||
|
|
||||||
|
@step(u"the server is starting")
|
||||||
|
def step_start_server(context):
|
||||||
|
start_server_background(context)
|
||||||
|
while True:
|
||||||
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||||
|
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
||||||
|
if result == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@step(u"the server is {expecting_status}")
|
@step(u"the server is {expecting_status}")
|
||||||
def step_wait_for_the_server_to_be_started(context, expecting_status):
|
@async_run_until_complete
|
||||||
|
async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
||||||
match expecting_status:
|
match expecting_status:
|
||||||
case 'starting':
|
|
||||||
start_server_background(context)
|
|
||||||
server_started = False
|
|
||||||
while not server_started:
|
|
||||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
|
||||||
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
|
||||||
if result == 0:
|
|
||||||
return 0
|
|
||||||
case 'loading model':
|
|
||||||
wait_for_health_status(context, 503, 'loading model')
|
|
||||||
case 'healthy':
|
case 'healthy':
|
||||||
wait_for_health_status(context, 200, 'ok')
|
await wait_for_health_status(context.base_url, 200, 'ok')
|
||||||
|
|
||||||
case 'ready' | 'idle':
|
case 'ready' | 'idle':
|
||||||
wait_for_health_status(context, 200, 'ok',
|
await wait_for_health_status(context.base_url, 200, 'ok',
|
||||||
params={'fail_on_no_slot': True},
|
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||||
slots_idle=context.n_slots,
|
slots_idle=context.n_slots,
|
||||||
slots_processing=0)
|
slots_processing=0,
|
||||||
request_slots_status(context, [{'id': slot_id, 'state': 0} for slot_id in range(context.n_slots)])
|
expected_slots=[{'id': slot_id, 'state': 0}
|
||||||
|
for slot_id in range(context.n_slots)])
|
||||||
case 'busy':
|
case 'busy':
|
||||||
wait_for_health_status(context, 503, 'no slot available',
|
await wait_for_health_status(context.base_url, 503,
|
||||||
params={'fail_on_no_slot': True},
|
'no slot available',
|
||||||
slots_idle=0,
|
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||||
slots_processing=context.n_slots)
|
slots_idle=0,
|
||||||
request_slots_status(context, [{'id': slot_id, 'state': 1} for slot_id in range(context.n_slots)])
|
slots_processing=context.n_slots,
|
||||||
|
expected_slots=[{'id': slot_id, 'state': 1}
|
||||||
|
for slot_id in range(context.n_slots)])
|
||||||
case _:
|
case _:
|
||||||
assert False, "unknown status"
|
assert False, "unknown status"
|
||||||
|
|
||||||
|
|
||||||
@step(u'all slots are {expected_slot_status_string}')
|
@step(u'all slots are {expected_slot_status_string}')
|
||||||
def step_all_slots_status(context, expected_slot_status_string):
|
@async_run_until_complete
|
||||||
|
async def step_all_slots_status(context, expected_slot_status_string):
|
||||||
match expected_slot_status_string:
|
match expected_slot_status_string:
|
||||||
case 'idle':
|
case 'idle':
|
||||||
expected_slot_status = 0
|
expected_slot_status = 0
|
||||||
|
@ -102,36 +118,40 @@ def step_all_slots_status(context, expected_slot_status_string):
|
||||||
case _:
|
case _:
|
||||||
assert False, "unknown status"
|
assert False, "unknown status"
|
||||||
|
|
||||||
expected_slots = []
|
expected_slots = [{'id': slot_id, 'state': expected_slot_status}
|
||||||
for slot_id in range(context.n_slots):
|
for slot_id in range(context.n_slots)]
|
||||||
expected_slots.append({
|
await request_slots_status(context, expected_slots)
|
||||||
'id': slot_id,
|
|
||||||
'state': expected_slot_status
|
|
||||||
})
|
|
||||||
request_slots_status(context, expected_slots)
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'a completion request with {api_error} api error')
|
@step(u'a completion request with {api_error} api error')
|
||||||
def step_request_completion(context, api_error):
|
@async_run_until_complete
|
||||||
request_completion(context, context.prompts.pop(),
|
async def step_request_completion(context, api_error):
|
||||||
n_predict=context.n_predict,
|
expect_api_error = api_error == 'raised'
|
||||||
expect_api_error=api_error == 'raised')
|
completion = await request_completion(context.prompts.pop(),
|
||||||
|
context.base_url,
|
||||||
|
n_predict=context.n_predict,
|
||||||
|
server_seed=context.server_seed,
|
||||||
|
expect_api_error=expect_api_error,
|
||||||
|
user_api_key=context.user_api_key)
|
||||||
|
context.completions.append(completion)
|
||||||
|
print(f"Completion response: {completion}")
|
||||||
|
if expect_api_error:
|
||||||
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||||
|
|
||||||
|
|
||||||
@step(u'{predicted_n} tokens are predicted with content: {content}')
|
@step(u'{predicted_n} tokens are predicted matching {re_content}')
|
||||||
def step_n_tokens_predicted_with_content(context, predicted_n, content):
|
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||||
assert_n_tokens_predicted(context.completions[0], int(predicted_n), content)
|
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n), re_content)
|
||||||
|
|
||||||
|
|
||||||
@step(u'{predicted_n} tokens are predicted')
|
@step(u'{predicted_n} tokens are predicted')
|
||||||
def step_n_tokens_predicted(context, predicted_n):
|
def step_n_tokens_predicted(context, predicted_n):
|
||||||
if int(predicted_n) > 0:
|
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n))
|
||||||
assert_n_tokens_predicted(context.completions[0], int(predicted_n))
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'a user prompt {user_prompt}')
|
@step(u'a user prompt {user_prompt}')
|
||||||
def step_user_prompt(context, user_prompt):
|
def step_user_prompt(context, user_prompt):
|
||||||
context.user_prompt = user_prompt
|
context.prompts.append(user_prompt)
|
||||||
|
|
||||||
|
|
||||||
@step(u'a system prompt {system_prompt}')
|
@step(u'a system prompt {system_prompt}')
|
||||||
|
@ -151,7 +171,7 @@ def step_max_tokens(context, max_tokens):
|
||||||
|
|
||||||
@step(u'streaming is {enable_streaming}')
|
@step(u'streaming is {enable_streaming}')
|
||||||
def step_streaming(context, enable_streaming):
|
def step_streaming(context, enable_streaming):
|
||||||
context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
|
context.enable_streaming = enable_streaming == 'enabled'
|
||||||
|
|
||||||
|
|
||||||
@step(u'a user api key {user_api_key}')
|
@step(u'a user api key {user_api_key}')
|
||||||
|
@ -175,8 +195,35 @@ def step_server_api_key(context, server_api_key):
|
||||||
|
|
||||||
|
|
||||||
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
||||||
def step_oai_chat_completions(context, api_error):
|
@async_run_until_complete
|
||||||
oai_chat_completions(context, context.user_prompt, api_error=api_error == 'raised')
|
async def step_oai_chat_completions(context, api_error):
|
||||||
|
print(f"Submitting OAI compatible completions request...")
|
||||||
|
expect_api_error = api_error == 'raised'
|
||||||
|
completion = await oai_chat_completions(context.prompts.pop(),
|
||||||
|
context.system_prompt,
|
||||||
|
context.base_url,
|
||||||
|
False,
|
||||||
|
model=context.model if hasattr(context, 'model') else None,
|
||||||
|
|
||||||
|
n_predict=context.n_predict
|
||||||
|
if hasattr(context, 'n_predict') else None,
|
||||||
|
|
||||||
|
enable_streaming=context.enable_streaming
|
||||||
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
|
|
||||||
|
server_seed=context.server_seed
|
||||||
|
if hasattr(context, 'server_seed') else None,
|
||||||
|
|
||||||
|
user_api_key=context.user_api_key
|
||||||
|
if hasattr(context, 'user_api_key') else None,
|
||||||
|
|
||||||
|
expect_api_error=expect_api_error)
|
||||||
|
context.completions.append(completion)
|
||||||
|
print(f"Completion response: {completion}")
|
||||||
|
if expect_api_error:
|
||||||
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||||
|
|
||||||
|
print(f"Completion response: {completion}")
|
||||||
|
|
||||||
|
|
||||||
@step(u'a prompt')
|
@step(u'a prompt')
|
||||||
|
@ -190,22 +237,49 @@ def step_a_prompt_prompt(context, prompt):
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent completion requests')
|
@step(u'concurrent completion requests')
|
||||||
def step_concurrent_completion_requests(context):
|
@async_run_until_complete()
|
||||||
concurrent_requests(context, request_completion)
|
async def step_concurrent_completion_requests(context):
|
||||||
|
await concurrent_completion_requests(context,
|
||||||
|
request_completion,
|
||||||
|
# prompt is inserted automatically
|
||||||
|
context.base_url,
|
||||||
|
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||||
|
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||||
|
user_api_key=context.user_api_key if hasattr(context,
|
||||||
|
'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent OAI completions requests')
|
@step(u'concurrent OAI completions requests')
|
||||||
def step_oai_chat_completions(context):
|
@async_run_until_complete
|
||||||
concurrent_requests(context, oai_chat_completions)
|
async def step_oai_chat_completions(context):
|
||||||
|
await concurrent_completion_requests(context, oai_chat_completions,
|
||||||
|
# user_prompt is inserted automatically
|
||||||
|
context.system_prompt,
|
||||||
|
context.base_url,
|
||||||
|
True, # async_client
|
||||||
|
model=context.model
|
||||||
|
if hasattr(context, 'model') else None,
|
||||||
|
n_predict=context.n_predict
|
||||||
|
if hasattr(context, 'n_predict') else None,
|
||||||
|
enable_streaming=context.enable_streaming
|
||||||
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
|
server_seed=context.server_seed
|
||||||
|
if hasattr(context, 'server_seed') else None,
|
||||||
|
user_api_key=context.user_api_key
|
||||||
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step(u'all prompts are predicted')
|
@step(u'all prompts are predicted with {n_predict} tokens')
|
||||||
def step_all_prompts_are_predicted(context):
|
@async_run_until_complete
|
||||||
for completion_thread in context.completion_threads:
|
async def step_all_prompts_are_predicted(context, n_predict):
|
||||||
completion_thread.join()
|
n_completion_tasks = len(context.concurrent_completion_tasks)
|
||||||
assert len(context.completions) == len(context.completion_threads)
|
print(f"Waiting for all {n_completion_tasks} completion responses...")
|
||||||
for completion in context.completions:
|
for task_no in range(n_completion_tasks):
|
||||||
assert_n_tokens_predicted(completion)
|
context.completions.append(await context.concurrent_completion_tasks.pop())
|
||||||
|
n_completions = len(context.completions)
|
||||||
|
assert n_completions > 0
|
||||||
|
for i in range(n_completions):
|
||||||
|
assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=int(n_predict))
|
||||||
|
|
||||||
|
|
||||||
@step(u'embeddings are computed for')
|
@step(u'embeddings are computed for')
|
||||||
|
@ -269,126 +343,228 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
|
||||||
assert context.options_response.headers[cors_header] == cors_header_value
|
assert context.options_response.headers[cors_header] == cors_header_value
|
||||||
|
|
||||||
|
|
||||||
def concurrent_requests(context, f_completion, *argv):
|
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
||||||
context.completions.clear()
|
n_prompts = len(context.prompts)
|
||||||
context.completion_threads.clear()
|
print(f"starting {n_prompts} concurrent completion requests...")
|
||||||
for prompt in context.prompts:
|
assert n_prompts > 0
|
||||||
completion_thread = threading.Thread(target=f_completion, args=(context, prompt, *argv))
|
for prompt_no in range(n_prompts):
|
||||||
completion_thread.start()
|
shifted_args = [context.prompts.pop(), *args]
|
||||||
context.completion_threads.append(completion_thread)
|
context.concurrent_completion_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||||
context.prompts.clear()
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
def request_completion(context, prompt, n_predict=None, expect_api_error=None):
|
async def request_completion(prompt,
|
||||||
|
base_url,
|
||||||
|
n_predict=None,
|
||||||
|
server_seed=None,
|
||||||
|
expect_api_error=None,
|
||||||
|
user_api_key=None):
|
||||||
|
print(f"Sending completion request: {prompt}")
|
||||||
origin = "my.super.domain"
|
origin = "my.super.domain"
|
||||||
headers = {
|
headers = {
|
||||||
'Origin': origin
|
'Origin': origin
|
||||||
}
|
}
|
||||||
if context.user_api_key is not None:
|
if user_api_key is not None:
|
||||||
print(f"Set user_api_key: {context.user_api_key}")
|
print(f"Set user_api_key: {user_api_key}")
|
||||||
headers['Authorization'] = f'Bearer {context.user_api_key}'
|
headers['Authorization'] = f'Bearer {user_api_key}'
|
||||||
|
|
||||||
response = requests.post(f'{context.base_url}/completion',
|
async with aiohttp.ClientSession() as session:
|
||||||
json={
|
async with session.post(f'{base_url}/completion',
|
||||||
"prompt": prompt,
|
json={
|
||||||
"n_predict": int(n_predict) if n_predict is not None else context.n_predict,
|
"prompt": prompt,
|
||||||
"seed": context.server_seed if context.server_seed is not None else 42
|
"n_predict": int(n_predict) if n_predict is not None else -1,
|
||||||
},
|
"seed": server_seed if server_seed is not None else 42
|
||||||
headers=headers)
|
},
|
||||||
if expect_api_error is not None and not expect_api_error:
|
headers=headers) as response:
|
||||||
assert response.status_code == 200
|
if expect_api_error is None or not expect_api_error:
|
||||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
assert response.status == 200
|
||||||
context.completions.append(response.json())
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
else:
|
return await response.json()
|
||||||
assert response.status_code == 401
|
else:
|
||||||
|
return response.status
|
||||||
|
|
||||||
|
|
||||||
def oai_chat_completions(context, user_prompt, api_error=None):
|
async def oai_chat_completions(user_prompt,
|
||||||
openai.api_key = 'nope' # openai client always expects an api_keu
|
system_prompt,
|
||||||
if context.user_api_key is not None:
|
base_url,
|
||||||
openai.api_key = context.user_api_key
|
async_client,
|
||||||
openai.api_base = f'{context.base_url}/v1/chat'
|
model=None,
|
||||||
try:
|
n_predict=None,
|
||||||
chat_completion = openai.Completion.create(
|
enable_streaming=None,
|
||||||
messages=[
|
server_seed=None,
|
||||||
{
|
user_api_key=None,
|
||||||
"role": "system",
|
expect_api_error=None):
|
||||||
"content": context.system_prompt,
|
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||||
},
|
# openai client always expects an api key
|
||||||
{
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
"role": "user",
|
seed = server_seed if server_seed is not None else 42
|
||||||
"content": user_prompt,
|
enable_streaming = enable_streaming if enable_streaming is not None else False
|
||||||
}
|
payload = {
|
||||||
],
|
"messages": [
|
||||||
model=context.model,
|
{
|
||||||
max_tokens=context.n_predict,
|
"role": "system",
|
||||||
stream=context.enable_streaming,
|
"content": system_prompt,
|
||||||
seed=context.server_seed if context.server_seed is not None else 42
|
},
|
||||||
)
|
{
|
||||||
except openai.error.APIError:
|
"role": "user",
|
||||||
if api_error is not None and api_error:
|
"content": user_prompt,
|
||||||
return
|
|
||||||
if context.enable_streaming:
|
|
||||||
completion_response = {
|
|
||||||
'content': '',
|
|
||||||
'timings': {
|
|
||||||
'predicted_n': 0
|
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": n_predict,
|
||||||
|
"stream": enable_streaming,
|
||||||
|
"seed": seed
|
||||||
|
}
|
||||||
|
completion_response = {
|
||||||
|
'content': '',
|
||||||
|
'timings': {
|
||||||
|
'predicted_n': 0
|
||||||
}
|
}
|
||||||
for chunk in chat_completion:
|
}
|
||||||
assert len(chunk.choices) == 1
|
if async_client:
|
||||||
delta = chunk.choices[0].delta
|
origin = 'llama.cpp'
|
||||||
if 'content' in delta:
|
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||||
completion_response['content'] += delta['content']
|
async with aiohttp.ClientSession() as session:
|
||||||
completion_response['timings']['predicted_n'] += 1
|
async with session.post(f'{base_url}/v1/chat/completions',
|
||||||
context.completions.append(completion_response)
|
json=payload,
|
||||||
|
headers=headers) as response:
|
||||||
|
if enable_streaming:
|
||||||
|
print("payload", payload)
|
||||||
|
assert response.status == 200
|
||||||
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
|
assert response.headers['Content-Type'] == "text/event-stream"
|
||||||
|
|
||||||
|
async for line_in_bytes in response.content:
|
||||||
|
line = line_in_bytes.decode('utf8')
|
||||||
|
event_data = line.split(': ', 1)
|
||||||
|
assert event_data[0] == 'data', f'{event_data}'
|
||||||
|
chunk_raw = event_data[1]
|
||||||
|
|
||||||
|
chunk = json.loads(chunk_raw)
|
||||||
|
assert len(chunk['choices']) == 1
|
||||||
|
delta = chunk['choices'][0]['delta']
|
||||||
|
if 'content' in delta:
|
||||||
|
completion_response['content'] += delta['content']
|
||||||
|
completion_response['timings']['predicted_n'] += 1
|
||||||
|
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}")
|
||||||
|
else:
|
||||||
|
print(f"raw completion response: {response}")
|
||||||
|
if expect_api_error is None or not expect_api_error:
|
||||||
|
assert response.status == 200
|
||||||
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
|
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
|
||||||
|
chat_completion_raw = await response.json()
|
||||||
|
completion_response = {
|
||||||
|
'content': chat_completion_raw['choices'][0]['message'],
|
||||||
|
'timings': {
|
||||||
|
'predicted_n': chat_completion_raw['usage']['completion_tokens']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return response.status
|
||||||
else:
|
else:
|
||||||
assert len(chat_completion.choices) == 1
|
try:
|
||||||
context.completions.append({
|
openai.api_key = user_api_key
|
||||||
'content': chat_completion.choices[0].message,
|
openai.api_base = f'{base_url}/v1/chat'
|
||||||
'timings': {
|
chat_completion = openai.Completion.create(
|
||||||
'predicted_n': chat_completion.usage.completion_tokens
|
messages=payload['messages'],
|
||||||
|
model=model,
|
||||||
|
max_tokens=n_predict,
|
||||||
|
stream=enable_streaming,
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
except openai.error.APIError as e:
|
||||||
|
if expect_api_error is not None and expect_api_error:
|
||||||
|
return 401
|
||||||
|
else:
|
||||||
|
assert False, f'error raised: {e}'
|
||||||
|
|
||||||
|
if enable_streaming:
|
||||||
|
for chunk in chat_completion:
|
||||||
|
assert len(chunk.choices) == 1
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if 'content' in delta:
|
||||||
|
completion_response['content'] += delta['content']
|
||||||
|
completion_response['timings']['predicted_n'] += 1
|
||||||
|
else:
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
completion_response = {
|
||||||
|
'content': chat_completion.choices[0].message.content,
|
||||||
|
'timings': {
|
||||||
|
'predicted_n': chat_completion.usage.completion_tokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
print("OAI response formatted to llama.cpp", completion_response)
|
||||||
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, expected_content=None):
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||||
content = completion_response['content']
|
content = completion_response['content']
|
||||||
n_predicted = completion_response['timings']['predicted_n']
|
n_predicted = completion_response['timings']['predicted_n']
|
||||||
assert len(content) > 0, "no token predicted"
|
assert len(content) > 0, "no token predicted"
|
||||||
if expected_predicted_n is not None:
|
if expected_predicted_n is not None:
|
||||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||||
f' {n_predicted} <> {expected_predicted_n}')
|
f' {n_predicted} <> {expected_predicted_n}')
|
||||||
if expected_content is not None:
|
if re_content is not None:
|
||||||
expected_content = expected_content.replace('<space>', ' ').replace('<LF>', '\n')
|
re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
|
||||||
assert content == expected_content, (f'invalid tokens predicted:'
|
assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
|
||||||
f' ```\n{content}\n``` <> ```\n{expected_content}\n```')
|
f'invalid tokens predicted:'
|
||||||
|
f' ```\n{content}\n``` do not match /{re_content}/')
|
||||||
|
|
||||||
|
|
||||||
def wait_for_health_status(context, expected_http_status_code,
|
async def wait_for_health_status(base_url,
|
||||||
expected_health_status,
|
expected_http_status_code,
|
||||||
params=None,
|
expected_health_status,
|
||||||
slots_idle=None,
|
params=None,
|
||||||
slots_processing=None):
|
slots_idle=None,
|
||||||
while True:
|
slots_processing=None,
|
||||||
health_response = requests.get(f'{context.base_url}/health', params)
|
expected_slots=None):
|
||||||
status_code = health_response.status_code
|
print(f"Starting checking for health for expected_health_status={expected_health_status}")
|
||||||
health = health_response.json()
|
timeout = 3 # seconds
|
||||||
if (status_code == expected_http_status_code
|
interval = 0.5
|
||||||
and health['status'] == expected_health_status
|
counter = 0
|
||||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
async with aiohttp.ClientSession() as session:
|
||||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
while True:
|
||||||
break
|
async with await session.get(f'{base_url}/health', params=params) as health_response:
|
||||||
|
status_code = health_response.status
|
||||||
|
health = await health_response.json()
|
||||||
|
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
|
||||||
|
f"'{base_url}/health'?{params} is {health}")
|
||||||
|
if (status_code == expected_http_status_code
|
||||||
|
and health['status'] == expected_health_status
|
||||||
|
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||||
|
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||||
|
if expected_slots is not None:
|
||||||
|
assert_slots_status(health['slots'], expected_slots)
|
||||||
|
return
|
||||||
|
if (status_code == expected_http_status_code
|
||||||
|
and health['status'] == expected_health_status
|
||||||
|
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||||
|
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||||
|
if expected_slots is not None:
|
||||||
|
assert_slots_status(health['slots'], expected_slots)
|
||||||
|
return
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
counter += interval
|
||||||
|
if counter >= timeout:
|
||||||
|
assert False, 'timeout exceeded'
|
||||||
|
|
||||||
|
|
||||||
def request_slots_status(context, expected_slots):
|
async def request_slots_status(context, expected_slots):
|
||||||
slots_response = requests.get(f'{context.base_url}/slots')
|
async with aiohttp.ClientSession() as session:
|
||||||
assert slots_response.status_code == 200
|
async with await session.get(f'{context.base_url}/slots') as slots_response:
|
||||||
slots = slots_response.json()
|
assert slots_response.status == 200
|
||||||
|
slots = await slots_response.json()
|
||||||
|
assert_slots_status(slots, expected_slots)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_slots_status(slots, expected_slots):
|
||||||
assert len(slots) == len(expected_slots)
|
assert len(slots) == len(expected_slots)
|
||||||
for expected, slot in zip(expected_slots, slots):
|
for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
|
||||||
for key in expected:
|
for key in expected:
|
||||||
assert expected[key] == slot[key], f"expected[{key}] != slot[{key}]"
|
assert expected[key] == slot[key], (f"invalid slot {slot_id}"
|
||||||
|
f" expected[{key}] != slot[{key}]"
|
||||||
|
f" = {expected[key]} != {slot[key]}")
|
||||||
|
|
||||||
|
|
||||||
def start_server_background(context):
|
def start_server_background(context):
|
||||||
|
@ -398,6 +574,8 @@ def start_server_background(context):
|
||||||
server_args = [
|
server_args = [
|
||||||
'--model', context.model_file
|
'--model', context.model_file
|
||||||
]
|
]
|
||||||
|
if context.server_continuous_batching:
|
||||||
|
server_args.append('--cont-batching')
|
||||||
if context.model_alias is not None:
|
if context.model_alias is not None:
|
||||||
server_args.extend(['--alias', context.model_alias])
|
server_args.extend(['--alias', context.model_alias])
|
||||||
if context.server_seed is not None:
|
if context.server_seed is not None:
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
|
aiohttp~=3.9.3
|
||||||
behave~=1.2.6
|
behave~=1.2.6
|
||||||
openai~=0.25.0
|
openai~=0.25.0
|
||||||
|
|
|
@ -3,4 +3,4 @@
|
||||||
set -eu
|
set -eu
|
||||||
|
|
||||||
# Start @llama.cpp scenario
|
# Start @llama.cpp scenario
|
||||||
behave --summary --stop --tags llama.cpp
|
behave --summary --stop --no-capture --tags llama.cpp
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue