server: tests - print only in case of DEBUG
This commit is contained in:
parent
60781f0a2b
commit
a779a4bf9c
1 changed files with 24 additions and 10 deletions
|
@ -146,12 +146,14 @@ async def step_request_completion(context, api_error):
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
completion = await request_completion(context.prompts.pop(),
|
completion = await request_completion(context.prompts.pop(),
|
||||||
context.base_url,
|
context.base_url,
|
||||||
|
debug=context.debug,
|
||||||
n_predict=context.n_predict,
|
n_predict=context.n_predict,
|
||||||
server_seed=context.server_seed,
|
server_seed=context.server_seed,
|
||||||
expect_api_error=expect_api_error,
|
expect_api_error=expect_api_error,
|
||||||
user_api_key=context.user_api_key)
|
user_api_key=context.user_api_key)
|
||||||
context.tasks_result.append(completion)
|
context.tasks_result.append(completion)
|
||||||
print(f"Completion response: {completion}")
|
if context.debug:
|
||||||
|
print(f"Completion response: {completion}")
|
||||||
if expect_api_error:
|
if expect_api_error:
|
||||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||||
|
|
||||||
|
@ -214,7 +216,8 @@ def step_server_api_key(context, server_api_key):
|
||||||
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_oai_chat_completions(context, api_error):
|
async def step_oai_chat_completions(context, api_error):
|
||||||
print(f"Submitting OAI compatible completions request...")
|
if context.debug:
|
||||||
|
print(f"Submitting OAI compatible completions request...")
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
completion = await oai_chat_completions(context.prompts.pop(),
|
completion = await oai_chat_completions(context.prompts.pop(),
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
|
@ -236,11 +239,13 @@ async def step_oai_chat_completions(context, api_error):
|
||||||
|
|
||||||
expect_api_error=expect_api_error)
|
expect_api_error=expect_api_error)
|
||||||
context.tasks_result.append(completion)
|
context.tasks_result.append(completion)
|
||||||
print(f"Completion response: {completion}")
|
if context.debug:
|
||||||
|
print(f"Completion response: {completion}")
|
||||||
if expect_api_error:
|
if expect_api_error:
|
||||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||||
|
|
||||||
print(f"Completion response: {completion}")
|
if context.debug:
|
||||||
|
print(f"Completion response: {completion}")
|
||||||
|
|
||||||
|
|
||||||
@step(u'a prompt')
|
@step(u'a prompt')
|
||||||
|
@ -260,6 +265,7 @@ async def step_concurrent_completion_requests(context):
|
||||||
request_completion,
|
request_completion,
|
||||||
# prompt is inserted automatically
|
# prompt is inserted automatically
|
||||||
context.base_url,
|
context.base_url,
|
||||||
|
debug=context.debug,
|
||||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||||
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||||
user_api_key=context.user_api_key if hasattr(context,
|
user_api_key=context.user_api_key if hasattr(context,
|
||||||
|
@ -397,7 +403,8 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
|
||||||
|
|
||||||
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
||||||
n_prompts = len(context.prompts)
|
n_prompts = len(context.prompts)
|
||||||
print(f"starting {n_prompts} concurrent completion requests...")
|
if context.debug:
|
||||||
|
print(f"starting {n_prompts} concurrent completion requests...")
|
||||||
assert n_prompts > 0
|
assert n_prompts > 0
|
||||||
for prompt_no in range(n_prompts):
|
for prompt_no in range(n_prompts):
|
||||||
shifted_args = [context.prompts.pop(), *args]
|
shifted_args = [context.prompts.pop(), *args]
|
||||||
|
@ -407,17 +414,20 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs)
|
||||||
|
|
||||||
async def request_completion(prompt,
|
async def request_completion(prompt,
|
||||||
base_url,
|
base_url,
|
||||||
|
debug=False,
|
||||||
n_predict=None,
|
n_predict=None,
|
||||||
server_seed=None,
|
server_seed=None,
|
||||||
expect_api_error=None,
|
expect_api_error=None,
|
||||||
user_api_key=None):
|
user_api_key=None):
|
||||||
print(f"Sending completion request: {prompt}")
|
if debug:
|
||||||
|
print(f"Sending completion request: {prompt}")
|
||||||
origin = "my.super.domain"
|
origin = "my.super.domain"
|
||||||
headers = {
|
headers = {
|
||||||
'Origin': origin
|
'Origin': origin
|
||||||
}
|
}
|
||||||
if user_api_key is not None:
|
if user_api_key is not None:
|
||||||
print(f"Set user_api_key: {user_api_key}")
|
if debug:
|
||||||
|
print(f"Set user_api_key: {user_api_key}")
|
||||||
headers['Authorization'] = f'Bearer {user_api_key}'
|
headers['Authorization'] = f'Bearer {user_api_key}'
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
@ -440,13 +450,15 @@ async def oai_chat_completions(user_prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
base_url,
|
base_url,
|
||||||
async_client,
|
async_client,
|
||||||
|
debug=False,
|
||||||
model=None,
|
model=None,
|
||||||
n_predict=None,
|
n_predict=None,
|
||||||
enable_streaming=None,
|
enable_streaming=None,
|
||||||
server_seed=None,
|
server_seed=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
expect_api_error=None):
|
expect_api_error=None):
|
||||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
if debug:
|
||||||
|
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||||
# openai client always expects an api key
|
# openai client always expects an api key
|
||||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
seed = server_seed if server_seed is not None else 42
|
seed = server_seed if server_seed is not None else 42
|
||||||
|
@ -548,7 +560,8 @@ async def oai_chat_completions(user_prompt,
|
||||||
'predicted_n': chat_completion.usage.completion_tokens
|
'predicted_n': chat_completion.usage.completion_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
print("OAI response formatted to llama.cpp:", completion_response)
|
if debug:
|
||||||
|
print("OAI response formatted to llama.cpp:", completion_response)
|
||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
|
@ -579,7 +592,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
||||||
|
|
||||||
async def gather_tasks_results(context):
|
async def gather_tasks_results(context):
|
||||||
n_tasks = len(context.concurrent_tasks)
|
n_tasks = len(context.concurrent_tasks)
|
||||||
print(f"Waiting for all {n_tasks} tasks results...")
|
if context.debug:
|
||||||
|
print(f"Waiting for all {n_tasks} tasks results...")
|
||||||
for task_no in range(n_tasks):
|
for task_no in range(n_tasks):
|
||||||
context.tasks_result.append(await context.concurrent_tasks.pop())
|
context.tasks_result.append(await context.concurrent_tasks.pop())
|
||||||
n_completions = len(context.tasks_result)
|
n_completions = len(context.tasks_result)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue