fix llama_chat_format_single for mistral

This commit is contained in:
ngxson 2024-07-23 21:43:07 +02:00
parent b841d07408
commit f70331ebd2
3 changed files with 21 additions and 4 deletions

View file

@ -2721,7 +2721,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
const llama_chat_msg & new_msg, const llama_chat_msg & new_msg,
bool add_ass) { bool add_ass) {
std::ostringstream ss; std::ostringstream ss;
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
std::vector<llama_chat_msg> chat_new(past_msg); std::vector<llama_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version // if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {

View file

@ -124,6 +124,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
auto formatted = llama_chat_format_single( auto formatted = llama_chat_format_single(
model, g_params->chat_template, chat_msgs, new_msg, role == "user"); model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content}); chat_msgs.push_back({role, content});
LOG("formatted: %s\n", formatted.c_str());
return formatted; return formatted;
} }

View file

@ -137,9 +137,25 @@ int main(void) {
assert(output == expected); assert(output == expected);
} }
// test llama_chat_format_single
std::cout << "\n\n=== llama_chat_format_single ===\n\n"; // test llama_chat_format_single for system message
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
std::vector<llama_chat_msg> chat2; std::vector<llama_chat_msg> chat2;
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
auto fmt_sys = [&](std::string tmpl) {
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
std::cout << "fmt_sys(" << tmpl << ")\n" << output << "\n-------------------------\n";
return output;
};
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
// test llama_chat_format_single for user message
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
chat2.push_back({"system", "You are a helpful assistant"}); chat2.push_back({"system", "You are a helpful assistant"});
chat2.push_back({"user", "Hello"}); chat2.push_back({"user", "Hello"});
chat2.push_back({"assistant", "I am assistant"}); chat2.push_back({"assistant", "I am assistant"});