From 11adf1d8644c5343beec221b84613d76f78d517b Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Tue, 20 Feb 2024 22:00:09 +0100 Subject: [PATCH] server: tests: add OAI multi user scenario --- examples/server/tests/features/server.feature | 23 +++++ examples/server/tests/features/steps/steps.py | 92 +++++++++++-------- 2 files changed, 76 insertions(+), 39 deletions(-) diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index a14d1459a..78ba2bec9 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -57,3 +57,26 @@ Feature: llama.cpp server Then the server is busy Then the server is idle Then all prompts are predicted + + + Scenario: Multi users OAI Compatibility + Given a system prompt "You are an AI assistant." + And a model tinyllama-2 + And 1024 max tokens to predict + And streaming is enabled + Given a prompt: + """ + Write a very long story about AI. + """ + And a prompt: + """ + Write another very long music lyrics. + """ + And a prompt: + """ + Write yet another very long music lyrics. + """ + Given concurrent OAI completions requests + Then the server is busy + Then the server is idle + Then all prompts are predicted \ No newline at end of file diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index f9823b51f..6d714ae92 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -16,6 +16,7 @@ def step_server_config(context, server_fqdn, server_port, n_slots): context.base_url = f'http://{context.server_fqdn}:{context.server_port}' context.completions = [] + context.completion_threads = [] context.prompts = [] openai.api_key = 'llama.cpp' @@ -76,11 +77,58 @@ def step_max_tokens(context, max_tokens): @step(u'streaming is {enable_streaming}') def step_streaming(context, enable_streaming): - context.enable_streaming = bool(enable_streaming) + context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming) @step(u'an OAI compatible chat completions request') def step_oai_chat_completions(context): + oai_chat_completions(context, context.user_prompt) + + +@step(u'a prompt') +def step_a_prompt(context): + context.prompts.append(context.text) + + +@step(u'concurrent completion requests') +def step_concurrent_completion_requests(context): + concurrent_requests(context, request_completion) + + +@step(u'concurrent OAI completions requests') +def step_oai_chat_completions(context): + concurrent_requests(context, oai_chat_completions) + + +@step(u'all prompts are predicted') +def step_all_prompts_are_predicted(context): + for completion_thread in context.completion_threads: + completion_thread.join() + for completion in context.completions: + assert_n_tokens_predicted(completion) + + +def concurrent_requests(context, f_completion): + context.completions.clear() + context.completion_threads.clear() + for prompt in context.prompts: + completion_thread = threading.Thread(target=f_completion, args=(context, prompt)) + completion_thread.start() + context.completion_threads.append(completion_thread) + context.prompts.clear() + + +def request_completion(context, prompt, n_predict=None): + response = requests.post(f'{context.base_url}/completion', json={ + "prompt": prompt, + "n_predict": int(n_predict) if n_predict is not None else 4096, + }) + status_code = response.status_code + assert status_code == 200 + context.completions.append(response.json()) + + +def oai_chat_completions(context, user_prompt): chat_completion = openai.Completion.create( messages=[ { @@ -89,7 +137,7 @@ def step_oai_chat_completions(context): }, { "role": "user", - "content": context.user_prompt, + "content": user_prompt, } ], model=context.model, @@ -120,39 +168,6 @@ def step_oai_chat_completions(context): }) -@step(u'a prompt') -def step_a_prompt(context): - context.prompts.append(context.text) - - -@step(u'concurrent completion requests') -def step_n_concurrent_prompts(context): - context.completions.clear() - context.completion_threads = [] - for prompt in context.prompts: - completion_thread = threading.Thread(target=request_completion, args=(context, prompt)) - completion_thread.start() - context.completion_threads.append(completion_thread) - - -@step(u'all prompts are predicted') -def step_all_prompts_must_be_predicted(context): - for completion_thread in context.completion_threads: - completion_thread.join() - for completion in context.completions: - assert_n_tokens_predicted(completion) - - -def request_completion(context, prompt, n_predict=None): - response = requests.post(f'{context.base_url}/completion', json={ - "prompt": prompt, - "n_predict": int(n_predict) if n_predict is not None else 4096, - }) - status_code = response.status_code - assert status_code == 200 - context.completions.append(response.json()) - - def assert_n_tokens_predicted(completion_response, expected_predicted_n=None): content = completion_response['content'] n_predicted = completion_response['timings']['predicted_n'] @@ -163,10 +178,9 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None): def wait_for_health_status(context, expected_http_status_code, expected_health_status, params=None): - status_code = 500 - while status_code != expected_http_status_code: + while True: health_response = requests.get(f'{context.base_url}/health', params) status_code = health_response.status_code health = health_response.json() - if status_code != expected_http_status_code or health['status'] != expected_health_status: - time.sleep(0.001) + if status_code == expected_http_status_code and health['status'] == expected_health_status: + break