Merge branch 'tool-call' of github.com:ochafik/llama.cpp into tool-call
This commit is contained in:
commit
40cc3f2fde
7 changed files with 86 additions and 69 deletions
|
@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
std::shared_ptr<common_chat_parser> chat_parser;
|
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
@ -1191,7 +1190,6 @@ struct server_slot {
|
||||||
|
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
std::shared_ptr<common_chat_parser> chat_parser;
|
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
@ -1200,6 +1198,8 @@ struct server_slot {
|
||||||
|
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
|
|
||||||
|
common_chat_parser chat_parser;
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
|
|
||||||
|
@ -3998,8 +3998,6 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
auto body = json::parse(req.body);
|
auto body = json::parse(req.body);
|
||||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||||
LOG_INF("Request: %s\n", body.dump(2).c_str());
|
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||||
|
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
|
|
|
@ -61,28 +61,7 @@ WEATHER_TOOL = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
|
||||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
|
||||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
|
||||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
|
||||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
|
||||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
|
||||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
|
||||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
|
||||||
("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
|
||||||
# TODO: fix these
|
|
||||||
])
|
|
||||||
def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
global server
|
global server
|
||||||
# server = ServerPreset.stories15m_moe()
|
# server = ServerPreset.stories15m_moe()
|
||||||
|
@ -117,6 +96,40 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
|
||||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||||
|
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||||
|
])
|
||||||
|
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||||
|
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||||
|
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||||
|
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||||
|
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||||
|
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||||
|
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||||
|
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||||
|
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||||
|
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||||
|
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||||
|
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||||
|
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||||
|
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||||
|
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||||
|
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||||
|
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||||
|
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||||
|
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||||
|
("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||||
|
])
|
||||||
|
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
||||||
|
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
||||||
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
@ -154,7 +167,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
|
@ -183,18 +196,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||||
("meetkai-functionary-medium-v3.1", 128, [], None),
|
|
||||||
("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None),
|
|
||||||
("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'),
|
|
||||||
("meetkai-functionary-medium-v3.2", 128, [], None),
|
|
||||||
("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None),
|
|
||||||
("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'),
|
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [], None),
|
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [TEST_TOOL], None),
|
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
|
||||||
])
|
|
||||||
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
|
||||||
global server
|
global server
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
|
@ -217,6 +219,31 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||||
|
])
|
||||||
|
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||||
|
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||||
|
("meetkai-functionary-medium-v3.1", 128, [], None),
|
||||||
|
("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None),
|
||||||
|
("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'),
|
||||||
|
("meetkai-functionary-medium-v3.2", 128, [], None),
|
||||||
|
("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None),
|
||||||
|
("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'),
|
||||||
|
("meta-llama-Llama-3.2-3B-Instruct", 128, [], None),
|
||||||
|
("meta-llama-Llama-3.2-3B-Instruct", 128, [TEST_TOOL], None),
|
||||||
|
("meta-llama-Llama-3.2-3B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||||
|
])
|
||||||
|
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||||
|
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
||||||
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
@ -243,7 +270,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start(timeout_seconds=15*60)
|
server.start(timeout_seconds=15*60)
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": 256,
|
"max_tokens": 256,
|
||||||
|
@ -292,7 +319,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start(timeout_seconds=15*60)
|
server.start(timeout_seconds=15*60)
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": 256,
|
"max_tokens": 256,
|
||||||
|
|
|
@ -596,6 +596,11 @@ static json oaicompat_completion_params_parse(
|
||||||
throw std::runtime_error("tools param requires --jinja flag");
|
throw std::runtime_error("tools param requires --jinja flag");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!use_jinja) {
|
||||||
|
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
|
||||||
|
throw std::runtime_error("Unsupported param: tool_choice");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
|
@ -605,7 +610,6 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle "response_format" field
|
// Handle "response_format" field
|
||||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
|
||||||
if (body.contains("response_format")) {
|
if (body.contains("response_format")) {
|
||||||
json response_format = json_value(body, "response_format", json::object());
|
json response_format = json_value(body, "response_format", json::object());
|
||||||
std::string response_type = json_value(response_format, "type", std::string());
|
std::string response_type = json_value(response_format, "type", std::string());
|
||||||
|
@ -649,16 +653,6 @@ static json oaicompat_completion_params_parse(
|
||||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
|
||||||
if (!use_jinja) {
|
|
||||||
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
|
||||||
for (const auto & param : unsupported_params) {
|
|
||||||
if (body.contains(param)) {
|
|
||||||
throw std::runtime_error("Unsupported param: " + param);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy remaining properties to llama_params
|
// Copy remaining properties to llama_params
|
||||||
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
|
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
|
||||||
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
If a model has multiple chat templates, you can specify the variant name.
|
If a model has multiple chat templates, you can specify the variant name.
|
||||||
|
|
||||||
Syntax:
|
Syntax:
|
||||||
./scripts/get_hf_chat_template.py model_id [variant]
|
./scripts/get_chat_template.py model_id [variant]
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
||||||
./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
||||||
./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
@ -17,7 +17,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def get_hf_chat_template(model_id, variant=None):
|
def get_chat_template(model_id, variant=None):
|
||||||
try:
|
try:
|
||||||
# Use huggingface_hub library if available.
|
# Use huggingface_hub library if available.
|
||||||
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
|
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
|
||||||
|
@ -69,9 +69,10 @@ def main(args):
|
||||||
model_id = args[0]
|
model_id = args[0]
|
||||||
variant = None if len(args) < 2 else args[1]
|
variant = None if len(args) < 2 else args[1]
|
||||||
|
|
||||||
template = get_hf_chat_template(model_id, variant)
|
template = get_chat_template(model_id, variant)
|
||||||
sys.stdout.write(template)
|
sys.stdout.write(template)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main(sys.argv[1:])
|
main(sys.argv[1:])
|
||||||
|
|
|
@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
|
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
|
||||||
rules.clear();
|
rules.clear();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,8 +118,8 @@ struct llama_grammar {
|
||||||
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
||||||
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||||
// (useful e.g. for tool_choice=required)
|
// (useful e.g. for tool_choice=required)
|
||||||
bool lazy; // Useful when resetting
|
bool lazy;
|
||||||
bool awaiting_trigger; // Initialized to lazy
|
bool awaiting_trigger; // Initialized to true for lazy grammars only
|
||||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||||
std::vector<std::string> trigger_words;
|
std::vector<std::string> trigger_words;
|
||||||
|
|
|
@ -169,9 +169,6 @@ struct delta_data {
|
||||||
};
|
};
|
||||||
|
|
||||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
|
||||||
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
|
||||||
|
|
||||||
common_chat_params params;
|
common_chat_params params;
|
||||||
params.parallel_tool_calls = true;
|
params.parallel_tool_calls = true;
|
||||||
params.messages = json::array();
|
params.messages = json::array();
|
||||||
|
@ -209,12 +206,14 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||||
return {delta, full_data.grammar, full_data.parser};
|
return {delta, full_data.grammar, full_data.parser};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
|
||||||
|
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
|
||||||
|
the parsed message is the same as the test_message
|
||||||
|
*/
|
||||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) {
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) {
|
||||||
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
|
||||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||||
|
|
||||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
|
||||||
// get the diff and try and parse it w/ the grammar.
|
|
||||||
auto user_message = json {
|
auto user_message = json {
|
||||||
{"role", "user"},
|
{"role", "user"},
|
||||||
{"content", "Hello, world!"}
|
{"content", "Hello, world!"}
|
||||||
|
@ -228,7 +227,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
params.tools = tools;
|
params.tools = tools;
|
||||||
|
|
||||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||||
std::cout << "Full delta:\n```\n" << data.delta << "\n```" << std::endl;
|
|
||||||
if (!expected_delta.empty()) {
|
if (!expected_delta.empty()) {
|
||||||
assert_equals(expected_delta, data.delta);
|
assert_equals(expected_delta, data.delta);
|
||||||
}
|
}
|
||||||
|
@ -449,7 +447,6 @@ static void test_template_output_parsers() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// test_parsing();
|
|
||||||
test_template_output_parsers();
|
test_template_output_parsers();
|
||||||
|
|
||||||
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue