From 98c46cfbfa57483e3788fb9189f3cecc7beb6229 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 22 Apr 2024 08:49:57 +0200 Subject: [PATCH] fix llama_chat_apply_template --- llama.cpp | 95 +++++++++++++++++++++++++++++------- tests/test-chat-template.cpp | 24 ++++----- 2 files changed, 89 insertions(+), 30 deletions(-) diff --git a/llama.cpp b/llama.cpp index 16786bff3..436082aca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17177,6 +17177,11 @@ static int32_t llama_chat_get_prefix( } return output; }; + auto str_tofirstcap = [](std::string & str) { + std::string output(str); + output[0] = toupper(output[0]); + return output; + }; switch (tmpl) { case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: return -1; @@ -17189,13 +17194,20 @@ static int32_t llama_chat_get_prefix( } break; case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: - if (!sprev_role.empty()) { - ss << ""; + if (srole == "system") { + ss << "[INST] <>\n"; + } else if (srole == "user" && sprev_role != "system") { + if (!sprev_role.empty()) { + ss << ""; + } + ss << "[INST] "; + } else if (srole == "assistant") { + ss << " "; } - // do not add "break" + break; case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: if (srole == "system") { - ss << "[INST]<>\n"; + ss << "[INST] <>\n"; } else if (srole == "user" && sprev_role != "system") { ss << "[INST] "; } @@ -17216,15 +17228,16 @@ static int32_t llama_chat_get_prefix( case LLAMA_CHAT_TEMPLATE_ORION: // for orion, "user" is "human" srole = srole == "user" ? "human" : srole; - srole[0] = toupper(srole[0]); // upper case for first letter - ss << srole << ": "; + ss << str_tofirstcap(srole) << ": "; + if (srole == "assistant") { + ss << ""; + } break; case LLAMA_CHAT_TEMPLATE_OPENCHAT: if (srole == "system") { ss << ""; } else { - srole[0] = toupper(srole[0]); // upper case for first letter - ss << "GPT4 Correct " << srole << ": "; + ss << "GPT4 Correct " << str_tofirstcap(srole) << ": "; } break; case LLAMA_CHAT_TEMPLATE_VICUNA: @@ -17250,7 +17263,7 @@ static int32_t llama_chat_get_prefix( } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); - return output.size() + 1; + return output.size(); } static int32_t llama_chat_get_postfix( @@ -17271,14 +17284,18 @@ static int32_t llama_chat_get_postfix( case LLAMA_CHAT_TEMPLATE_LLAMA2: if (srole == "user") { ss << " [/INST]"; + } else if (srole == "assistant") { + ss << ""; } break; case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: if (srole == "system") { ss << "\n<>\n\n"; - } else { - ss << ""; + } else if (srole == "user") { + ss << " [/INST]"; + } else if (srole == "assistant") { + ss << " "; } break; case LLAMA_CHAT_TEMPLATE_ZEPHYR: @@ -17291,7 +17308,11 @@ static int32_t llama_chat_get_postfix( ss << "\n"; break; case LLAMA_CHAT_TEMPLATE_ORION: - ss << ""; + if (srole == "assistant") { + ss << ""; + } else { + ss << "\n\n"; + } break; case LLAMA_CHAT_TEMPLATE_OPENCHAT: srole[0] = toupper(srole[0]); @@ -17299,7 +17320,11 @@ static int32_t llama_chat_get_postfix( break; case LLAMA_CHAT_TEMPLATE_VICUNA: case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: - ss << "\n"; + if (srole == "assistant") { + ss << "\n"; + } else { + ss << "\n"; + } break; case LLAMA_CHAT_TEMPLATE_DEEPSEEK: if (srole == "user") { @@ -17317,7 +17342,25 @@ static int32_t llama_chat_get_postfix( } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); - return output.size() + 1; + return output.size(); +} + +static bool llama_chat_support_system_message(const llama_chat_template tmpl) { + switch (tmpl) { + case LLAMA_CHAT_TEMPLATE_CHATML: + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: + case LLAMA_CHAT_TEMPLATE_ZEPHYR: + case LLAMA_CHAT_TEMPLATE_MONARCH: + case LLAMA_CHAT_TEMPLATE_ORION: + case LLAMA_CHAT_TEMPLATE_OPENCHAT: + case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: + case LLAMA_CHAT_TEMPLATE_COMMAND_R: + case LLAMA_CHAT_TEMPLATE_LLAMA3: + return true; + default: + return false; + } } LLAMA_API int32_t llama_chat_apply_template( @@ -17342,6 +17385,7 @@ LLAMA_API int32_t llama_chat_apply_template( // detect template type llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str()); + bool support_system_message = llama_chat_support_system_message(ttmpl); if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) { return -1; } @@ -17349,22 +17393,37 @@ LLAMA_API int32_t llama_chat_apply_template( // format the chat to string std::stringstream ss; std::string prev_role; - std::vector prefix(1024, 0); - std::vector postfix(1024, 0); for (size_t i = 0; i < n_msg; i++) { std::string role(chat[i].role); std::string content(chat[i].content); + if (!support_system_message) { + // if the template does not support system message, we convert it to user message + role = role == "system" ? "user" : role; + } + std::vector prefix(1024, 0); + std::vector postfix(1024, 0); llama_chat_get_prefix(ttmpl, role.c_str(), prev_role.c_str(), prefix.data(), prefix.size()); llama_chat_get_postfix(ttmpl, role.c_str(), prev_role.c_str(), postfix.data(), postfix.size()); - ss << std::string(prefix.data(), prefix.size()) << content << std::string(postfix.data(), postfix.size()); + ss << std::string(prefix.data()) << trim(content) << std::string(postfix.data()); prev_role = role; } + if (add_ass) { + std::vector prefix(1024, 0); + llama_chat_get_prefix(ttmpl, "assistant", prev_role.c_str(), prefix.data(), prefix.size()); + std::string assistant_prompt(prefix.data()); + if (assistant_prompt.back() == ' ') { + // Some templates need trailing space to be tokenized with the next word. We should make sure there is no trailing in the output text + assistant_prompt.pop_back(); + } + ss << assistant_prompt; + } + std::string output = ss.str(); if (buf && length > 0) { snprintf(buf, length, "%s", output.c_str()); } - return output.size() + 1; + return output.size(); } LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cddf86a41..c2db07257 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -14,7 +14,7 @@ int main(void) { {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "Who are you"}, - {"assistant", " I am an assistant "}, + {"assistant", "I am an assistant"}, {"user", "Another question"}, }; size_t message_count = 6; @@ -52,27 +52,27 @@ int main(void) { }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\nI am an assistant<|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", // mistralai/Mistral-7B-Instruct-v0.2 - "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] You are a helpful assistant [/INST][INST] Hello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", // TheBloke/FusionNet_34Bx2_MoE-AWQ - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", // bofenghuang/vigogne-2-70b-chat - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there [INST] Who are you [/INST]I am an assistant [INST] Another question [/INST]", // mlabonne/AlphaMonarch-7B - "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\nI am an assistant\nuser\nAnother question\nassistant\n", // google/gemma-7b-it - "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + "user\nYou are a helpful assistant\nuser\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", // OrionStarAI/Orion-14B-Chat - "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + "System: You are a helpful assistant\n\nHuman: Hello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistantHuman: Another question\n\nAssistant: ", // openchat/openchat-3.5-0106 - "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant<|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", // deepseek-ai/deepseek-coder-33b-instruct - "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + "### Instruction:\nYou are a helpful assistant\n### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\nI am an assistant\n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", // eachadea/vicuna-13b-1.1 - "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + "USER: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", // Orca-Vicuna - "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", // CohereForAI/c4ai-command-r-plus "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", // Llama 3