fix llama_chat_format_single
for mistral
This commit is contained in:
parent
b841d07408
commit
f70331ebd2
3 changed files with 21 additions and 4 deletions
|
@ -2721,7 +2721,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
const llama_chat_msg & new_msg,
|
const llama_chat_msg & new_msg,
|
||||||
bool add_ass) {
|
bool add_ass) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
|
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
|
||||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
|
|
|
@ -124,6 +124,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
|
||||||
auto formatted = llama_chat_format_single(
|
auto formatted = llama_chat_format_single(
|
||||||
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||||
chat_msgs.push_back({role, content});
|
chat_msgs.push_back({role, content});
|
||||||
|
LOG("formatted: %s\n", formatted.c_str());
|
||||||
return formatted;
|
return formatted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -137,9 +137,25 @@ int main(void) {
|
||||||
assert(output == expected);
|
assert(output == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test llama_chat_format_single
|
|
||||||
std::cout << "\n\n=== llama_chat_format_single ===\n\n";
|
// test llama_chat_format_single for system message
|
||||||
|
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
|
||||||
std::vector<llama_chat_msg> chat2;
|
std::vector<llama_chat_msg> chat2;
|
||||||
|
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
||||||
|
|
||||||
|
auto fmt_sys = [&](std::string tmpl) {
|
||||||
|
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
|
||||||
|
std::cout << "fmt_sys(" << tmpl << ")\n" << output << "\n-------------------------\n";
|
||||||
|
return output;
|
||||||
|
};
|
||||||
|
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
|
||||||
|
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
|
||||||
|
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
|
||||||
|
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
|
||||||
|
|
||||||
|
|
||||||
|
// test llama_chat_format_single for user message
|
||||||
|
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
|
||||||
chat2.push_back({"system", "You are a helpful assistant"});
|
chat2.push_back({"system", "You are a helpful assistant"});
|
||||||
chat2.push_back({"user", "Hello"});
|
chat2.push_back({"user", "Hello"});
|
||||||
chat2.push_back({"assistant", "I am assistant"});
|
chat2.push_back({"assistant", "I am assistant"});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue