When llama_chat_apply_template doesn't work
Try minja. With granite-code if we fall back to jinja on failure, it's fine. Co-authored-by: Michael Engel <mengel@redhat.com> Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
parent
d774ab3acc
commit
d23abdc3f6
2 changed files with 70 additions and 43 deletions
|
@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
||||||
llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
|
llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to handle Jinja template application
|
||||||
|
static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) {
|
||||||
|
json messages = json::array();
|
||||||
|
for (const auto & msg : llama_data.messages) {
|
||||||
|
messages.push_back({
|
||||||
|
{ "role", msg.role },
|
||||||
|
{ "content", msg.content },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
minja::chat_template_inputs tmpl_inputs;
|
||||||
|
tmpl_inputs.messages = messages;
|
||||||
|
tmpl_inputs.add_generation_prompt = append;
|
||||||
|
|
||||||
|
minja::chat_template_options tmpl_opts;
|
||||||
|
tmpl_opts.use_bos_token = false;
|
||||||
|
tmpl_opts.use_eos_token = false;
|
||||||
|
|
||||||
|
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||||
|
llama_data.fmtted.resize(result.size() + 1);
|
||||||
|
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||||
|
return result.size();
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
printe("failed to render the chat template: %s\n", e.what());
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
// Function to apply the chat template and resize `formatted` if needed
|
// Function to apply the chat template and resize `formatted` if needed
|
||||||
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
json messages = json::array();
|
return handle_jinja_template(tmpl, llama_data, append);
|
||||||
for (const auto & msg : llama_data.messages) {
|
|
||||||
messages.push_back({
|
|
||||||
{"role", msg.role},
|
|
||||||
{"content", msg.content},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
minja::chat_template_inputs tmpl_inputs;
|
|
||||||
tmpl_inputs.messages = messages;
|
|
||||||
tmpl_inputs.add_generation_prompt = append;
|
|
||||||
|
|
||||||
minja::chat_template_options tmpl_opts;
|
|
||||||
tmpl_opts.use_bos_token = false;
|
|
||||||
tmpl_opts.use_eos_token = false;
|
|
||||||
|
|
||||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
|
||||||
llama_data.fmtted.resize(result.size() + 1);
|
|
||||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
|
||||||
return result.size();
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
printe("failed to render the chat template: %s\n", e.what());
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int result = llama_chat_apply_template(
|
int result = llama_chat_apply_template(
|
||||||
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||||
|
// If llama_chat_apply_template fails to apply template, fallback to using jinja
|
||||||
|
if (result < 0) {
|
||||||
|
return handle_jinja_template(tmpl, llama_data, append);
|
||||||
|
}
|
||||||
|
|
||||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||||
llama_data.fmtted.resize(result);
|
llama_data.fmtted.resize(result);
|
||||||
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
||||||
|
|
|
@ -1895,30 +1895,44 @@ struct server_context {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool apply_jinja_templates() const {
|
||||||
|
auto templates = common_chat_templates_from_model(model, "");
|
||||||
|
common_chat_inputs inputs;
|
||||||
|
inputs.messages = json::array({
|
||||||
|
{
|
||||||
|
{ "role", "user" },
|
||||||
|
{ "content", "test" },
|
||||||
|
}
|
||||||
|
});
|
||||||
|
GGML_ASSERT(templates.template_default);
|
||||||
|
try {
|
||||||
|
common_chat_params_init(*templates.template_default, inputs);
|
||||||
|
if (templates.template_tool_use) {
|
||||||
|
common_chat_params_init(*templates.template_tool_use, inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
SRV_ERR("failed to apply template: %s\n", e.what());
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool validate_builtin_chat_template(bool use_jinja) const {
|
bool validate_builtin_chat_template(bool use_jinja) const {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {
|
||||||
|
{ "user", "test" }
|
||||||
|
};
|
||||||
|
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
auto templates = common_chat_templates_from_model(model, "");
|
return apply_jinja_templates();
|
||||||
common_chat_inputs inputs;
|
|
||||||
inputs.messages = json::array({{
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", "test"},
|
|
||||||
}});
|
|
||||||
GGML_ASSERT(templates.template_default);
|
|
||||||
try {
|
|
||||||
common_chat_params_init(*templates.template_default, inputs);
|
|
||||||
if (templates.template_tool_use) {
|
|
||||||
common_chat_params_init(*templates.template_tool_use, inputs);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
SRV_ERR("failed to apply template: %s\n", e.what());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
||||||
|
if (chat_res < 0) {
|
||||||
|
return apply_jinja_templates();
|
||||||
|
}
|
||||||
|
|
||||||
return chat_res > 0;
|
return chat_res > 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue