Changed system message logic and added tests for all 4

This commit is contained in:
juk 2024-11-29 00:40:10 +00:00
parent dbbde92374
commit f55c434eb4
2 changed files with 15 additions and 7 deletions

View file

@ -21867,21 +21867,21 @@ static int32_t llama_chat_apply_template_internal(
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
std::string leading_space = (tmpl == "mistral-v1" ? " " : "");
std::string trailing_space = (tmpl != "mistral-v3-tekken" ? " " : "");
std::string system_message = "";
bool is_inside_turn = false;
for (auto message : chat) {
if (!is_inside_turn) {
ss << leading_space << "[INST]" << trailing_space;
is_inside_turn = true;
}
std::string role(message->role);
std::string content = trim(message->content);
if (role == "system") {
system_message = content;
ss << system_message << "\n\n";
} else if (role == "user") {
ss << leading_space << "[INST]" << trailing_space;
if (!system_message.empty()) {
ss << system_message << "\n\n";
system_message = "";
}
ss << content << leading_space << "[/INST]";
} else {
ss << trailing_space << content << "</s>";
is_inside_turn = false;
}
}
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {

View file

@ -154,6 +154,10 @@ int main(void) {
return output;
};
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
assert(fmt_sys("mistral-v2") == "[INST] You are a helpful assistant\n\n");
assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
@ -173,6 +177,10 @@ int main(void) {
return output;
};
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
assert(fmt_single("mistral-v2") == "[INST] How are you[/INST]");
assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");