server: tests: adding concurrent embedding in issue #5655

allow to enable VERBOSE mode
This commit is contained in:
Pierrick HYMBERT 2024-02-23 18:41:11 +01:00
parent 30f802d0d7
commit 6c0e6f4f9c
7 changed files with 117 additions and 50 deletions

View file

@ -12,8 +12,9 @@ Server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_
3. Start the test: `./tests.sh` 3. Start the test: `./tests.sh`
It's possible to override some scenario steps values with environment variables: It's possible to override some scenario steps values with environment variables:
- `$PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080` - `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
- `$LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server` - `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
- `DEBUG` -> "ON" to enable server verbose mode `--verbose`
### Run @bug, @wip or @wrong_usage annotated scenario ### Run @bug, @wip or @wrong_usage annotated scenario
@ -23,4 +24,4 @@ Feature or Scenario must be annotated with `@llama.cpp` to be included in the de
- `@wip` to focus on a scenario working in progress - `@wip` to focus on a scenario working in progress
To run a scenario annotated with `@bug`, start: To run a scenario annotated with `@bug`, start:
`./tests.sh --tags bug` `DEBUG=ON ./tests.sh --no-skipped --tags bug`

View file

@ -24,7 +24,7 @@ def after_scenario(context, scenario):
for line in f: for line in f:
print(line) print(line)
if not is_server_listening(context.server_fqdn, context.server_port): if not is_server_listening(context.server_fqdn, context.server_port):
print("ERROR: Server has crashed") print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
if not pid_exists(context.server_process.pid): if not pid_exists(context.server_process.pid):
assert False, f"Server not running pid={context.server_process.pid} ..." assert False, f"Server not running pid={context.server_process.pid} ..."
@ -41,7 +41,7 @@ def after_scenario(context, scenario):
time.sleep(0.1) time.sleep(0.1)
attempts += 1 attempts += 1
if attempts > 5: if attempts > 5:
print(f"Server dandling exits, killing all {context.server_path} ...") print(f"Server dangling exits, killing all {context.server_path} ...")
process = subprocess.run(['killall', '-9', context.server_path], process = subprocess.run(['killall', '-9', context.server_path],
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=True) universal_newlines=True)

View file

@ -0,0 +1,36 @@
# List of ongoing issues
@bug
Feature: Issues
# Issue #5655
Scenario: Multi users embeddings
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model alias tinyllama-2
And 42 as server seed
And 64 KV cache size
And 2 slots
And continuous batching
And embeddings extraction
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View file

@ -35,8 +35,8 @@ def step_server_config(context, server_fqdn, server_port):
context.server_seed = None context.server_seed = None
context.user_api_key = None context.user_api_key = None
context.completions = [] context.tasks_result = []
context.concurrent_completion_tasks = [] context.concurrent_tasks = []
context.prompts = [] context.prompts = []
@ -149,7 +149,7 @@ async def step_request_completion(context, api_error):
server_seed=context.server_seed, server_seed=context.server_seed,
expect_api_error=expect_api_error, expect_api_error=expect_api_error,
user_api_key=context.user_api_key) user_api_key=context.user_api_key)
context.completions.append(completion) context.tasks_result.append(completion)
print(f"Completion response: {completion}") print(f"Completion response: {completion}")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@ -157,12 +157,12 @@ async def step_request_completion(context, api_error):
@step(u'{predicted_n} tokens are predicted matching {re_content}') @step(u'{predicted_n} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n), re_content) assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
@step(u'{predicted_n} tokens are predicted') @step(u'{predicted_n} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n): def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.completions.pop(), int(predicted_n)) assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
@step(u'a user prompt {user_prompt}') @step(u'a user prompt {user_prompt}')
@ -195,13 +195,13 @@ def step_user_api_key(context, user_api_key):
context.user_api_key = user_api_key context.user_api_key = user_api_key
@step(u'a user api key ') @step(u'no user api key')
def step_no_user_api_key(context): def step_no_user_api_key(context):
context.user_api_key = None context.user_api_key = None
@step(u'no user api key') @step(u'a user api key ')
def step_no_user_api_key(context): def step_no_user_api_key_space(context):
context.user_api_key = None context.user_api_key = None
@ -234,7 +234,7 @@ async def step_oai_chat_completions(context, api_error):
if hasattr(context, 'user_api_key') else None, if hasattr(context, 'user_api_key') else None,
expect_api_error=expect_api_error) expect_api_error=expect_api_error)
context.completions.append(completion) context.tasks_result.append(completion)
print(f"Completion response: {completion}") print(f"Completion response: {completion}")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@ -285,47 +285,38 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@async_run_until_complete
@step(u'all prompts are predicted') @step(u'all prompts are predicted')
async def step_impl(context): @async_run_until_complete
async def step_all_prompts_are_predicted(context):
await all_prompts_are_predicted(context) await all_prompts_are_predicted(context)
@step(u'all prompts are predicted with {n_predict} tokens') @step(u'all prompts are predicted with {n_predict} tokens')
@async_run_until_complete @async_run_until_complete
async def step_all_prompts_are_predicted(context, n_predict): async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
expected_predicted_n = int(n_predict) expected_predicted_n = int(n_predict)
await all_prompts_are_predicted(context, expected_predicted_n) await all_prompts_are_predicted(context, expected_predicted_n)
async def all_prompts_are_predicted(context, expected_predicted_n=None): async def all_prompts_are_predicted(context, expected_predicted_n=None):
n_completions = await gather_concurrent_completions_tasks(context) n_completions = await gather_tasks_results(context)
assert n_completions > 0 assert n_completions > 0
for i in range(n_completions): for i in range(n_completions):
assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=expected_predicted_n) assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
@step(u'embeddings are computed for') @step(u'embeddings are computed for')
@async_run_until_complete @async_run_until_complete
async def step_compute_embedding(context): async def step_compute_embedding(context):
async with aiohttp.ClientSession() as session: content = context.text
async with session.post(f'{context.base_url}/embedding', base_url = context.base_url
json={ context.embeddings = await request_embedding(content, base_url)
"content": context.text,
}) as response:
assert response.status == 200
response_json = await response.json()
context.embeddings = response_json['embedding']
@step(u'embeddings are generated') @step(u'embeddings are generated')
def step_compute_embeddings(context): def step_assert_embeddings(context):
assert len(context.embeddings) > 0 assert_embeddings(context.embeddings)
embeddings_computed = False
for emb in context.embeddings:
if emb != 0:
embeddings_computed = True
assert embeddings_computed, f"Embeddings: {context.embeddings}"
@step(u'an OAI compatible embeddings computation request for') @step(u'an OAI compatible embeddings computation request for')
@ -341,6 +332,24 @@ def step_oai_compute_embedding(context):
context.embeddings = embeddings context.embeddings = embeddings
@step(u'concurrent embedding requests')
@async_run_until_complete()
async def step_concurrent_embedding_requests(context):
await concurrent_completion_requests(context,
request_embedding,
# prompt is inserted automatically
context.base_url)
@step(u'all embeddings are generated')
@async_run_until_complete()
async def all_embeddings_are_generated(context):
n_embedding_requests = await gather_tasks_results(context)
assert n_embedding_requests > 0
for i in range(n_embedding_requests):
assert_embeddings(context.tasks_result.pop())
@step(u'tokenizing') @step(u'tokenizing')
@async_run_until_complete @async_run_until_complete
async def step_tokenize(context): async def step_tokenize(context):
@ -391,7 +400,7 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs)
assert n_prompts > 0 assert n_prompts > 0
for prompt_no in range(n_prompts): for prompt_no in range(n_prompts):
shifted_args = [context.prompts.pop(), *args] shifted_args = [context.prompts.pop(), *args]
context.concurrent_completion_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -540,6 +549,17 @@ async def oai_chat_completions(user_prompt,
return completion_response return completion_response
async def request_embedding(content, base_url):
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding',
json={
"content": content,
}) as response:
assert response.status == 200
response_json = await response.json()
return response_json['embedding']
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']
@ -554,12 +574,12 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
f' ```\n{content}\n``` do not match /{re_content}/') f' ```\n{content}\n``` do not match /{re_content}/')
async def gather_concurrent_completions_tasks(context): async def gather_tasks_results(context):
n_completion_tasks = len(context.concurrent_completion_tasks) n_tasks = len(context.concurrent_tasks)
print(f"Waiting for all {n_completion_tasks} completion responses...") print(f"Waiting for all {n_tasks} tasks results...")
for task_no in range(n_completion_tasks): for task_no in range(n_tasks):
context.completions.append(await context.concurrent_completion_tasks.pop()) context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.completions) n_completions = len(context.tasks_result)
return n_completions return n_completions
@ -602,16 +622,25 @@ async def wait_for_health_status(context,
if counter >= timeout: if counter >= timeout:
# Sometimes health requests are triggered after completions are predicted # Sometimes health requests are triggered after completions are predicted
if expected_http_status_code == 503: if expected_http_status_code == 503:
if len(context.completions) == 0: if len(context.tasks_result) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks," print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
" busy health check missed, probably too fast inference\x1b[0m") " busy health check missed, probably too fast inference\x1b[0m")
n_completions = await gather_concurrent_completions_tasks(context) n_completions = await gather_tasks_results(context)
if n_completions > 0: if n_completions > 0:
return return
assert False, 'timeout exceeded' assert False, 'timeout exceeded'
def assert_embeddings(embeddings):
assert len(embeddings) > 0
embeddings_computed = False
for emb in embeddings:
if emb != 0:
embeddings_computed = True
assert embeddings_computed, f"Embeddings: {embeddings}"
async def request_slots_status(context, expected_slots): async def request_slots_status(context, expected_slots):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with await session.get(f'{context.base_url}/slots') as slots_response: async with await session.get(f'{context.base_url}/slots') as slots_response:
@ -652,6 +681,8 @@ def start_server_background(context):
server_args.extend(['--n-predict', context.n_server_predict]) server_args.extend(['--n-predict', context.n_server_predict])
if context.server_api_key is not None: if context.server_api_key is not None:
server_args.extend(['--api-key', context.server_api_key]) server_args.extend(['--api-key', context.server_api_key])
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
server_args.append('--verbose')
print(f"starting server with: {context.server_path}", *server_args) print(f"starting server with: {context.server_path}", *server_args)
context.server_process = subprocess.Popen( context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]], [str(arg) for arg in [context.server_path, *server_args]],

View file

@ -4,12 +4,10 @@ Feature: Wrong usage of llama.cpp server
#3969 The user must always set --n-predict option #3969 The user must always set --n-predict option
# to cap the number of tokens any completion request can generate # to cap the number of tokens any completion request can generate
# or pass n_predict or max_tokens in the request. # or pass n_predict/max_tokens in the request.
Scenario: Infinite loop Scenario: Infinite loop
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file stories260K.gguf And a model file stories260K.gguf
And 1 slots
And 32 KV cache size
# Uncomment below to fix the issue # Uncomment below to fix the issue
#And 64 server max tokens to predict #And 64 server max tokens to predict
Then the server is starting Then the server is starting
@ -17,6 +15,7 @@ Feature: Wrong usage of llama.cpp server
""" """
Go to: infinite loop Go to: infinite loop
""" """
# Uncomment below to fix the issue
#And 128 max tokens to predict
Given concurrent completion requests Given concurrent completion requests
Then all prompts are predicted Then all prompts are predicted

View file

@ -5,7 +5,7 @@ set -eu
if [ $# -lt 1 ] if [ $# -lt 1 ]
then then
# Start @llama.cpp scenario # Start @llama.cpp scenario
behave --summary --stop --no-capture --tags llama.cpp behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
else else
behave "$@" behave "$@"
fi fi