Fix firefunction w/ jinja: requires two variables, use the chat handlers everywhere templates are used

This commit is contained in:
ochafik 2025-01-29 01:15:44 +00:00
parent d603d067d5
commit 64263910d8
3 changed files with 71 additions and 52 deletions

View file

@ -162,6 +162,14 @@ static void foreach_function(const json & tools, const std::function<void(const
}
}
static common_chat_msg no_op_text_parser(const std::string & input) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
@ -498,40 +506,49 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
}
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "%s\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
if (!params.tools.is_null() && !params.tools.empty()) {
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}},
{"arguments", function["parameters"]},
}},
{"required", json::array({"name", "arguments", "id"})},
{"required", json::array({"name", "arguments", "id"})},
});
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!params.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
};
if (!params.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "firefunction v2 tool calls";
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
};
data.format = "firefunction v2 tool calls";
} else {
data.parser = no_op_text_parser;
data.format = "firefunction v2 text-only";
}
data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(params.tools.empty() ? "" : params.tools.dump(2))},
}, /* adjust_inputs= */ false);
return data;
}
@ -747,13 +764,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only";
data.parser = [](const std::string & input) -> common_chat_msg {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
};
data.parser = no_op_text_parser;
data.grammar_lazy = false;
if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) {
@ -777,6 +788,10 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
}
if (src.find(" functools[") != std::string::npos) {
// Firefunction v2 requires datetime and functions in the context
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
}
if (has_tools) {
return common_chat_init_without_tools(tmpl, params);
@ -807,8 +822,5 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
}
if (src.find(" functools[") != std::string::npos) {
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
}
return common_chat_init_generic_tool_call(tmpl, params);
}

View file

@ -12,6 +12,7 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat-handler.hpp"
#include "chat-template.hpp"
#include <algorithm>
@ -1774,11 +1775,13 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_params params;
params.messages = json::array({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
}});
common_chat_init(chat_template, params);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
@ -1800,7 +1803,11 @@ std::string common_chat_apply_template(
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_ass);
common_chat_params params;
params.messages = messages;
params.add_generation_prompt = add_ass;
auto data = common_chat_init(tmpl, params);
return data.prompt;
}
int alloc_size = 0;

View file

@ -1794,17 +1794,16 @@ struct server_context {
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_params params;
params.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
templates.template_default->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
common_chat_init(*templates.template_default, params);
if (templates.template_tool_use) {
templates.template_tool_use->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
common_chat_init(*templates.template_tool_use, params);
}
return true;
} catch (const std::exception & e) {
@ -3770,6 +3769,7 @@ int main(int argc, char ** argv) {
/* .stream = */ json_value(data, "stream", false),
/* .grammar = */ json_value(data, "grammar", std::string("")),
});
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
if (data.contains("grammar")) {
if (!chat_data.grammar.empty()) {
throw std::runtime_error("Cannot provide grammar and tools");