add parallel completion test
This commit is contained in:
parent
1c2f0f708c
commit
78e3cb3cf2
2 changed files with 129 additions and 41 deletions
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import time
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
|
@ -10,7 +11,6 @@ def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||||
|
@ -52,24 +52,6 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
||||||
content += data["content"]
|
content += data["content"]
|
||||||
|
|
||||||
|
|
||||||
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
|
|
||||||
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
|
|
||||||
def test_completion_with_openai_library():
|
|
||||||
global server
|
|
||||||
server.start()
|
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
|
||||||
res = client.completions.create(
|
|
||||||
model="gpt-3.5-turbo-instruct",
|
|
||||||
prompt="I believe the meaning of life is",
|
|
||||||
max_tokens=8,
|
|
||||||
seed=42,
|
|
||||||
temperature=0.8,
|
|
||||||
)
|
|
||||||
print(res)
|
|
||||||
assert res.choices[0].finish_reason == "length"
|
|
||||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||||
def test_consistent_result_same_seed(n_slots: int):
|
def test_consistent_result_same_seed(n_slots: int):
|
||||||
global server
|
global server
|
||||||
|
@ -121,4 +103,97 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
|
||||||
assert res.body["content"] == last_res.body["content"]
|
assert res.body["content"] == last_res.body["content"]
|
||||||
last_res = res
|
last_res = res
|
||||||
|
|
||||||
# TODO: add completion with tokens as input, mixed token+string input
|
|
||||||
|
def test_completion_with_tokens_input():
|
||||||
|
global server
|
||||||
|
server.temperature = 0.0
|
||||||
|
server.start()
|
||||||
|
prompt_str = "I believe the meaning of life is"
|
||||||
|
res = server.make_request("POST", "/tokenize", data={
|
||||||
|
"content": prompt_str,
|
||||||
|
"add_special": True,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
tokens = res.body["tokens"]
|
||||||
|
|
||||||
|
# single completion
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": tokens,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert type(res.body["content"]) == str
|
||||||
|
|
||||||
|
# batch completion
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": [tokens, tokens],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert type(res.body) == list
|
||||||
|
assert len(res.body) == 2
|
||||||
|
assert res.body[0]["content"] == res.body[1]["content"]
|
||||||
|
|
||||||
|
# mixed string and tokens
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": [tokens, prompt_str],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert type(res.body) == list
|
||||||
|
assert len(res.body) == 2
|
||||||
|
assert res.body[0]["content"] == res.body[1]["content"]
|
||||||
|
|
||||||
|
# mixed string and tokens in one sequence
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert type(res.body["content"]) == str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||||
|
(1, 3),
|
||||||
|
(2, 2),
|
||||||
|
(2, 4),
|
||||||
|
(4, 2), # some slots must be idle
|
||||||
|
(4, 6),
|
||||||
|
])
|
||||||
|
def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||||
|
global server
|
||||||
|
server.n_slots = n_slots
|
||||||
|
server.temperature = 0.0
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
|
("Write a very long book.", "(very|special|big)+"),
|
||||||
|
("Write another a poem.", "(small|house)+"),
|
||||||
|
("What is LLM?", "(Dad|said)+"),
|
||||||
|
("The sky is blue and I love it.", "(climb|leaf)+"),
|
||||||
|
("Write another very long music lyrics.", "(friends|step|sky)+"),
|
||||||
|
("Write a very long joke.", "(cat|Whiskers)+"),
|
||||||
|
]
|
||||||
|
def check_slots_status():
|
||||||
|
should_all_slots_busy = n_requests >= n_slots
|
||||||
|
time.sleep(0.1)
|
||||||
|
res = server.make_request("GET", "/slots")
|
||||||
|
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
|
||||||
|
if should_all_slots_busy:
|
||||||
|
assert n_busy == n_slots
|
||||||
|
else:
|
||||||
|
assert n_busy <= n_slots
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for i in range(n_requests):
|
||||||
|
prompt, re_content = PROMPTS[i % len(PROMPTS)]
|
||||||
|
tasks.append((server.make_request, ("POST", "/completion", {
|
||||||
|
"prompt": prompt,
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 1.0,
|
||||||
|
})))
|
||||||
|
tasks.append((check_slots_status, ()))
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
|
||||||
|
# check results
|
||||||
|
for i in range(n_requests):
|
||||||
|
prompt, re_content = PROMPTS[i % len(PROMPTS)]
|
||||||
|
res = results[i]
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert match_regex(re_content, res.body["content"])
|
||||||
|
|
|
@ -11,6 +11,7 @@ import sys
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -19,7 +20,7 @@ from typing import (
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Sequence,
|
Tuple,
|
||||||
Set,
|
Set,
|
||||||
)
|
)
|
||||||
from re import RegexFlag
|
from re import RegexFlag
|
||||||
|
@ -28,7 +29,7 @@ from re import RegexFlag
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
headers: dict
|
headers: dict
|
||||||
status_code: int
|
status_code: int
|
||||||
body: dict
|
body: dict | Any
|
||||||
|
|
||||||
|
|
||||||
class ServerProcess:
|
class ServerProcess:
|
||||||
|
@ -322,30 +323,42 @@ class ServerPreset:
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
||||||
def multiple_post_requests(
|
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
||||||
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
|
"""
|
||||||
) -> Sequence[ServerResponse]:
|
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
|
||||||
def worker(data_chunk):
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
results = parallel_function_calls([
|
||||||
|
(func1, (arg1, arg2)),
|
||||||
|
(func2, (arg3, arg4)),
|
||||||
|
])
|
||||||
|
"""
|
||||||
|
results = [None] * len(function_list)
|
||||||
|
exceptions = []
|
||||||
|
|
||||||
|
def worker(index, func, args):
|
||||||
try:
|
try:
|
||||||
return server.make_request("POST", path, data=data_chunk, headers=headers)
|
result = func(*args)
|
||||||
|
results[index] = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error occurred: {e}", file=sys.stderr)
|
exceptions.append((index, str(e)))
|
||||||
os._exit(1) # terminate main thread
|
|
||||||
|
|
||||||
threads = []
|
with ThreadPoolExecutor() as executor:
|
||||||
results = []
|
futures = []
|
||||||
|
for i, (func, args) in enumerate(function_list):
|
||||||
|
future = executor.submit(worker, i, func, args)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
def thread_target(data_chunk):
|
# Wait for all futures to complete
|
||||||
result = worker(data_chunk)
|
for future in as_completed(futures):
|
||||||
results.append(result)
|
pass
|
||||||
|
|
||||||
for chunk in data:
|
# Check if there were any exceptions
|
||||||
thread = threading.Thread(target=thread_target, args=(chunk,))
|
if exceptions:
|
||||||
threads.append(thread)
|
print("Exceptions occurred:")
|
||||||
thread.start()
|
for index, error in exceptions:
|
||||||
|
print(f"Function at index {index}: {error}")
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue