update to minja's new api
This commit is contained in:
parent
11c1f0c7d4
commit
30ea3591c9
4 changed files with 170 additions and 12 deletions
|
@ -270,6 +270,28 @@ class chat_template {
|
|||
const std::string & eos_token() const { return eos_token_; }
|
||||
const chat_template_caps & original_caps() const { return caps_; }
|
||||
|
||||
// Deprecated, please use the form with chat_template_inputs and chat_template_options
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
||||
bool apply_polyfills = true)
|
||||
{
|
||||
fprintf(stderr, "[%s] Deprecated!\n", __func__);
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context = extra_context;
|
||||
inputs.now = std::chrono::system_clock::now();
|
||||
|
||||
chat_template_options opts;
|
||||
opts.apply_polyfills = apply_polyfills;
|
||||
|
||||
return apply(inputs, opts);
|
||||
}
|
||||
|
||||
std::string apply(
|
||||
const chat_template_inputs & inputs,
|
||||
const chat_template_options & opts = chat_template_options()) const
|
||||
|
|
|
@ -175,6 +175,28 @@ static void foreach_function(const json & tools, const std::function<void(const
|
|||
}
|
||||
}
|
||||
|
||||
static std::string apply(
|
||||
const common_chat_template & tmpl,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
|
||||
{
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages;
|
||||
tmpl_inputs.tools = tools;
|
||||
tmpl_inputs.add_generation_prompt = add_generation_prompt;
|
||||
tmpl_inputs.extra_context = extra_context;
|
||||
// TODO: add flag to control date/time, if only for testing purposes.
|
||||
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
tmpl_opts.use_bos_token = false;
|
||||
tmpl_opts.use_eos_token = false;
|
||||
|
||||
return tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
|
@ -256,7 +278,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|||
inputs.messages,
|
||||
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
||||
|
||||
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
return data;
|
||||
}
|
||||
|
@ -322,7 +344,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
|||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
}
|
||||
|
@ -372,7 +394,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
return data;
|
||||
}
|
||||
|
@ -489,7 +511,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
|
@ -568,7 +590,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
|||
"<|tool▁call▁end|>",
|
||||
};
|
||||
}, grammar_options);
|
||||
auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
|
||||
// Hacks to fix the official (broken) prompt.
|
||||
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
||||
|
@ -614,10 +636,10 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
|
|||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
fprintf(stderr, "%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
}, /* adjust_inputs= */ false);
|
||||
});
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -661,7 +683,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_params data;
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
|
@ -788,7 +810,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
|
@ -843,7 +865,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|||
data.preserved_tokens = { "</tool_call>" };
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||
return data;
|
||||
}
|
||||
|
@ -904,7 +926,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
data.grammar_lazy = false;
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue