tmp_contains

This commit is contained in:
ngxson 2024-06-27 18:46:09 +02:00
parent f675b20a3b
commit be26435fde

View file

@ -19414,7 +19414,10 @@ static int32_t llama_chat_apply_template_internal(
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
return tmpl.find(haystack) != std::string::npos;
};
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@ -19422,16 +19425,16 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
// [variant] space before + after response
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
bool space_around_response = tmpl_contains("' ' + eos_token");
// [variant] add BOS inside history
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
// [variant] trim spaces from the input message
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
bool strip_message = tmpl_contains("content.strip()");
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
@ -19457,7 +19460,7 @@ static int32_t llama_chat_apply_template_internal(
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
// Phi 3
for (auto message : chat) {
std::string role(message->role);
@ -19466,7 +19469,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@ -19474,7 +19477,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
for (auto message : chat) {
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@ -19483,7 +19486,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<s>assistant\n";
}
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
} else if (tmpl == "gemma" || tmpl_contains("<start_of_turn>")) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
@ -19505,7 +19508,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
@ -19525,7 +19528,7 @@ static int32_t llama_chat_apply_template_internal(
ss << message->content << "</s>";
}
}
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
// openchat/openchat-3.5-0106,
for (auto message : chat) {
std::string role(message->role);
@ -19539,13 +19542,13 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "GPT4 Correct Assistant:";
}
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
// eachadea/vicuna-13b-1.1 (and Orca variant)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// Orca-Vicuna variant uses a system prefix
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
ss << "SYSTEM: " << message->content << "\n";
} else {
ss << message->content << "\n\n";
@ -19559,7 +19562,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "ASSISTANT:";
}
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
// deepseek-ai/deepseek-coder-33b-instruct
for (auto message : chat) {
std::string role(message->role);
@ -19574,7 +19577,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "### Response:\n";
}
} else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
// CohereForAI/c4ai-command-r-plus
for (auto message : chat) {
std::string role(message->role);
@ -19589,7 +19592,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
}
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);