add parallel completion test
This commit is contained in:
parent
1c2f0f708c
commit
78e3cb3cf2
2 changed files with 129 additions and 41 deletions
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
import time
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
|
@ -10,7 +11,6 @@ def create_server():
|
|||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
|
@ -52,24 +52,6 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
|||
content += data["content"]
|
||||
|
||||
|
||||
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
|
||||
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
|
||||
def test_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="I believe the meaning of life is",
|
||||
max_tokens=8,
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_consistent_result_same_seed(n_slots: int):
|
||||
global server
|
||||
|
@ -121,4 +103,97 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
|
|||
assert res.body["content"] == last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
# TODO: add completion with tokens as input, mixed token+string input
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
server.temperature = 0.0
|
||||
server.start()
|
||||
prompt_str = "I believe the meaning of life is"
|
||||
res = server.make_request("POST", "/tokenize", data={
|
||||
"content": prompt_str,
|
||||
"add_special": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
tokens = res.body["tokens"]
|
||||
|
||||
# single completion
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": tokens,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
|
||||
# batch completion
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [tokens, tokens],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [tokens, prompt_str],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens in one sequence
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||
(1, 3),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 2), # some slots must be idle
|
||||
(4, 6),
|
||||
])
|
||||
def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.temperature = 0.0
|
||||
server.start()
|
||||
|
||||
PROMPTS = [
|
||||
("Write a very long book.", "(very|special|big)+"),
|
||||
("Write another a poem.", "(small|house)+"),
|
||||
("What is LLM?", "(Dad|said)+"),
|
||||
("The sky is blue and I love it.", "(climb|leaf)+"),
|
||||
("Write another very long music lyrics.", "(friends|step|sky)+"),
|
||||
("Write a very long joke.", "(cat|Whiskers)+"),
|
||||
]
|
||||
def check_slots_status():
|
||||
should_all_slots_busy = n_requests >= n_slots
|
||||
time.sleep(0.1)
|
||||
res = server.make_request("GET", "/slots")
|
||||
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
|
||||
if should_all_slots_busy:
|
||||
assert n_busy == n_slots
|
||||
else:
|
||||
assert n_busy <= n_slots
|
||||
|
||||
tasks = []
|
||||
for i in range(n_requests):
|
||||
prompt, re_content = PROMPTS[i % len(PROMPTS)]
|
||||
tasks.append((server.make_request, ("POST", "/completion", {
|
||||
"prompt": prompt,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
})))
|
||||
tasks.append((check_slots_status, ()))
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
# check results
|
||||
for i in range(n_requests):
|
||||
prompt, re_content = PROMPTS[i % len(PROMPTS)]
|
||||
res = results[i]
|
||||
assert res.status_code == 200
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
|
|
@ -11,6 +11,7 @@ import sys
|
|||
import threading
|
||||
import requests
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
@ -19,7 +20,7 @@ from typing import (
|
|||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Set,
|
||||
)
|
||||
from re import RegexFlag
|
||||
|
@ -28,7 +29,7 @@ from re import RegexFlag
|
|||
class ServerResponse:
|
||||
headers: dict
|
||||
status_code: int
|
||||
body: dict
|
||||
body: dict | Any
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
|
@ -322,30 +323,42 @@ class ServerPreset:
|
|||
return server
|
||||
|
||||
|
||||
def multiple_post_requests(
|
||||
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
|
||||
) -> Sequence[ServerResponse]:
|
||||
def worker(data_chunk):
|
||||
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
||||
"""
|
||||
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
|
||||
|
||||
Example usage:
|
||||
|
||||
results = parallel_function_calls([
|
||||
(func1, (arg1, arg2)),
|
||||
(func2, (arg3, arg4)),
|
||||
])
|
||||
"""
|
||||
results = [None] * len(function_list)
|
||||
exceptions = []
|
||||
|
||||
def worker(index, func, args):
|
||||
try:
|
||||
return server.make_request("POST", path, data=data_chunk, headers=headers)
|
||||
result = func(*args)
|
||||
results[index] = result
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}", file=sys.stderr)
|
||||
os._exit(1) # terminate main thread
|
||||
exceptions.append((index, str(e)))
|
||||
|
||||
threads = []
|
||||
results = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i, (func, args) in enumerate(function_list):
|
||||
future = executor.submit(worker, i, func, args)
|
||||
futures.append(future)
|
||||
|
||||
def thread_target(data_chunk):
|
||||
result = worker(data_chunk)
|
||||
results.append(result)
|
||||
# Wait for all futures to complete
|
||||
for future in as_completed(futures):
|
||||
pass
|
||||
|
||||
for chunk in data:
|
||||
thread = threading.Thread(target=thread_target, args=(chunk,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# Check if there were any exceptions
|
||||
if exceptions:
|
||||
print("Exceptions occurred:")
|
||||
for index, error in exceptions:
|
||||
print(f"Function at index {index}: {error}")
|
||||
|
||||
return results
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue