From be26435fdee42be2b6bdbc7bd280e79bf30b45ca Mon Sep 17 00:00:00 2001 From: ngxson Date: Thu, 27 Jun 2024 18:46:09 +0200 Subject: [PATCH] tmp_contains --- src/llama.cpp | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index b97b5e279..b307e4780 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19414,7 +19414,10 @@ static int32_t llama_chat_apply_template_internal( std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) { + auto tmpl_contains = [&tmpl](std::string haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { // chatml template for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; @@ -19422,16 +19425,16 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) { + } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) { // llama2 template and its variants // [variant] support system message - bool support_system_message = tmpl.find("<>") != std::string::npos || tmpl == "mistral"; + bool support_system_message = tmpl_contains("<>") || tmpl == "mistral"; // [variant] space before + after response - bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; + bool space_around_response = tmpl_contains("' ' + eos_token"); // [variant] add BOS inside history - bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos; + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); // [variant] trim spaces from the input message - bool strip_message = tmpl.find("content.strip()") != std::string::npos; + bool strip_message = tmpl_contains("content.strip()"); // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning ss << "[INST] "; @@ -19457,7 +19460,7 @@ static int32_t llama_chat_apply_template_internal( } } // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) { + } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { // Phi 3 for (auto message : chat) { std::string role(message->role); @@ -19466,7 +19469,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) { + } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { // zephyr template for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; @@ -19474,7 +19477,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) { + } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { // mlabonne/AlphaMonarch-7B template (the is included inside history) for (auto message : chat) { std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message @@ -19483,7 +19486,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "assistant\n"; } - } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) { + } else if (tmpl == "gemma" || tmpl_contains("")) { // google/gemma-7b-it std::string system_prompt = ""; for (auto message : chat) { @@ -19505,7 +19508,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "model\n"; } - } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) { + } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { // OrionStarAI/Orion-14B-Chat std::string system_prompt = ""; for (auto message : chat) { @@ -19525,7 +19528,7 @@ static int32_t llama_chat_apply_template_internal( ss << message->content << ""; } } - } else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) { + } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { // openchat/openchat-3.5-0106, for (auto message : chat) { std::string role(message->role); @@ -19539,13 +19542,13 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "GPT4 Correct Assistant:"; } - } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) { + } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { // eachadea/vicuna-13b-1.1 (and Orca variant) for (auto message : chat) { std::string role(message->role); if (role == "system") { // Orca-Vicuna variant uses a system prefix - if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) { + if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { ss << "SYSTEM: " << message->content << "\n"; } else { ss << message->content << "\n\n"; @@ -19559,7 +19562,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "ASSISTANT:"; } - } else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) { + } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { // deepseek-ai/deepseek-coder-33b-instruct for (auto message : chat) { std::string role(message->role); @@ -19574,7 +19577,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "### Response:\n"; } - } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) { + } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { // CohereForAI/c4ai-command-r-plus for (auto message : chat) { std::string role(message->role); @@ -19589,7 +19592,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; } - } else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) { + } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { // Llama 3 for (auto message : chat) { std::string role(message->role);