Test chat_template in e2e test
This commit is contained in:
parent
a6afb2735f
commit
b4083e4155
2 changed files with 15 additions and 13 deletions
|
@ -11,17 +11,19 @@ def create_server():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja",
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||||
[
|
[
|
||||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False),
|
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
|
||||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True),
|
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False),
|
(None, "Book", "What is the best book", 8, " blue and shin", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True),
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
||||||
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja):
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||||
global server
|
global server
|
||||||
server.jinja = jinja
|
server.jinja = jinja
|
||||||
|
server.chat_template = chat_template
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
@ -70,13 +70,13 @@ class ServerProcess:
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
lora_files: List[str] | None = None
|
lora_files: List[str] | None = None
|
||||||
chat_template_file: str | None = None
|
|
||||||
jinja: bool | None = None
|
|
||||||
disable_ctx_shift: int | None = False
|
disable_ctx_shift: int | None = False
|
||||||
draft_min: int | None = None
|
draft_min: int | None = None
|
||||||
draft_max: int | None = None
|
draft_max: int | None = None
|
||||||
no_webui: bool | None = None
|
no_webui: bool | None = None
|
||||||
|
jinja: bool | None = None
|
||||||
chat_template: str | None = None
|
chat_template: str | None = None
|
||||||
|
chat_template_file: str | None = None
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
@ -157,10 +157,6 @@ class ServerProcess:
|
||||||
if self.lora_files:
|
if self.lora_files:
|
||||||
for lora_file in self.lora_files:
|
for lora_file in self.lora_files:
|
||||||
server_args.extend(["--lora", lora_file])
|
server_args.extend(["--lora", lora_file])
|
||||||
if self.chat_template_file:
|
|
||||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
|
||||||
if self.jinja:
|
|
||||||
server_args.append("--jinja")
|
|
||||||
if self.disable_ctx_shift:
|
if self.disable_ctx_shift:
|
||||||
server_args.extend(["--no-context-shift"])
|
server_args.extend(["--no-context-shift"])
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
|
@ -171,9 +167,13 @@ class ServerProcess:
|
||||||
server_args.extend(["--draft-min", self.draft_min])
|
server_args.extend(["--draft-min", self.draft_min])
|
||||||
if self.no_webui:
|
if self.no_webui:
|
||||||
server_args.append("--no-webui")
|
server_args.append("--no-webui")
|
||||||
|
if self.jinja:
|
||||||
|
server_args.append("--jinja")
|
||||||
if self.chat_template:
|
if self.chat_template:
|
||||||
server_args.extend(["--chat-template", self.chat_template])
|
server_args.extend(["--chat-template", self.chat_template])
|
||||||
|
if self.chat_template_file:
|
||||||
|
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue