add parallel completion test

This commit is contained in:
Xuan Son Nguyen 2024-11-20 21:35:31 +01:00
parent 1c2f0f708c
commit 78e3cb3cf2
2 changed files with 129 additions and 41 deletions

View file

@ -1,4 +1,5 @@
import pytest
import time
from openai import OpenAI
from utils import *
@ -10,7 +11,6 @@ def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
@ -52,24 +52,6 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
content += data["content"]
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.completions.create(
model="gpt-3.5-turbo-instruct",
prompt="I believe the meaning of life is",
max_tokens=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "length"
assert match_regex("(going|bed)+", res.choices[0].text)
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
@ -121,4 +103,97 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
assert res.body["content"] == last_res.body["content"]
last_res = res
# TODO: add completion with tokens as input, mixed token+string input
def test_completion_with_tokens_input():
global server
server.temperature = 0.0
server.start()
prompt_str = "I believe the meaning of life is"
res = server.make_request("POST", "/tokenize", data={
"content": prompt_str,
"add_special": True,
})
assert res.status_code == 200
tokens = res.body["tokens"]
# single completion
res = server.make_request("POST", "/completion", data={
"prompt": tokens,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
# batch completion
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, tokens],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, prompt_str],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens in one sequence
res = server.make_request("POST", "/completion", data={
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 3),
(2, 2),
(2, 4),
(4, 2), # some slots must be idle
(4, 6),
])
def test_completion_parallel_slots(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.temperature = 0.0
server.start()
PROMPTS = [
("Write a very long book.", "(very|special|big)+"),
("Write another a poem.", "(small|house)+"),
("What is LLM?", "(Dad|said)+"),
("The sky is blue and I love it.", "(climb|leaf)+"),
("Write another very long music lyrics.", "(friends|step|sky)+"),
("Write a very long joke.", "(cat|Whiskers)+"),
]
def check_slots_status():
should_all_slots_busy = n_requests >= n_slots
time.sleep(0.1)
res = server.make_request("GET", "/slots")
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
if should_all_slots_busy:
assert n_busy == n_slots
else:
assert n_busy <= n_slots
tasks = []
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": prompt,
"seed": 42,
"temperature": 1.0,
})))
tasks.append((check_slots_status, ()))
results = parallel_function_calls(tasks)
# check results
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
res = results[i]
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])

View file

@ -11,6 +11,7 @@ import sys
import threading
import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import (
Any,
Callable,
@ -19,7 +20,7 @@ from typing import (
Iterator,
List,
Literal,
Sequence,
Tuple,
Set,
)
from re import RegexFlag
@ -28,7 +29,7 @@ from re import RegexFlag
class ServerResponse:
headers: dict
status_code: int
body: dict
body: dict | Any
class ServerProcess:
@ -322,30 +323,42 @@ class ServerPreset:
return server
def multiple_post_requests(
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
) -> Sequence[ServerResponse]:
def worker(data_chunk):
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
"""
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
Example usage:
results = parallel_function_calls([
(func1, (arg1, arg2)),
(func2, (arg3, arg4)),
])
"""
results = [None] * len(function_list)
exceptions = []
def worker(index, func, args):
try:
return server.make_request("POST", path, data=data_chunk, headers=headers)
result = func(*args)
results[index] = result
except Exception as e:
print(f"Error occurred: {e}", file=sys.stderr)
os._exit(1) # terminate main thread
exceptions.append((index, str(e)))
threads = []
results = []
with ThreadPoolExecutor() as executor:
futures = []
for i, (func, args) in enumerate(function_list):
future = executor.submit(worker, i, func, args)
futures.append(future)
def thread_target(data_chunk):
result = worker(data_chunk)
results.append(result)
# Wait for all futures to complete
for future in as_completed(futures):
pass
for chunk in data:
thread = threading.Thread(target=thread_target, args=(chunk,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Check if there were any exceptions
if exceptions:
print("Exceptions occurred:")
for index, error in exceptions:
print(f"Function at index {index}: {error}")
return results