fix llama_chat_apply_template
This commit is contained in:
parent
d6df8ecb9c
commit
98c46cfbfa
2 changed files with 89 additions and 30 deletions
95
llama.cpp
95
llama.cpp
|
@ -17177,6 +17177,11 @@ static int32_t llama_chat_get_prefix(
|
||||||
}
|
}
|
||||||
return output;
|
return output;
|
||||||
};
|
};
|
||||||
|
auto str_tofirstcap = [](std::string & str) {
|
||||||
|
std::string output(str);
|
||||||
|
output[0] = toupper(output[0]);
|
||||||
|
return output;
|
||||||
|
};
|
||||||
switch (tmpl) {
|
switch (tmpl) {
|
||||||
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
|
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -17189,13 +17194,20 @@ static int32_t llama_chat_get_prefix(
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
|
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
|
||||||
if (!sprev_role.empty()) {
|
if (srole == "system") {
|
||||||
ss << "<s>";
|
ss << "[INST] <<SYS>>\n";
|
||||||
|
} else if (srole == "user" && sprev_role != "system") {
|
||||||
|
if (!sprev_role.empty()) {
|
||||||
|
ss << "<s>";
|
||||||
|
}
|
||||||
|
ss << "[INST] ";
|
||||||
|
} else if (srole == "assistant") {
|
||||||
|
ss << " ";
|
||||||
}
|
}
|
||||||
// do not add "break"
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
|
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
|
||||||
if (srole == "system") {
|
if (srole == "system") {
|
||||||
ss << "[INST]<<SYS>>\n";
|
ss << "[INST] <<SYS>>\n";
|
||||||
} else if (srole == "user" && sprev_role != "system") {
|
} else if (srole == "user" && sprev_role != "system") {
|
||||||
ss << "[INST] ";
|
ss << "[INST] ";
|
||||||
}
|
}
|
||||||
|
@ -17216,15 +17228,16 @@ static int32_t llama_chat_get_prefix(
|
||||||
case LLAMA_CHAT_TEMPLATE_ORION:
|
case LLAMA_CHAT_TEMPLATE_ORION:
|
||||||
// for orion, "user" is "human"
|
// for orion, "user" is "human"
|
||||||
srole = srole == "user" ? "human" : srole;
|
srole = srole == "user" ? "human" : srole;
|
||||||
srole[0] = toupper(srole[0]); // upper case for first letter
|
ss << str_tofirstcap(srole) << ": ";
|
||||||
ss << srole << ": ";
|
if (srole == "assistant") {
|
||||||
|
ss << "</s>";
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
|
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
|
||||||
if (srole == "system") {
|
if (srole == "system") {
|
||||||
ss << "";
|
ss << "";
|
||||||
} else {
|
} else {
|
||||||
srole[0] = toupper(srole[0]); // upper case for first letter
|
ss << "GPT4 Correct " << str_tofirstcap(srole) << ": ";
|
||||||
ss << "GPT4 Correct " << srole << ": ";
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_VICUNA:
|
case LLAMA_CHAT_TEMPLATE_VICUNA:
|
||||||
|
@ -17250,7 +17263,7 @@ static int32_t llama_chat_get_prefix(
|
||||||
}
|
}
|
||||||
std::string output = ss.str();
|
std::string output = ss.str();
|
||||||
snprintf(buf, length, "%s", output.c_str());
|
snprintf(buf, length, "%s", output.c_str());
|
||||||
return output.size() + 1;
|
return output.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
static int32_t llama_chat_get_postfix(
|
static int32_t llama_chat_get_postfix(
|
||||||
|
@ -17271,14 +17284,18 @@ static int32_t llama_chat_get_postfix(
|
||||||
case LLAMA_CHAT_TEMPLATE_LLAMA2:
|
case LLAMA_CHAT_TEMPLATE_LLAMA2:
|
||||||
if (srole == "user") {
|
if (srole == "user") {
|
||||||
ss << " [/INST]";
|
ss << " [/INST]";
|
||||||
|
} else if (srole == "assistant") {
|
||||||
|
ss << "</s>";
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
|
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
|
||||||
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
|
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
|
||||||
if (srole == "system") {
|
if (srole == "system") {
|
||||||
ss << "\n<</SYS>>\n\n";
|
ss << "\n<</SYS>>\n\n";
|
||||||
} else {
|
} else if (srole == "user") {
|
||||||
ss << "</s>";
|
ss << " [/INST]";
|
||||||
|
} else if (srole == "assistant") {
|
||||||
|
ss << " </s>";
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_ZEPHYR:
|
case LLAMA_CHAT_TEMPLATE_ZEPHYR:
|
||||||
|
@ -17291,7 +17308,11 @@ static int32_t llama_chat_get_postfix(
|
||||||
ss << "<end_of_turn>\n";
|
ss << "<end_of_turn>\n";
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_ORION:
|
case LLAMA_CHAT_TEMPLATE_ORION:
|
||||||
ss << "</s>";
|
if (srole == "assistant") {
|
||||||
|
ss << "</s>";
|
||||||
|
} else {
|
||||||
|
ss << "\n\n";
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
|
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
|
||||||
srole[0] = toupper(srole[0]);
|
srole[0] = toupper(srole[0]);
|
||||||
|
@ -17299,7 +17320,11 @@ static int32_t llama_chat_get_postfix(
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_VICUNA:
|
case LLAMA_CHAT_TEMPLATE_VICUNA:
|
||||||
case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA:
|
case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA:
|
||||||
ss << "</s>\n";
|
if (srole == "assistant") {
|
||||||
|
ss << "</s>\n";
|
||||||
|
} else {
|
||||||
|
ss << "\n";
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case LLAMA_CHAT_TEMPLATE_DEEPSEEK:
|
case LLAMA_CHAT_TEMPLATE_DEEPSEEK:
|
||||||
if (srole == "user") {
|
if (srole == "user") {
|
||||||
|
@ -17317,7 +17342,25 @@ static int32_t llama_chat_get_postfix(
|
||||||
}
|
}
|
||||||
std::string output = ss.str();
|
std::string output = ss.str();
|
||||||
snprintf(buf, length, "%s", output.c_str());
|
snprintf(buf, length, "%s", output.c_str());
|
||||||
return output.size() + 1;
|
return output.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool llama_chat_support_system_message(const llama_chat_template tmpl) {
|
||||||
|
switch (tmpl) {
|
||||||
|
case LLAMA_CHAT_TEMPLATE_CHATML:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_ZEPHYR:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_MONARCH:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_ORION:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_COMMAND_R:
|
||||||
|
case LLAMA_CHAT_TEMPLATE_LLAMA3:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_API int32_t llama_chat_apply_template(
|
LLAMA_API int32_t llama_chat_apply_template(
|
||||||
|
@ -17342,6 +17385,7 @@ LLAMA_API int32_t llama_chat_apply_template(
|
||||||
|
|
||||||
// detect template type
|
// detect template type
|
||||||
llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str());
|
llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str());
|
||||||
|
bool support_system_message = llama_chat_support_system_message(ttmpl);
|
||||||
if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) {
|
if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -17349,22 +17393,37 @@ LLAMA_API int32_t llama_chat_apply_template(
|
||||||
// format the chat to string
|
// format the chat to string
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
std::string prev_role;
|
std::string prev_role;
|
||||||
std::vector<char> prefix(1024, 0);
|
|
||||||
std::vector<char> postfix(1024, 0);
|
|
||||||
for (size_t i = 0; i < n_msg; i++) {
|
for (size_t i = 0; i < n_msg; i++) {
|
||||||
std::string role(chat[i].role);
|
std::string role(chat[i].role);
|
||||||
std::string content(chat[i].content);
|
std::string content(chat[i].content);
|
||||||
|
if (!support_system_message) {
|
||||||
|
// if the template does not support system message, we convert it to user message
|
||||||
|
role = role == "system" ? "user" : role;
|
||||||
|
}
|
||||||
|
std::vector<char> prefix(1024, 0);
|
||||||
|
std::vector<char> postfix(1024, 0);
|
||||||
llama_chat_get_prefix(ttmpl, role.c_str(), prev_role.c_str(), prefix.data(), prefix.size());
|
llama_chat_get_prefix(ttmpl, role.c_str(), prev_role.c_str(), prefix.data(), prefix.size());
|
||||||
llama_chat_get_postfix(ttmpl, role.c_str(), prev_role.c_str(), postfix.data(), postfix.size());
|
llama_chat_get_postfix(ttmpl, role.c_str(), prev_role.c_str(), postfix.data(), postfix.size());
|
||||||
ss << std::string(prefix.data(), prefix.size()) << content << std::string(postfix.data(), postfix.size());
|
ss << std::string(prefix.data()) << trim(content) << std::string(postfix.data());
|
||||||
prev_role = role;
|
prev_role = role;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (add_ass) {
|
||||||
|
std::vector<char> prefix(1024, 0);
|
||||||
|
llama_chat_get_prefix(ttmpl, "assistant", prev_role.c_str(), prefix.data(), prefix.size());
|
||||||
|
std::string assistant_prompt(prefix.data());
|
||||||
|
if (assistant_prompt.back() == ' ') {
|
||||||
|
// Some templates need trailing space to be tokenized with the next word. We should make sure there is no trailing in the output text
|
||||||
|
assistant_prompt.pop_back();
|
||||||
|
}
|
||||||
|
ss << assistant_prompt;
|
||||||
|
}
|
||||||
|
|
||||||
std::string output = ss.str();
|
std::string output = ss.str();
|
||||||
if (buf && length > 0) {
|
if (buf && length > 0) {
|
||||||
snprintf(buf, length, "%s", output.c_str());
|
snprintf(buf, length, "%s", output.c_str());
|
||||||
}
|
}
|
||||||
return output.size() + 1;
|
return output.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
||||||
|
|
|
@ -14,7 +14,7 @@ int main(void) {
|
||||||
{"user", "Hello"},
|
{"user", "Hello"},
|
||||||
{"assistant", "Hi there"},
|
{"assistant", "Hi there"},
|
||||||
{"user", "Who are you"},
|
{"user", "Who are you"},
|
||||||
{"assistant", " I am an assistant "},
|
{"assistant", "I am an assistant"},
|
||||||
{"user", "Another question"},
|
{"user", "Another question"},
|
||||||
};
|
};
|
||||||
size_t message_count = 6;
|
size_t message_count = 6;
|
||||||
|
@ -52,27 +52,27 @@ int main(void) {
|
||||||
};
|
};
|
||||||
std::vector<std::string> expected_output = {
|
std::vector<std::string> expected_output = {
|
||||||
// teknium/OpenHermes-2.5-Mistral-7B
|
// teknium/OpenHermes-2.5-Mistral-7B
|
||||||
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
|
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\nI am an assistant<|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
|
||||||
// mistralai/Mistral-7B-Instruct-v0.2
|
// mistralai/Mistral-7B-Instruct-v0.2
|
||||||
"[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
"[INST] You are a helpful assistant [/INST][INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
|
||||||
// TheBloke/FusionNet_34Bx2_MoE-AWQ
|
// TheBloke/FusionNet_34Bx2_MoE-AWQ
|
||||||
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
||||||
// bofenghuang/vigogne-2-70b-chat
|
// bofenghuang/vigogne-2-70b-chat
|
||||||
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there </s>[INST] Who are you [/INST]I am an assistant </s>[INST] Another question [/INST]",
|
||||||
// mlabonne/AlphaMonarch-7B
|
// mlabonne/AlphaMonarch-7B
|
||||||
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\nI am an assistant</s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
||||||
// google/gemma-7b-it
|
// google/gemma-7b-it
|
||||||
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
|
"<start_of_turn>user\nYou are a helpful assistant<end_of_turn>\n<start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
|
||||||
// OrionStarAI/Orion-14B-Chat
|
// OrionStarAI/Orion-14B-Chat
|
||||||
"Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
"System: You are a helpful assistant\n\nHuman: Hello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>I am an assistant</s>Human: Another question\n\nAssistant: </s>",
|
||||||
// openchat/openchat-3.5-0106
|
// openchat/openchat-3.5-0106
|
||||||
"You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
|
"You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant<|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
|
||||||
// deepseek-ai/deepseek-coder-33b-instruct
|
// deepseek-ai/deepseek-coder-33b-instruct
|
||||||
"You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
|
"### Instruction:\nYou are a helpful assistant\n### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\nI am an assistant\n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
|
||||||
// eachadea/vicuna-13b-1.1
|
// eachadea/vicuna-13b-1.1
|
||||||
"You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
|
"USER: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant</s>\nUSER: Another question\nASSISTANT:",
|
||||||
// Orca-Vicuna
|
// Orca-Vicuna
|
||||||
"SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
|
"SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant</s>\nUSER: Another question\nASSISTANT:",
|
||||||
// CohereForAI/c4ai-command-r-plus
|
// CohereForAI/c4ai-command-r-plus
|
||||||
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||||
// Llama 3
|
// Llama 3
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue