diff --git a/common/tool-call.cpp b/common/tool-call.cpp index adff1b2f8..b209c9145 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -67,6 +67,8 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) { return "CommandRPlus"; case llama_tool_call_style::MistralNemo: return "MistralNemo"; + case llama_tool_call_style::FirefunctionV2: + return "FirefunctionV2"; default: return "Unknown"; } @@ -92,6 +94,8 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & return CommandRPlus; } else if (src.find("[TOOL_CALLS]") != std::string::npos) { return MistralNemo; + } else if (src.find(" functools[") != std::string::npos) { + return FirefunctionV2; } else { return Generic; } @@ -315,8 +319,8 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) { return result; } -static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { - auto content_end = input.find("[TOOL_CALLS]"); +static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { + auto content_end = input.find(prefix); size_t tc_start = std::string::npos; llama_tool_calls result; @@ -330,25 +334,27 @@ static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) }); } }; - if (content_end != std::string::npos) { - tc_start = content_end + 12; + if (content_end == std::string::npos) { + result.content = input; + } else { + tc_start = content_end + prefix.size() - rstrip_prefix; result.content = input.substr(0, content_end); auto tool_calls = json::parse(input.substr(tc_start)); process_tool_calls(tool_calls); - } else { - // Somehow not getting [TOOL_CALLS] in the output. Oh well, just do without it. - try { - auto tool_calls = json::parse(input); - process_tool_calls(tool_calls); - } catch (const json::exception & e) { - throw std::runtime_error("Failed to parse tool calls: " + std::string(e.what()) + ":\n" + input); - } } return result; } +static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { + return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); +} + +static llama_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) { + return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); +} + llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { - // fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", llama_tool_call_style_name(style).c_str(), input.c_str()); + fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", llama_tool_call_style_name(style).c_str(), input.c_str()); switch (style) { case llama_tool_call_style::None: return {input, {}}; @@ -366,6 +372,8 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool return parse_hermes_tool_calls(input); case llama_tool_call_style::MistralNemo: return parse_mistral_nemo_tool_calls(input); + case llama_tool_call_style::FirefunctionV2: + return parse_firefunction_v2_tool_calls(input); default: throw std::runtime_error("Unsupported tool call style"); } @@ -406,16 +414,14 @@ llama_tool_call_handler llama_tool_call_handler_init( auto tool_call_schemas = json::array(); for (const auto & tool : actual_tools) { const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; auto tool_schema = json { {"type", "object"}, {"properties", { {"name", { {"type", "string"}, - {"const", name}, + {"const", function["name"]}, }}, - {"arguments", parameters}, + {"arguments", function["parameters"]}, }}, {"required", json::array({"name", "arguments"})}, }; @@ -483,18 +489,16 @@ llama_tool_call_handler llama_tool_call_handler_init( auto schemas = json::array(); for (const auto & tool : actual_tools) { const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto schema = json { + schemas.push_back({ {"type", "object"}, {"properties", { // Important note: the model is probably trained to take a JSON stringified arguments value. // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. - {"arguments", parameters}, {"name", { {"type", "string"}, - {"const", name}, + {"const", function["name"]}, }}, + {"arguments", function["parameters"]}, {"id", { {"type", "string"}, // Nemo's template expects a 9-character alphanumeric ID. @@ -502,8 +506,7 @@ llama_tool_call_handler llama_tool_call_handler_init( }}, }}, {"required", json::array({"name", "arguments", "id"})}, - }; - schemas.push_back(schema); + }); } auto schema = json { {"type", "array"}, @@ -517,9 +520,41 @@ llama_tool_call_handler llama_tool_call_handler_init( }); if (allow_content) { handler.grammar_trigger_words.push_back("[TOOL_CALLS]"); - handler.grammar_trigger_words.push_back("[{\"arguments\":"); } - // auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]"); + handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); + break; + } + case llama_tool_call_style::FirefunctionV2: { + auto actual_tools = normalize_tools(tools); + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + auto schemas = json::array(); + for (const auto & tool : actual_tools) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + } + auto schema = json { + {"type", "array"}, + {"items", json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!parallel) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); + }); + if (allow_content) { + handler.grammar_trigger_words.push_back(" functools["); + } handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); break; } diff --git a/common/tool-call.h b/common/tool-call.h index 6d1265460..c2d068441 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -18,6 +18,7 @@ enum llama_tool_call_style { Hermes2Pro, CommandRPlus, MistralNemo, + FirefunctionV2, }; struct llama_tool_call { diff --git a/examples/agent/README.md b/examples/agent/README.md index 830c6493c..4770720c6 100644 --- a/examples/agent/README.md +++ b/examples/agent/README.md @@ -1,10 +1,11 @@ # Agents / Tool Calling w/ llama.cpp While *any model* should work (using some generic support), we only support the native call style of a few models: -- Llama 3.x +- Firefunction v2 +- Mistral Nemo - Functionary 3.x -- Hermes 2/3, Qwen 2.5 -- Mistral Nemo. +- Llama 3.x +- Hermes 2/3 / Qwen 2.5 / QwQ For natively supported models, it's important to have the right template (it might not be in the GGUF; note that we prefer the `tool_use` variant of the Jinja template if it's present in the GGUF metadata). You can check which template is defined by inspecting `http://localhost:8080/props`, and inspect the logs for `Tool call style: `. @@ -23,31 +24,35 @@ Here's how to run an agent w/ local tool call: # and consume more tokens) ./build/bin/llama-server --jinja -fa --verbose \ - -hfr bartowski/Qwen2.5-7B-Instruct-GGUF -hff Qwen2.5-7B-Instruct-Q4_K_M.gguf - - ./build/bin/llama-server --jinja -fa --verbose \ - -hfr NousResearch/Hermes-3-Llama-3.1-8B-GGUF -hff Hermes-3-Llama-3.1-8B.Q4_K_M.gguf \ - --chat-template-file tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja - - ./build/bin/llama-server --jinja -fa --verbose \ - -hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q8_0.gguf \ - --chat-template-file tests/chat/templates/meetkai-functionary-medium-v3.2.jinja - - ./build/bin/llama-server --jinja -fa --verbose \ - -hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K.gguf \ - --chat-template-file tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja + -hfr mav23/llama-3-firefunction-v2-GGUF -hff llama-3-firefunction-v2.Q4_K_M.gguf \ + --chat-template-file <( python scripts/get_hf_chat_template.py fireworks-ai/firellama-3-firefunction-v2 ) # Note the --special flag: this is needed b/c of a regression from the last merge, will fix! - ./build/bin/llama-server --jinja -fa --verbose --special \ + ./llama-server --jinja -fa --special \ -hfr bartowski/Mistral-Nemo-Instruct-2407-GGUF -hff Mistral-Nemo-Instruct-2407-Q8_0.gguf \ - --chat-template-file tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja + --chat-template-file <( python scripts/get_hf_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 ) + + ./llama-server --jinja -fa \ + -hfr NousResearch/Hermes-3-Llama-3.1-8B-GGUF -hff Hermes-3-Llama-3.1-8B.Q4_K_M.gguf \ + --chat-template-file <( python scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use ) + + ./llama-server --jinja -fa \ + -hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q8_0.gguf \ + --chat-template-file <( python scripts/get_hf_chat_template.py meetkai/functionary-medium-v3.2 ) + + ./llama-server --jinja -fa \ + -hfr bartowski/Qwen2.5-7B-Instruct-GGUF -hff Qwen2.5-7B-Instruct-Q4_K_M.gguf + + ./llama-server --jinja -fa \ + -hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K.gguf \ + --chat-template-file <( python scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct ) # Generic support, e.g. Phi 3.5, Gemma 2b, but really anything goes - ./build/bin/llama-server --jinja -fa --verbose \ + ./llama-server --jinja -fa \ -hfr bartowski/Phi-3.5-mini-instruct-GGUF -hff Phi-3.5-mini-instruct-Q4_K_M.gguf - ./build/bin/llama-server --jinja -fa --verbose \ + ./llama-server --jinja -fa \ -hfr bartowski/gemma-2-2b-it-GGUF -hff gemma-2-2b-it-Q4_K_M.gguf ``` diff --git a/tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja b/tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja new file mode 100644 index 000000000..9b8136df7 --- /dev/null +++ b/tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja @@ -0,0 +1,57 @@ +{%- set loop_messages = messages -%} +{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%} +{%- set system_prompt_suffix -%} +{%- filter trim -%} +In addition to plain text responses, you can chose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + +Available functions as JSON spec: +{%- endfilter -%} +{%- endset -%} +{%- set system_prompt_suffix = system_prompt_suffix + "\n" + functions -%} +{%- set system_prompt_suffix = system_prompt_suffix + '\nToday is ' + datetime + '.' -%} +{%- set ns = namespace(role='', content='') -%} +{#- Basic consistency checks -#} +{%- if not loop_messages -%} + {{ raise_exception('Expected non-empty messages') }} +{%- endif -%} +{%- for message in loop_messages -%} + {%- set ns.role = message['role'] | lower -%} + {%- if ns.role not in message_roles -%} + {%- set message_roles_string = message_roles | join(', ') -%} + {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles_string + ' are supported.') }} + {%- endif -%} + {%- set msg_content = message['content'] | default('', true) | trim -%} + {%- if loop.index0 == 0 -%} + {%- if ns.role == 'system' -%} + {%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + message['content'] | trim + '\n' + system_prompt_suffix + '<|eot_id|>' -%} + {%- else -%} + {%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\nYou are a helpful assistant with access to functions.\n' + system_prompt_suffix + '<|eot_id|>' -%} + {%- endif -%} + {%- set ns.content = bos_token + system_prompt -%} + {{- ns.content -}} + {%- endif -%} + {%- if loop.index0 > 0 or ns.role != 'system' -%} + {%- set ns.content = '<|start_header_id|>' + ns.role + '<|end_header_id|>\n\n' + msg_content -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- set tool = namespace(calls=[]) -%} + {%- for call in message['tool_calls'] -%} + {%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}'] -%} + {%- endfor -%} + {%- set ns.content = ns.content + ' functools[' + tool.calls | join(', ') + ']' -%} + {%- endif -%} + {%- set ns.content = ns.content + '<|eot_id|>' -%} + {{- ns.content -}} + {%- endif -%} +{%- endfor -%} +{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index c81a4c15a..d112e395e 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -306,10 +306,11 @@ static void test_parsing() { "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", "Bleh", json::array({special_function_call_with_id})); - test_parse_tool_call(llama_tool_call_style::MistralNemo, tools, - "[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", - "", - json::array({special_function_call_with_id})); + + test_parse_tool_call(llama_tool_call_style::FirefunctionV2, tools, + "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]", + "Bleh", + json::array({special_function_call})); } static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { @@ -322,6 +323,7 @@ static void test_tool_call_style(const std::string & template_file, llama_tool_c static void test_tool_call_style_detection() { test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31); test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3); + test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", FirefunctionV2); test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", Hermes2Pro); @@ -414,6 +416,7 @@ static void test_grammars() { test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "", "", { "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "", "", { "" }, tool_call_message_with_id, tools); test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "", "", { "<|end|>" }, tool_call_message_with_id, tools); }