tool-call: add firefunction-v2 style

This commit is contained in:
ochafik 2024-12-07 03:09:50 +00:00
parent 5d0033f57a
commit 1f0b15799b
5 changed files with 151 additions and 50 deletions

View file

@ -67,6 +67,8 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) {
return "CommandRPlus"; return "CommandRPlus";
case llama_tool_call_style::MistralNemo: case llama_tool_call_style::MistralNemo:
return "MistralNemo"; return "MistralNemo";
case llama_tool_call_style::FirefunctionV2:
return "FirefunctionV2";
default: default:
return "Unknown"; return "Unknown";
} }
@ -92,6 +94,8 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template &
return CommandRPlus; return CommandRPlus;
} else if (src.find("[TOOL_CALLS]") != std::string::npos) { } else if (src.find("[TOOL_CALLS]") != std::string::npos) {
return MistralNemo; return MistralNemo;
} else if (src.find(" functools[") != std::string::npos) {
return FirefunctionV2;
} else { } else {
return Generic; return Generic;
} }
@ -315,8 +319,8 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
return result; return result;
} }
static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
auto content_end = input.find("[TOOL_CALLS]"); auto content_end = input.find(prefix);
size_t tc_start = std::string::npos; size_t tc_start = std::string::npos;
llama_tool_calls result; llama_tool_calls result;
@ -330,25 +334,27 @@ static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input)
}); });
} }
}; };
if (content_end != std::string::npos) { if (content_end == std::string::npos) {
tc_start = content_end + 12; result.content = input;
} else {
tc_start = content_end + prefix.size() - rstrip_prefix;
result.content = input.substr(0, content_end); result.content = input.substr(0, content_end);
auto tool_calls = json::parse(input.substr(tc_start)); auto tool_calls = json::parse(input.substr(tc_start));
process_tool_calls(tool_calls); process_tool_calls(tool_calls);
} else {
// Somehow not getting [TOOL_CALLS] in the output. Oh well, just do without it.
try {
auto tool_calls = json::parse(input);
process_tool_calls(tool_calls);
} catch (const json::exception & e) {
throw std::runtime_error("Failed to parse tool calls: " + std::string(e.what()) + ":\n" + input);
}
} }
return result; return result;
} }
static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static llama_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
// fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", llama_tool_call_style_name(style).c_str(), input.c_str()); fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", llama_tool_call_style_name(style).c_str(), input.c_str());
switch (style) { switch (style) {
case llama_tool_call_style::None: case llama_tool_call_style::None:
return {input, {}}; return {input, {}};
@ -366,6 +372,8 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool
return parse_hermes_tool_calls(input); return parse_hermes_tool_calls(input);
case llama_tool_call_style::MistralNemo: case llama_tool_call_style::MistralNemo:
return parse_mistral_nemo_tool_calls(input); return parse_mistral_nemo_tool_calls(input);
case llama_tool_call_style::FirefunctionV2:
return parse_firefunction_v2_tool_calls(input);
default: default:
throw std::runtime_error("Unsupported tool call style"); throw std::runtime_error("Unsupported tool call style");
} }
@ -406,16 +414,14 @@ llama_tool_call_handler llama_tool_call_handler_init(
auto tool_call_schemas = json::array(); auto tool_call_schemas = json::array();
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto tool_schema = json { auto tool_schema = json {
{"type", "object"}, {"type", "object"},
{"properties", { {"properties", {
{"name", { {"name", {
{"type", "string"}, {"type", "string"},
{"const", name}, {"const", function["name"]},
}}, }},
{"arguments", parameters}, {"arguments", function["parameters"]},
}}, }},
{"required", json::array({"name", "arguments"})}, {"required", json::array({"name", "arguments"})},
}; };
@ -483,18 +489,16 @@ llama_tool_call_handler llama_tool_call_handler_init(
auto schemas = json::array(); auto schemas = json::array();
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
std::string name = function["name"]; schemas.push_back({
auto parameters = function["parameters"];
auto schema = json {
{"type", "object"}, {"type", "object"},
{"properties", { {"properties", {
// Important note: the model is probably trained to take a JSON stringified arguments value. // Important note: the model is probably trained to take a JSON stringified arguments value.
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
{"arguments", parameters},
{"name", { {"name", {
{"type", "string"}, {"type", "string"},
{"const", name}, {"const", function["name"]},
}}, }},
{"arguments", function["parameters"]},
{"id", { {"id", {
{"type", "string"}, {"type", "string"},
// Nemo's template expects a 9-character alphanumeric ID. // Nemo's template expects a 9-character alphanumeric ID.
@ -502,8 +506,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
}}, }},
}}, }},
{"required", json::array({"name", "arguments", "id"})}, {"required", json::array({"name", "arguments", "id"})},
}; });
schemas.push_back(schema);
} }
auto schema = json { auto schema = json {
{"type", "array"}, {"type", "array"},
@ -517,9 +520,41 @@ llama_tool_call_handler llama_tool_call_handler_init(
}); });
if (allow_content) { if (allow_content) {
handler.grammar_trigger_words.push_back("[TOOL_CALLS]"); handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
handler.grammar_trigger_words.push_back("[{\"arguments\":");
} }
// auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]"); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break;
}
case llama_tool_call_style::FirefunctionV2: {
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}},
{"required", json::array({"name", "arguments", "id"})},
});
}
auto schema = json {
{"type", "array"},
{"items", json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
});
if (allow_content) {
handler.grammar_trigger_words.push_back(" functools[");
}
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break; break;
} }

View file

@ -18,6 +18,7 @@ enum llama_tool_call_style {
Hermes2Pro, Hermes2Pro,
CommandRPlus, CommandRPlus,
MistralNemo, MistralNemo,
FirefunctionV2,
}; };
struct llama_tool_call { struct llama_tool_call {

View file

@ -1,10 +1,11 @@
# Agents / Tool Calling w/ llama.cpp # Agents / Tool Calling w/ llama.cpp
While *any model* should work (using some generic support), we only support the native call style of a few models: While *any model* should work (using some generic support), we only support the native call style of a few models:
- Llama 3.x - Firefunction v2
- Mistral Nemo
- Functionary 3.x - Functionary 3.x
- Hermes 2/3, Qwen 2.5 - Llama 3.x
- Mistral Nemo. - Hermes 2/3 / Qwen 2.5 / QwQ
For natively supported models, it's important to have the right template (it might not be in the GGUF; note that we prefer the `tool_use` variant of the Jinja template if it's present in the GGUF metadata). You can check which template is defined by inspecting `http://localhost:8080/props`, and inspect the logs for `Tool call style: `. For natively supported models, it's important to have the right template (it might not be in the GGUF; note that we prefer the `tool_use` variant of the Jinja template if it's present in the GGUF metadata). You can check which template is defined by inspecting `http://localhost:8080/props`, and inspect the logs for `Tool call style: `.
@ -23,31 +24,35 @@ Here's how to run an agent w/ local tool call:
# and consume more tokens) # and consume more tokens)
./build/bin/llama-server --jinja -fa --verbose \ ./build/bin/llama-server --jinja -fa --verbose \
-hfr bartowski/Qwen2.5-7B-Instruct-GGUF -hff Qwen2.5-7B-Instruct-Q4_K_M.gguf -hfr mav23/llama-3-firefunction-v2-GGUF -hff llama-3-firefunction-v2.Q4_K_M.gguf \
--chat-template-file <( python scripts/get_hf_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
./build/bin/llama-server --jinja -fa --verbose \
-hfr NousResearch/Hermes-3-Llama-3.1-8B-GGUF -hff Hermes-3-Llama-3.1-8B.Q4_K_M.gguf \
--chat-template-file tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
./build/bin/llama-server --jinja -fa --verbose \
-hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q8_0.gguf \
--chat-template-file tests/chat/templates/meetkai-functionary-medium-v3.2.jinja
./build/bin/llama-server --jinja -fa --verbose \
-hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K.gguf \
--chat-template-file tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
# Note the --special flag: this is needed b/c of a regression from the last merge, will fix! # Note the --special flag: this is needed b/c of a regression from the last merge, will fix!
./build/bin/llama-server --jinja -fa --verbose --special \ ./llama-server --jinja -fa --special \
-hfr bartowski/Mistral-Nemo-Instruct-2407-GGUF -hff Mistral-Nemo-Instruct-2407-Q8_0.gguf \ -hfr bartowski/Mistral-Nemo-Instruct-2407-GGUF -hff Mistral-Nemo-Instruct-2407-Q8_0.gguf \
--chat-template-file tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja --chat-template-file <( python scripts/get_hf_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 )
./llama-server --jinja -fa \
-hfr NousResearch/Hermes-3-Llama-3.1-8B-GGUF -hff Hermes-3-Llama-3.1-8B.Q4_K_M.gguf \
--chat-template-file <( python scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
./llama-server --jinja -fa \
-hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q8_0.gguf \
--chat-template-file <( python scripts/get_hf_chat_template.py meetkai/functionary-medium-v3.2 )
./llama-server --jinja -fa \
-hfr bartowski/Qwen2.5-7B-Instruct-GGUF -hff Qwen2.5-7B-Instruct-Q4_K_M.gguf
./llama-server --jinja -fa \
-hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K.gguf \
--chat-template-file <( python scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct )
# Generic support, e.g. Phi 3.5, Gemma 2b, but really anything goes # Generic support, e.g. Phi 3.5, Gemma 2b, but really anything goes
./build/bin/llama-server --jinja -fa --verbose \ ./llama-server --jinja -fa \
-hfr bartowski/Phi-3.5-mini-instruct-GGUF -hff Phi-3.5-mini-instruct-Q4_K_M.gguf -hfr bartowski/Phi-3.5-mini-instruct-GGUF -hff Phi-3.5-mini-instruct-Q4_K_M.gguf
./build/bin/llama-server --jinja -fa --verbose \ ./llama-server --jinja -fa \
-hfr bartowski/gemma-2-2b-it-GGUF -hff gemma-2-2b-it-Q4_K_M.gguf -hfr bartowski/gemma-2-2b-it-GGUF -hff gemma-2-2b-it-Q4_K_M.gguf
``` ```

View file

@ -0,0 +1,57 @@
{%- set loop_messages = messages -%}
{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%}
{%- set system_prompt_suffix -%}
{%- filter trim -%}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{%- endfilter -%}
{%- endset -%}
{%- set system_prompt_suffix = system_prompt_suffix + "\n" + functions -%}
{%- set system_prompt_suffix = system_prompt_suffix + '\nToday is ' + datetime + '.' -%}
{%- set ns = namespace(role='', content='') -%}
{#- Basic consistency checks -#}
{%- if not loop_messages -%}
{{ raise_exception('Expected non-empty messages') }}
{%- endif -%}
{%- for message in loop_messages -%}
{%- set ns.role = message['role'] | lower -%}
{%- if ns.role not in message_roles -%}
{%- set message_roles_string = message_roles | join(', ') -%}
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles_string + ' are supported.') }}
{%- endif -%}
{%- set msg_content = message['content'] | default('', true) | trim -%}
{%- if loop.index0 == 0 -%}
{%- if ns.role == 'system' -%}
{%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + message['content'] | trim + '\n' + system_prompt_suffix + '<|eot_id|>' -%}
{%- else -%}
{%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\nYou are a helpful assistant with access to functions.\n' + system_prompt_suffix + '<|eot_id|>' -%}
{%- endif -%}
{%- set ns.content = bos_token + system_prompt -%}
{{- ns.content -}}
{%- endif -%}
{%- if loop.index0 > 0 or ns.role != 'system' -%}
{%- set ns.content = '<|start_header_id|>' + ns.role + '<|end_header_id|>\n\n' + msg_content -%}
{%- if 'tool_calls' in message and message['tool_calls'] -%}
{%- set tool = namespace(calls=[]) -%}
{%- for call in message['tool_calls'] -%}
{%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}'] -%}
{%- endfor -%}
{%- set ns.content = ns.content + ' functools[' + tool.calls | join(', ') + ']' -%}
{%- endif -%}
{%- set ns.content = ns.content + '<|eot_id|>' -%}
{{- ns.content -}}
{%- endif -%}
{%- endfor -%}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}

View file

@ -306,10 +306,11 @@ static void test_parsing() {
"Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
"Bleh", "Bleh",
json::array({special_function_call_with_id})); json::array({special_function_call_with_id}));
test_parse_tool_call(llama_tool_call_style::MistralNemo, tools,
"[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", test_parse_tool_call(llama_tool_call_style::FirefunctionV2, tools,
"", "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
json::array({special_function_call_with_id})); "Bleh",
json::array({special_function_call}));
} }
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
@ -322,6 +323,7 @@ static void test_tool_call_style(const std::string & template_file, llama_tool_c
static void test_tool_call_style_detection() { static void test_tool_call_style_detection() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31); test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3); test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", FirefunctionV2);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", Hermes2Pro); test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", Hermes2Pro);
@ -414,6 +416,7 @@ static void test_grammars() {
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "<s>", "</s>", { "<|eot_id|>" }, tool_call_message, tools);
test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "<s>", "</s>", { "<end_of_turn>" }, tool_call_message_with_id, tools); test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "<s>", "</s>", { "<end_of_turn>" }, tool_call_message_with_id, tools);
test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "<s>", "</s>", { "<|end|>" }, tool_call_message_with_id, tools); test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "<s>", "</s>", { "<|end|>" }, tool_call_message_with_id, tools);
} }