Fix tool-call server tests
This commit is contained in:
parent
0a5d527508
commit
a2fe8a4922
5 changed files with 180 additions and 25 deletions
|
@ -1778,11 +1778,9 @@ minja::chat_template llama_chat_template_from_model(
|
|||
if (chat_template.empty()) {
|
||||
if (prefer_tool_use) {
|
||||
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
|
||||
fprintf(stderr, "# tokenizer.chat_template.tool_use: %s\n", chat_template.c_str());
|
||||
}
|
||||
if (chat_template.empty()) {
|
||||
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
fprintf(stderr, "# tokenizer.chat_template: %s\n", chat_template.c_str());
|
||||
}
|
||||
}
|
||||
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
|
||||
|
|
|
@ -1900,8 +1900,8 @@ struct server_context {
|
|||
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
// const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
|
||||
// const std::string token_str = result.text_to_send;
|
||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
|
||||
slot.sampled = result.tok;
|
||||
|
||||
if (match.pos != std::string::npos && !match.is_partial) {
|
||||
|
|
|
@ -2,10 +2,9 @@ import pytest
|
|||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
server: ServerProcess
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
@ -277,37 +276,41 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
|||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
|
||||
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print(\"Hello World\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print(\"Hello World!\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello World')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print(\"Hello World\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "lmstudio-community/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('hello world')"}, "lmstudio-community/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
# TODO: fix tool call handling of these models
|
||||
# (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
# (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
# (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||
# (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)),
|
||||
])
|
||||
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.use_jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 128
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = hf_file
|
||||
if template_override:
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/fetch_server_test_models.py {template_hf_repo} {template_variant}` to download the template."
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
# else:
|
||||
# server.chat_template_file = None
|
||||
server.start(timeout_seconds=15*60)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 256,
|
||||
|
@ -322,7 +325,10 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
|
|||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
if tool["type"] == "function":
|
||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
elif tool["type"] == "code_interpreter":
|
||||
assert tool_call["function"]["name"] == "python"
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import os
|
|||
from typing import Generator
|
||||
from pydantic import BaseModel
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
class HuggingFaceModel(BaseModel):
|
||||
|
@ -41,7 +40,7 @@ def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, N
|
|||
for dec in node.decorator_list:
|
||||
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
|
||||
param_names = ast.literal_eval(dec.args[0]).split(",")
|
||||
if not "hf_repo" in param_names or not "hf_file" in param_names:
|
||||
if "hf_repo" not in param_names or "hf_file" not in param_names:
|
||||
continue
|
||||
|
||||
raw_param_values = dec.args[1]
|
||||
|
@ -78,8 +77,7 @@ if __name__ == '__main__':
|
|||
'LLAMA_SERVER_BIN_PATH',
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' \
|
||||
else '../build/bin/llama-cli'))
|
||||
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
|
||||
|
||||
for m in models:
|
||||
if '<' in m.hf_repo or '<' in m.hf_file:
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
{%- macro json_to_python_type(json_spec) %}
|
||||
{%- set basic_type_map = {
|
||||
"string": "str",
|
||||
"number": "float",
|
||||
"integer": "int",
|
||||
"boolean": "bool"
|
||||
} %}
|
||||
|
||||
{%- if basic_type_map[json_spec.type] is defined %}
|
||||
{{- basic_type_map[json_spec.type] }}
|
||||
{%- elif json_spec.type == "array" %}
|
||||
{{- "list[" + json_to_python_type(json_spec|items) + "]"}}
|
||||
{%- elif json_spec.type == "object" %}
|
||||
{%- if json_spec.additionalProperties is defined %}
|
||||
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}}
|
||||
{%- else %}
|
||||
{{- "dict" }}
|
||||
{%- endif %}
|
||||
{%- elif json_spec.type is iterable %}
|
||||
{{- "Union[" }}
|
||||
{%- for t in json_spec.type %}
|
||||
{{- json_to_python_type({"type": t}) }}
|
||||
{%- if not loop.last %}
|
||||
{{- "," }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "]" }}
|
||||
{%- else %}
|
||||
{{- "Any" }}
|
||||
{%- endif %}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{{- bos_token }}
|
||||
{{- '<|im_start|>system
|
||||
' }}
|
||||
{{- "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
{%- set tool = tool.function %}
|
||||
{%- endif %}
|
||||
{{- '{"type": "function", "function": ' }}
|
||||
{{- '{"name": "' + tool.name + '", ' }}
|
||||
{{- '"description": "' + tool.name + '(' }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{{- param_name + ": " + json_to_python_type(param_fields) }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- ")" }}
|
||||
{%- if tool.return is defined %}
|
||||
{{- " -> " + json_to_python_type(tool.return) }}
|
||||
{%- endif %}
|
||||
{{- " - " + tool.description + "
|
||||
|
||||
" }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{%- if loop.first %}
|
||||
{{- " Args:
|
||||
" }}
|
||||
{%- endif %}
|
||||
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
|
||||
{%- endfor %}
|
||||
{%- if tool.return is defined and tool.return.description is defined %}
|
||||
{{- "
|
||||
Returns:
|
||||
" + tool.return.description }}
|
||||
{%- endif %}
|
||||
{{- '"' }}
|
||||
{{- ', "parameters": ' }}
|
||||
{%- if tool.parameters.properties | length == 0 %}
|
||||
{{- "{}" }}
|
||||
{%- else %}
|
||||
{{- tool.parameters|tojson }}
|
||||
{%- endif %}
|
||||
{{- "}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- "
|
||||
" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- " </tools>" }}
|
||||
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
|
||||
' }}
|
||||
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
" }}
|
||||
{{- "<tool_call>
|
||||
" }}
|
||||
{{- '{"name": <function-name>, "arguments": <args-dict>}
|
||||
' }}
|
||||
{{- '</tool_call><|im_end|>
|
||||
' }}
|
||||
{%- for message in messages %}
|
||||
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
|
||||
{{- '<|im_start|>' + message.role + '
|
||||
' + message.content + '<|im_end|>' + '
|
||||
' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{{- '
|
||||
<tool_call>
|
||||
' }} {%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '{' }}
|
||||
{{- '"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '"' }}
|
||||
{{- ', '}}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{{- '"arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments|tojson }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{{- '}' }}
|
||||
{{- '
|
||||
</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>
|
||||
' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>tool
|
||||
' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>
|
||||
' }}
|
||||
{{- message.content }}
|
||||
{%- if not loop.last %}
|
||||
{{- '
|
||||
</tool_response>
|
||||
' }}
|
||||
{%- else %}
|
||||
{{- '
|
||||
</tool_response>' }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant
|
||||
' }}
|
||||
{%- endif %}
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue