diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index bbe7c3f6c..50f2b641e 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -146,12 +146,14 @@ async def step_request_completion(context, api_error): expect_api_error = api_error == 'raised' completion = await request_completion(context.prompts.pop(), context.base_url, + debug=context.debug, n_predict=context.n_predict, server_seed=context.server_seed, expect_api_error=expect_api_error, user_api_key=context.user_api_key) context.tasks_result.append(completion) - print(f"Completion response: {completion}") + if context.debug: + print(f"Completion response: {completion}") if expect_api_error: assert completion == 401, f"completion must be an 401 status code: {completion}" @@ -214,7 +216,8 @@ def step_server_api_key(context, server_api_key): @step(u'an OAI compatible chat completions request with {api_error} api error') @async_run_until_complete async def step_oai_chat_completions(context, api_error): - print(f"Submitting OAI compatible completions request...") + if context.debug: + print(f"Submitting OAI compatible completions request...") expect_api_error = api_error == 'raised' completion = await oai_chat_completions(context.prompts.pop(), context.system_prompt, @@ -236,11 +239,13 @@ async def step_oai_chat_completions(context, api_error): expect_api_error=expect_api_error) context.tasks_result.append(completion) - print(f"Completion response: {completion}") + if context.debug: + print(f"Completion response: {completion}") if expect_api_error: assert completion == 401, f"completion must be an 401 status code: {completion}" - print(f"Completion response: {completion}") + if context.debug: + print(f"Completion response: {completion}") @step(u'a prompt') @@ -260,6 +265,7 @@ async def step_concurrent_completion_requests(context): request_completion, # prompt is inserted automatically context.base_url, + debug=context.debug, n_predict=context.n_predict if hasattr(context, 'n_predict') else None, server_seed=context.server_seed if hasattr(context, 'server_seed') else None, user_api_key=context.user_api_key if hasattr(context, @@ -397,7 +403,8 @@ def step_check_options_header_value(context, cors_header, cors_header_value): async def concurrent_completion_requests(context, f_completion, *args, **kwargs): n_prompts = len(context.prompts) - print(f"starting {n_prompts} concurrent completion requests...") + if context.debug: + print(f"starting {n_prompts} concurrent completion requests...") assert n_prompts > 0 for prompt_no in range(n_prompts): shifted_args = [context.prompts.pop(), *args] @@ -407,17 +414,20 @@ async def concurrent_completion_requests(context, f_completion, *args, **kwargs) async def request_completion(prompt, base_url, + debug=False, n_predict=None, server_seed=None, expect_api_error=None, user_api_key=None): - print(f"Sending completion request: {prompt}") + if debug: + print(f"Sending completion request: {prompt}") origin = "my.super.domain" headers = { 'Origin': origin } if user_api_key is not None: - print(f"Set user_api_key: {user_api_key}") + if debug: + print(f"Set user_api_key: {user_api_key}") headers['Authorization'] = f'Bearer {user_api_key}' async with aiohttp.ClientSession() as session: @@ -440,13 +450,15 @@ async def oai_chat_completions(user_prompt, system_prompt, base_url, async_client, + debug=False, model=None, n_predict=None, enable_streaming=None, server_seed=None, user_api_key=None, expect_api_error=None): - print(f"Sending OAI Chat completions request: {user_prompt}") + if debug: + print(f"Sending OAI Chat completions request: {user_prompt}") # openai client always expects an api key user_api_key = user_api_key if user_api_key is not None else 'nope' seed = server_seed if server_seed is not None else 42 @@ -548,7 +560,8 @@ async def oai_chat_completions(user_prompt, 'predicted_n': chat_completion.usage.completion_tokens } } - print("OAI response formatted to llama.cpp:", completion_response) + if debug: + print("OAI response formatted to llama.cpp:", completion_response) return completion_response @@ -579,7 +592,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re async def gather_tasks_results(context): n_tasks = len(context.concurrent_tasks) - print(f"Waiting for all {n_tasks} tasks results...") + if context.debug: + print(f"Waiting for all {n_tasks} tasks results...") for task_no in range(n_tasks): context.tasks_result.append(await context.concurrent_tasks.pop()) n_completions = len(context.tasks_result)