server: tests - print only in case of DEBUG

This commit is contained in:
Pierrick HYMBERT 2024-02-24 11:13:43 +01:00
parent 60781f0a2b
commit a779a4bf9c

View file

@ -146,11 +146,13 @@ async def step_request_completion(context, api_error):
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
completion = await request_completion(context.prompts.pop(), completion = await request_completion(context.prompts.pop(),
context.base_url, context.base_url,
debug=context.debug,
n_predict=context.n_predict, n_predict=context.n_predict,
server_seed=context.server_seed, server_seed=context.server_seed,
expect_api_error=expect_api_error, expect_api_error=expect_api_error,
user_api_key=context.user_api_key) user_api_key=context.user_api_key)
context.tasks_result.append(completion) context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}") print(f"Completion response: {completion}")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@ -214,6 +216,7 @@ def step_server_api_key(context, server_api_key):
@step(u'an OAI compatible chat completions request with {api_error} api error') @step(u'an OAI compatible chat completions request with {api_error} api error')
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context, api_error): async def step_oai_chat_completions(context, api_error):
if context.debug:
print(f"Submitting OAI compatible completions request...") print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(), completion = await oai_chat_completions(context.prompts.pop(),
@ -236,10 +239,12 @@ async def step_oai_chat_completions(context, api_error):
expect_api_error=expect_api_error) expect_api_error=expect_api_error)
context.tasks_result.append(completion) context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}") print(f"Completion response: {completion}")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
if context.debug:
print(f"Completion response: {completion}") print(f"Completion response: {completion}")
@ -260,6 +265,7 @@ async def step_concurrent_completion_requests(context):
request_completion, request_completion,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url, context.base_url,
debug=context.debug,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None, n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
server_seed=context.server_seed if hasattr(context, 'server_seed') else None, server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key if hasattr(context, user_api_key=context.user_api_key if hasattr(context,
@ -397,6 +403,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
async def concurrent_completion_requests(context, f_completion, *args, **kwargs): async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts) n_prompts = len(context.prompts)
if context.debug:
print(f"starting {n_prompts} concurrent completion requests...") print(f"starting {n_prompts} concurrent completion requests...")
assert n_prompts > 0 assert n_prompts > 0
for prompt_no in range(n_prompts): for prompt_no in range(n_prompts):
@ -407,16 +414,19 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs)
async def request_completion(prompt, async def request_completion(prompt,
base_url, base_url,
debug=False,
n_predict=None, n_predict=None,
server_seed=None, server_seed=None,
expect_api_error=None, expect_api_error=None,
user_api_key=None): user_api_key=None):
if debug:
print(f"Sending completion request: {prompt}") print(f"Sending completion request: {prompt}")
origin = "my.super.domain" origin = "my.super.domain"
headers = { headers = {
'Origin': origin 'Origin': origin
} }
if user_api_key is not None: if user_api_key is not None:
if debug:
print(f"Set user_api_key: {user_api_key}") print(f"Set user_api_key: {user_api_key}")
headers['Authorization'] = f'Bearer {user_api_key}' headers['Authorization'] = f'Bearer {user_api_key}'
@ -440,12 +450,14 @@ async def oai_chat_completions(user_prompt,
system_prompt, system_prompt,
base_url, base_url,
async_client, async_client,
debug=False,
model=None, model=None,
n_predict=None, n_predict=None,
enable_streaming=None, enable_streaming=None,
server_seed=None, server_seed=None,
user_api_key=None, user_api_key=None,
expect_api_error=None): expect_api_error=None):
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}") print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key # openai client always expects an api key
user_api_key = user_api_key if user_api_key is not None else 'nope' user_api_key = user_api_key if user_api_key is not None else 'nope'
@ -548,6 +560,7 @@ async def oai_chat_completions(user_prompt,
'predicted_n': chat_completion.usage.completion_tokens 'predicted_n': chat_completion.usage.completion_tokens
} }
} }
if debug:
print("OAI response formatted to llama.cpp:", completion_response) print("OAI response formatted to llama.cpp:", completion_response)
return completion_response return completion_response
@ -579,6 +592,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
async def gather_tasks_results(context): async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks) n_tasks = len(context.concurrent_tasks)
if context.debug:
print(f"Waiting for all {n_tasks} tasks results...") print(f"Waiting for all {n_tasks} tasks results...")
for task_no in range(n_tasks): for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop()) context.tasks_result.append(await context.concurrent_tasks.pop())