From 78e3cb3cf2cffac8b6ae237e0277ac59981634d7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 20 Nov 2024 21:35:31 +0100 Subject: [PATCH] add parallel completion test --- examples/server/tests/unit/test_completion.py | 115 +++++++++++++++--- examples/server/tests/utils.py | 55 +++++---- 2 files changed, 129 insertions(+), 41 deletions(-) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 286b5e4c6..90ad741cd 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -1,4 +1,5 @@ import pytest +import time from openai import OpenAI from utils import * @@ -10,7 +11,6 @@ def create_server(): global server server = ServerPreset.tinyllama2() - @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), @@ -52,24 +52,6 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp content += data["content"] -# FIXME: This test is not working because /completions endpoint is not OAI-compatible -@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now") -def test_completion_with_openai_library(): - global server - server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") - res = client.completions.create( - model="gpt-3.5-turbo-instruct", - prompt="I believe the meaning of life is", - max_tokens=8, - seed=42, - temperature=0.8, - ) - print(res) - assert res.choices[0].finish_reason == "length" - assert match_regex("(going|bed)+", res.choices[0].text) - - @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server @@ -121,4 +103,97 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float assert res.body["content"] == last_res.body["content"] last_res = res -# TODO: add completion with tokens as input, mixed token+string input + +def test_completion_with_tokens_input(): + global server + server.temperature = 0.0 + server.start() + prompt_str = "I believe the meaning of life is" + res = server.make_request("POST", "/tokenize", data={ + "content": prompt_str, + "add_special": True, + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + + # single completion + res = server.make_request("POST", "/completion", data={ + "prompt": tokens, + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + # batch completion + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, tokens], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens in one sequence + res = server.make_request("POST", "/completion", data={ + "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 3), + (2, 2), + (2, 4), + (4, 2), # some slots must be idle + (4, 6), +]) +def test_completion_parallel_slots(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.temperature = 0.0 + server.start() + + PROMPTS = [ + ("Write a very long book.", "(very|special|big)+"), + ("Write another a poem.", "(small|house)+"), + ("What is LLM?", "(Dad|said)+"), + ("The sky is blue and I love it.", "(climb|leaf)+"), + ("Write another very long music lyrics.", "(friends|step|sky)+"), + ("Write a very long joke.", "(cat|Whiskers)+"), + ] + def check_slots_status(): + should_all_slots_busy = n_requests >= n_slots + time.sleep(0.1) + res = server.make_request("GET", "/slots") + n_busy = sum([1 for slot in res.body if slot["is_processing"]]) + if should_all_slots_busy: + assert n_busy == n_slots + else: + assert n_busy <= n_slots + + tasks = [] + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": prompt, + "seed": 42, + "temperature": 1.0, + }))) + tasks.append((check_slots_status, ())) + results = parallel_function_calls(tasks) + + # check results + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + res = results[i] + assert res.status_code == 200 + assert match_regex(re_content, res.body["content"]) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 0808a92b2..75ada2913 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -11,6 +11,7 @@ import sys import threading import requests import time +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import ( Any, Callable, @@ -19,7 +20,7 @@ from typing import ( Iterator, List, Literal, - Sequence, + Tuple, Set, ) from re import RegexFlag @@ -28,7 +29,7 @@ from re import RegexFlag class ServerResponse: headers: dict status_code: int - body: dict + body: dict | Any class ServerProcess: @@ -322,30 +323,42 @@ class ServerPreset: return server -def multiple_post_requests( - server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None -) -> Sequence[ServerResponse]: - def worker(data_chunk): +def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: + """ + Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS. + + Example usage: + + results = parallel_function_calls([ + (func1, (arg1, arg2)), + (func2, (arg3, arg4)), + ]) + """ + results = [None] * len(function_list) + exceptions = [] + + def worker(index, func, args): try: - return server.make_request("POST", path, data=data_chunk, headers=headers) + result = func(*args) + results[index] = result except Exception as e: - print(f"Error occurred: {e}", file=sys.stderr) - os._exit(1) # terminate main thread + exceptions.append((index, str(e))) - threads = [] - results = [] + with ThreadPoolExecutor() as executor: + futures = [] + for i, (func, args) in enumerate(function_list): + future = executor.submit(worker, i, func, args) + futures.append(future) - def thread_target(data_chunk): - result = worker(data_chunk) - results.append(result) + # Wait for all futures to complete + for future in as_completed(futures): + pass - for chunk in data: - thread = threading.Thread(target=thread_target, args=(chunk,)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() + # Check if there were any exceptions + if exceptions: + print("Exceptions occurred:") + for index, error in exceptions: + print(f"Function at index {index}: {error}") return results