server: tests - print only in case of DEBUG

This commit is contained in:
Pierrick HYMBERT 2024-02-24 11:13:43 +01:00
parent 60781f0a2b
commit a779a4bf9c

View file

@ -146,12 +146,14 @@ async def step_request_completion(context, api_error):
expect_api_error = api_error == 'raised'
completion = await request_completion(context.prompts.pop(),
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
server_seed=context.server_seed,
expect_api_error=expect_api_error,
user_api_key=context.user_api_key)
context.tasks_result.append(completion)
print(f"Completion response: {completion}")
if context.debug:
print(f"Completion response: {completion}")
if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}"
@ -214,7 +216,8 @@ def step_server_api_key(context, server_api_key):
@step(u'an OAI compatible chat completions request with {api_error} api error')
@async_run_until_complete
async def step_oai_chat_completions(context, api_error):
print(f"Submitting OAI compatible completions request...")
if context.debug:
print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt,
@ -236,11 +239,13 @@ async def step_oai_chat_completions(context, api_error):
expect_api_error=expect_api_error)
context.tasks_result.append(completion)
print(f"Completion response: {completion}")
if context.debug:
print(f"Completion response: {completion}")
if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}"
print(f"Completion response: {completion}")
if context.debug:
print(f"Completion response: {completion}")
@step(u'a prompt')
@ -260,6 +265,7 @@ async def step_concurrent_completion_requests(context):
request_completion,
# prompt is inserted automatically
context.base_url,
debug=context.debug,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key if hasattr(context,
@ -397,7 +403,8 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts)
print(f"starting {n_prompts} concurrent completion requests...")
if context.debug:
print(f"starting {n_prompts} concurrent completion requests...")
assert n_prompts > 0
for prompt_no in range(n_prompts):
shifted_args = [context.prompts.pop(), *args]
@ -407,17 +414,20 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs)
async def request_completion(prompt,
base_url,
debug=False,
n_predict=None,
server_seed=None,
expect_api_error=None,
user_api_key=None):
print(f"Sending completion request: {prompt}")
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
headers = {
'Origin': origin
}
if user_api_key is not None:
print(f"Set user_api_key: {user_api_key}")
if debug:
print(f"Set user_api_key: {user_api_key}")
headers['Authorization'] = f'Bearer {user_api_key}'
async with aiohttp.ClientSession() as session:
@ -440,13 +450,15 @@ async def oai_chat_completions(user_prompt,
system_prompt,
base_url,
async_client,
debug=False,
model=None,
n_predict=None,
enable_streaming=None,
server_seed=None,
user_api_key=None,
expect_api_error=None):
print(f"Sending OAI Chat completions request: {user_prompt}")
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
user_api_key = user_api_key if user_api_key is not None else 'nope'
seed = server_seed if server_seed is not None else 42
@ -548,7 +560,8 @@ async def oai_chat_completions(user_prompt,
'predicted_n': chat_completion.usage.completion_tokens
}
}
print("OAI response formatted to llama.cpp:", completion_response)
if debug:
print("OAI response formatted to llama.cpp:", completion_response)
return completion_response
@ -579,7 +592,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
print(f"Waiting for all {n_tasks} tasks results...")
if context.debug:
print(f"Waiting for all {n_tasks} tasks results...")
for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result)