diff --git a/common/common.cpp b/common/common.cpp index 5f5302074..d1e305103 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1869,10 +1869,18 @@ std::string common_chat_format_example(const common_chat_template & tmpl, bool u return common_chat_apply_template(tmpl, msgs, true, use_jinja); } +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%})" + common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { - std::string default_template_src = chat_template_override; - std::string template_tool_use_src = chat_template_override; + std::string default_template_src = chat_template_override == "chatml" ? CHATML_TEMPLATE_SRC : chat_template_override; + std::string template_tool_use_src = chat_template_override == "chatml" ? CHATML_TEMPLATE_SRC : ""; bool has_explicit_template = !chat_template_override.empty(); if (chat_template_override.empty()) { auto str = llama_model_chat_template(model, /* name */ nullptr);