add minicpm chat template

This commit is contained in:
ngxson 2024-03-24 12:41:15 +01:00
parent d0a71233fb
commit 00efa65eb1
2 changed files with 43 additions and 20 deletions

View file

@ -14491,12 +14491,26 @@ static std::string trim(const std::string & str) {
// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
// List of all supported chat templates: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
// utility functions to deal with chat templates that don't support system message
// in that case, we will merge system message with the next user message
std::string sys_msg;
auto sys_msg_push = [&](const llama_chat_message * msg, bool need_trim) {
sys_msg = need_trim ? trim(msg->content) : msg->content;
};
auto sys_msg_pop = [&]() {
if (!sys_msg.empty()) {
ss << sys_msg << "\n\n";
sys_msg = "";
}
};
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
// chatml template
for (auto message : chat) {
@ -14559,46 +14573,51 @@ static int32_t llama_chat_apply_template_internal(
}
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt = trim(message->content);
continue;
// there is no system message support, we will merge it with user prompt
sys_msg_push(message, true);
} else {
// in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n";
if (role != "model") {
sys_msg_pop();
}
ss << trim(message->content) << "<end_of_turn>\n";
}
// in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n";
if (!system_prompt.empty() && role != "model") {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << trim(message->content) << "<end_of_turn>\n";
}
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt = message->content;
continue;
sys_msg_push(message, false);
} else if (role == "user") {
ss << "Human: ";
if (!system_prompt.empty()) {
ss << system_prompt << "\n\n";
system_prompt = "";
}
sys_msg_pop();
ss << message->content << "\n\nAssistant: </s>";
} else {
ss << message->content << "</s>";
}
}
} else if (tmpl == "minicpm" || tmpl.find("<\xe7\x94\xa8\xe6\x88\xb7>") != std::string::npos) {
// openbmb/MiniCPM-2B-dpo-fp32
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << "<\xe7\x94\xa8\xe6\x88\xb7>";
ss << trim(message->content);
ss << "<AI>";
} else {
ss << trim(message->content);
}
}
} else {
// template not supported
return -1;

View file

@ -33,6 +33,8 @@ int main(void) {
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
// OrionStarAI/Orion-14B-Chat
"{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
// openbmb/MiniCPM-2B-dpo-fp32
"{% for message in messages %}{% if message['role'] == 'user' %}{{'<\xe7\x94\xa8\xe6\x88\xb7>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
};
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
@ -49,6 +51,8 @@ int main(void) {
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
// OrionStarAI/Orion-14B-Chat
"Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
// openbmb/MiniCPM-2B-dpo-fp32
"You are a helpful assistant<\xe7\x94\xa8\xe6\x88\xb7>Hello<AI>Hi there<\xe7\x94\xa8\xe6\x88\xb7>Who are you<AI>I am an assistant<\xe7\x94\xa8\xe6\x88\xb7>Another question<AI>",
};
std::vector<char> formatted_chat(1024);
int32_t res;