Fix firefunction w/ jinja: requires two variables, use the chat handlers everywhere templates are used
This commit is contained in:
parent
d603d067d5
commit
64263910d8
3 changed files with 71 additions and 52 deletions
|
@ -162,6 +162,14 @@ static void foreach_function(const json & tools, const std::function<void(const
|
|||
}
|
||||
}
|
||||
|
||||
static common_chat_msg no_op_text_parser(const std::string & input) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
|
||||
|
@ -498,40 +506,49 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "%s\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!params.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
};
|
||||
if (!params.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "firefunction v2 tool calls";
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
};
|
||||
data.format = "firefunction v2 tool calls";
|
||||
} else {
|
||||
data.parser = no_op_text_parser;
|
||||
data.format = "firefunction v2 text-only";
|
||||
}
|
||||
data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(params.tools.empty() ? "" : params.tools.dump(2))},
|
||||
}, /* adjust_inputs= */ false);
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -747,13 +764,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
|
|||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "content-only";
|
||||
data.parser = [](const std::string & input) -> common_chat_msg {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
};
|
||||
data.parser = no_op_text_parser;
|
||||
data.grammar_lazy = false;
|
||||
if (!params.json_schema.is_null()) {
|
||||
if (!params.grammar.empty()) {
|
||||
|
@ -777,6 +788,10 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
|||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
||||
}
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
// Firefunction v2 requires datetime and functions in the context
|
||||
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
|
||||
}
|
||||
|
||||
if (has_tools) {
|
||||
return common_chat_init_without_tools(tmpl, params);
|
||||
|
@ -807,8 +822,5 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
|||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
|
||||
}
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
|
||||
}
|
||||
return common_chat_init_generic_tool_call(tmpl, params);
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-handler.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -1774,11 +1775,13 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
|||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({{
|
||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
common_chat_params params;
|
||||
params.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
}});
|
||||
common_chat_init(chat_template, params);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
|
@ -1800,7 +1803,11 @@ std::string common_chat_apply_template(
|
|||
for (const auto & msg : msgs) {
|
||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||
common_chat_params params;
|
||||
params.messages = messages;
|
||||
params.add_generation_prompt = add_ass;
|
||||
auto data = common_chat_init(tmpl, params);
|
||||
return data.prompt;
|
||||
}
|
||||
|
||||
int alloc_size = 0;
|
||||
|
|
|
@ -1794,17 +1794,16 @@ struct server_context {
|
|||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
common_chat_params params;
|
||||
params.messages = json::array({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}});
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
templates.template_default->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
common_chat_init(*templates.template_default, params);
|
||||
if (templates.template_tool_use) {
|
||||
templates.template_tool_use->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
common_chat_init(*templates.template_tool_use, params);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
|
@ -3770,6 +3769,7 @@ int main(int argc, char ** argv) {
|
|||
/* .stream = */ json_value(data, "stream", false),
|
||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||
});
|
||||
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
|
||||
if (data.contains("grammar")) {
|
||||
if (!chat_data.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot provide grammar and tools");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue