Fix firefunction w/ jinja: requires two variables, use the chat handlers everywhere templates are used

This commit is contained in:
ochafik 2025-01-29 01:15:44 +00:00
parent d603d067d5
commit 64263910d8
3 changed files with 71 additions and 52 deletions

View file

@ -162,6 +162,14 @@ static void foreach_function(const json & tools, const std::function<void(const
} }
} }
static common_chat_msg no_op_text_parser(const std::string & input) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data; common_chat_data data;
@ -498,40 +506,49 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
} }
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "%s\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar_lazy = params.tool_choice != "required"; if (!params.tools.is_null() && !params.tools.empty()) {
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar_lazy = params.tool_choice != "required";
auto schemas = json::array(); data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(params.tools, [&](const json & tool) { auto schemas = json::array();
const auto & function = tool["function"]; foreach_function(params.tools, [&](const json & tool) {
schemas.push_back({ const auto & function = tool["function"];
{"type", "object"}, schemas.push_back({
{"properties", { {"type", "object"},
{"name", { {"properties", {
{"type", "string"}, {"name", {
{"const", function["name"]}, {"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}}, }},
{"arguments", function["parameters"]}, {"required", json::array({"name", "arguments", "id"})},
}}, });
{"required", json::array({"name", "arguments", "id"})},
}); });
}); auto schema = json {
auto schema = json { {"type", "array"},
{"type", "array"}, {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, {"minItems", 1},
{"minItems", 1}, };
if (!params.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}; };
if (!params.parallel_tool_calls) { data.format = "firefunction v2 tool calls";
schema["maxItems"] = 1; } else {
} data.parser = no_op_text_parser;
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); data.format = "firefunction v2 text-only";
}, grammar_options); }
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, {
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); {"datetime", "Jan 29 2025 13:00:00 GMT"},
data.format = "firefunction v2 tool calls"; {"functions", json(params.tools.empty() ? "" : params.tools.dump(2))},
data.parser = [](const std::string & input) { }, /* adjust_inputs= */ false);
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
};
return data; return data;
} }
@ -747,13 +764,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
common_chat_data data; common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only"; data.format = "content-only";
data.parser = [](const std::string & input) -> common_chat_msg { data.parser = no_op_text_parser;
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
};
data.grammar_lazy = false; data.grammar_lazy = false;
if (!params.json_schema.is_null()) { if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) { if (!params.grammar.empty()) {
@ -777,6 +788,10 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
return common_chat_init_functionary_v3_2_tool_call(tmpl, params); return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
} }
if (src.find(" functools[") != std::string::npos) {
// Firefunction v2 requires datetime and functions in the context
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
}
if (has_tools) { if (has_tools) {
return common_chat_init_without_tools(tmpl, params); return common_chat_init_without_tools(tmpl, params);
@ -807,8 +822,5 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
if (src.find("[TOOL_CALLS]") != std::string::npos) { if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_init_mistral_nemo_tool_call(tmpl, params); return common_chat_init_mistral_nemo_tool_call(tmpl, params);
} }
if (src.find(" functools[") != std::string::npos) {
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
}
return common_chat_init_generic_tool_call(tmpl, params); return common_chat_init_generic_tool_call(tmpl, params);
} }

View file

@ -12,6 +12,7 @@
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "chat-handler.hpp"
#include "chat-template.hpp" #include "chat-template.hpp"
#include <algorithm> #include <algorithm>
@ -1774,11 +1775,13 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) { if (use_jinja) {
try { try {
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>"); auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{ common_chat_params params;
params.messages = json::array({{
{"role", "user"}, {"role", "user"},
{"content", "test"}, {"content", "test"},
}}, json(), true); }});
common_chat_init(chat_template, params);
return true; return true;
} catch (const std::exception & e) { } catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
@ -1800,7 +1803,11 @@ std::string common_chat_apply_template(
for (const auto & msg : msgs) { for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}}); messages.push_back({{"role", msg.role}, {"content", msg.content}});
} }
return tmpl.apply(messages, /* tools= */ json(), add_ass); common_chat_params params;
params.messages = messages;
params.add_generation_prompt = add_ass;
auto data = common_chat_init(tmpl, params);
return data.prompt;
} }
int alloc_size = 0; int alloc_size = 0;

View file

@ -1794,17 +1794,16 @@ struct server_context {
if (use_jinja) { if (use_jinja) {
auto templates = common_chat_templates_from_model(model, ""); auto templates = common_chat_templates_from_model(model, "");
common_chat_params params;
params.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default); GGML_ASSERT(templates.template_default);
try { try {
templates.template_default->apply({{ common_chat_init(*templates.template_default, params);
{"role", "user"},
{"content", "test"},
}}, json(), true);
if (templates.template_tool_use) { if (templates.template_tool_use) {
templates.template_tool_use->apply({{ common_chat_init(*templates.template_tool_use, params);
{"role", "user"},
{"content", "test"},
}}, json(), true);
} }
return true; return true;
} catch (const std::exception & e) { } catch (const std::exception & e) {
@ -3770,6 +3769,7 @@ int main(int argc, char ** argv) {
/* .stream = */ json_value(data, "stream", false), /* .stream = */ json_value(data, "stream", false),
/* .grammar = */ json_value(data, "grammar", std::string("")), /* .grammar = */ json_value(data, "grammar", std::string("")),
}); });
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
if (data.contains("grammar")) { if (data.contains("grammar")) {
if (!chat_data.grammar.empty()) { if (!chat_data.grammar.empty()) {
throw std::runtime_error("Cannot provide grammar and tools"); throw std::runtime_error("Cannot provide grammar and tools");