add minicpm chat template

This commit is contained in:
ngxson 2024-03-24 12:41:15 +01:00
parent d0a71233fb
commit 00efa65eb1
2 changed files with 43 additions and 20 deletions

View file

@ -14491,12 +14491,26 @@ static std::string trim(const std::string & str) {
// Simple version of "llama_apply_chat_template" that only works with strings // Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. // This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
// List of all supported chat templates: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
static int32_t llama_chat_apply_template_internal( static int32_t llama_chat_apply_template_internal(
const std::string & tmpl, const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat, const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) { std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss; std::stringstream ss;
// utility functions to deal with chat templates that don't support system message
// in that case, we will merge system message with the next user message
std::string sys_msg;
auto sys_msg_push = [&](const llama_chat_message * msg, bool need_trim) {
sys_msg = need_trim ? trim(msg->content) : msg->content;
};
auto sys_msg_pop = [&]() {
if (!sys_msg.empty()) {
ss << sys_msg << "\n\n";
sys_msg = "";
}
};
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) { if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
// chatml template // chatml template
for (auto message : chat) { for (auto message : chat) {
@ -14559,46 +14573,51 @@ static int32_t llama_chat_apply_template_internal(
} }
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) { } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
// google/gemma-7b-it // google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
if (role == "system") { if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken // there is no system message support, we will merge it with user prompt
system_prompt = trim(message->content); sys_msg_push(message, true);
continue; } else {
}
// in gemma, "assistant" is "model" // in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role; role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n"; ss << "<start_of_turn>" << role << "\n";
if (!system_prompt.empty() && role != "model") { if (role != "model") {
ss << system_prompt << "\n\n"; sys_msg_pop();
system_prompt = "";
} }
ss << trim(message->content) << "<end_of_turn>\n"; ss << trim(message->content) << "<end_of_turn>\n";
} }
}
if (add_ass) { if (add_ass) {
ss << "<start_of_turn>model\n"; ss << "<start_of_turn>model\n";
} }
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) { } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
// OrionStarAI/Orion-14B-Chat // OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
if (role == "system") { if (role == "system") {
// there is no system message support, we will merge it with user prompt // there is no system message support, we will merge it with user prompt
system_prompt = message->content; sys_msg_push(message, false);
continue;
} else if (role == "user") { } else if (role == "user") {
ss << "Human: "; ss << "Human: ";
if (!system_prompt.empty()) { sys_msg_pop();
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << message->content << "\n\nAssistant: </s>"; ss << message->content << "\n\nAssistant: </s>";
} else { } else {
ss << message->content << "</s>"; ss << message->content << "</s>";
} }
} }
} else if (tmpl == "minicpm" || tmpl.find("<\xe7\x94\xa8\xe6\x88\xb7>") != std::string::npos) {
// openbmb/MiniCPM-2B-dpo-fp32
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << "<\xe7\x94\xa8\xe6\x88\xb7>";
ss << trim(message->content);
ss << "<AI>";
} else {
ss << trim(message->content);
}
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View file

@ -33,6 +33,8 @@ int main(void) {
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}", "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
// OrionStarAI/Orion-14B-Chat // OrionStarAI/Orion-14B-Chat
"{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
// openbmb/MiniCPM-2B-dpo-fp32
"{% for message in messages %}{% if message['role'] == 'user' %}{{'<\xe7\x94\xa8\xe6\x88\xb7>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
}; };
std::vector<std::string> expected_output = { std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B // teknium/OpenHermes-2.5-Mistral-7B
@ -49,6 +51,8 @@ int main(void) {
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n", "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
// OrionStarAI/Orion-14B-Chat // OrionStarAI/Orion-14B-Chat
"Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>", "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
// openbmb/MiniCPM-2B-dpo-fp32
"You are a helpful assistant<\xe7\x94\xa8\xe6\x88\xb7>Hello<AI>Hi there<\xe7\x94\xa8\xe6\x88\xb7>Who are you<AI>I am an assistant<\xe7\x94\xa8\xe6\x88\xb7>Another question<AI>",
}; };
std::vector<char> formatted_chat(1024); std::vector<char> formatted_chat(1024);
int32_t res; int32_t res;