From 00efa65eb1185ef322fbe8c945fb658422b6ca56 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 24 Mar 2024 12:41:15 +0100 Subject: [PATCH] add minicpm chat template --- llama.cpp | 59 ++++++++++++++++++++++++------------ tests/test-chat-template.cpp | 4 +++ 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1a9fe0c4d..30444e017 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14491,12 +14491,26 @@ static std::string trim(const std::string & str) { // Simple version of "llama_apply_chat_template" that only works with strings // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +// List of all supported chat templates: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template static int32_t llama_chat_apply_template_internal( const std::string & tmpl, const std::vector & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; + + // utility functions to deal with chat templates that don't support system message + // in that case, we will merge system message with the next user message + std::string sys_msg; + auto sys_msg_push = [&](const llama_chat_message * msg, bool need_trim) { + sys_msg = need_trim ? trim(msg->content) : msg->content; + }; + auto sys_msg_pop = [&]() { + if (!sys_msg.empty()) { + ss << sys_msg << "\n\n"; + sys_msg = ""; + } + }; + if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) { // chatml template for (auto message : chat) { @@ -14559,46 +14573,51 @@ static int32_t llama_chat_apply_template_internal( } } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) { // google/gemma-7b-it - std::string system_prompt = ""; for (auto message : chat) { std::string role(message->role); if (role == "system") { - // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken - system_prompt = trim(message->content); - continue; + // there is no system message support, we will merge it with user prompt + sys_msg_push(message, true); + } else { + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (role != "model") { + sys_msg_pop(); + } + ss << trim(message->content) << "\n"; } - // in gemma, "assistant" is "model" - role = role == "assistant" ? "model" : message->role; - ss << "" << role << "\n"; - if (!system_prompt.empty() && role != "model") { - ss << system_prompt << "\n\n"; - system_prompt = ""; - } - ss << trim(message->content) << "\n"; } if (add_ass) { ss << "model\n"; } } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) { // OrionStarAI/Orion-14B-Chat - std::string system_prompt = ""; for (auto message : chat) { std::string role(message->role); if (role == "system") { // there is no system message support, we will merge it with user prompt - system_prompt = message->content; - continue; + sys_msg_push(message, false); } else if (role == "user") { ss << "Human: "; - if (!system_prompt.empty()) { - ss << system_prompt << "\n\n"; - system_prompt = ""; - } + sys_msg_pop(); ss << message->content << "\n\nAssistant: "; } else { ss << message->content << ""; } } + } else if (tmpl == "minicpm" || tmpl.find("<\xe7\x94\xa8\xe6\x88\xb7>") != std::string::npos) { + // openbmb/MiniCPM-2B-dpo-fp32 + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << "<\xe7\x94\xa8\xe6\x88\xb7>"; + ss << trim(message->content); + ss << ""; + } else { + ss << trim(message->content); + } + } } else { // template not supported return -1; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 6e9e4bd1e..2ac32f713 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -33,6 +33,8 @@ int main(void) { "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", // OrionStarAI/Orion-14B-Chat "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", + // openbmb/MiniCPM-2B-dpo-fp32 + "{% for message in messages %}{% if message['role'] == 'user' %}{{'<\xe7\x94\xa8\xe6\x88\xb7>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -49,6 +51,8 @@ int main(void) { "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", // OrionStarAI/Orion-14B-Chat "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + // openbmb/MiniCPM-2B-dpo-fp32 + "You are a helpful assistant<\xe7\x94\xa8\xe6\x88\xb7>HelloHi there<\xe7\x94\xa8\xe6\x88\xb7>Who are youI am an assistant<\xe7\x94\xa8\xe6\x88\xb7>Another question", }; std::vector formatted_chat(1024); int32_t res;