refactor chat template api

This commit is contained in:
ngxson 2024-04-22 06:45:01 +02:00
parent 5cf5e7d490
commit d6df8ecb9c
2 changed files with 279 additions and 200 deletions

447
llama.cpp
View file

@ -17074,195 +17074,250 @@ static std::string trim(const std::string & str) {
return str.substr(start, end - start);
}
// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
static int32_t llama_chat_get_model_template(
const struct llama_model * model,
const char * name,
char * buf,
int32_t length) {
GGML_ASSERT(model != nullptr);
auto get_meta = [&model](std::string template_key) {
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res < 0) {
return std::string(); // not found
} else {
return std::string(model_template.data(), model_template.size());
}
if (add_ass) {
ss << "<|im_start|>assistant\n";
};
std::string default_meta = "tokenizer.chat_template";
std::string model_template;
if (name != nullptr) {
// support for named template: https://github.com/ggerganov/llama.cpp/pull/6588
model_template = get_meta(std::string("tokenizer.chat_template.") + name);
if (model_template.empty()) {
model_template = get_meta(default_meta);
}
} else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
// llama2 template and its variants
} else {
// default template
model_template = get_meta(default_meta);
}
if (model_template.empty()) {
return -1;
} else {
snprintf(buf, length, "%s", model_template.c_str());
return model_template.size() + 1;
}
}
static llama_chat_template llama_chat_get_template_type(const char * tmpl) {
if (tmpl == nullptr) {
return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED;
}
std::string stmpl(tmpl);
auto tmpl_contains = [&stmpl](std::string needle) {
return stmpl.find(needle) != std::string::npos;
};
if (stmpl == "chatml" || tmpl_contains("<|im_start|>")) {
return LLAMA_CHAT_TEMPLATE_CHATML;
} else if (stmpl == "llama2" || tmpl_contains("[INST]")) {
// [variant] support system message
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
// [variant] space before + after response
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
bool support_system_message = tmpl_contains("<<SYS>>");
// [variant] add BOS inside history
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
// [variant] trim spaces from the input message
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
if (!is_inside_turn) {
is_inside_turn = true;
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
}
if (role == "system") {
if (support_system_message) {
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
if (support_system_message && add_bos_inside_history) {
return LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS;
} else if (support_system_message) {
return LLAMA_CHAT_TEMPLATE_LLAMA2_SYS;
} else {
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
ss << content << "\n";
return LLAMA_CHAT_TEMPLATE_LLAMA2;
}
} else if (role == "user") {
ss << content << " [/INST]";
} else if (stmpl == "zephyr" || tmpl_contains("<|user|>")) {
return LLAMA_CHAT_TEMPLATE_ZEPHYR;
} else if (stmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
return LLAMA_CHAT_TEMPLATE_MONARCH;
} else if (stmpl == "gemma" || tmpl_contains("<start_of_turn>")) {
return LLAMA_CHAT_TEMPLATE_GEMMA;
} else if (stmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
return LLAMA_CHAT_TEMPLATE_ORION;
} else if (stmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
return LLAMA_CHAT_TEMPLATE_OPENCHAT;
} else if (stmpl == "vicuna" || stmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
// [variant] support system message
if (stmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
return LLAMA_CHAT_TEMPLATE_VICUNA_ORCA;
} else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
is_inside_turn = false;
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
for (auto message : chat) {
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
ss << bos << message->role << "\n" << message->content << "</s>\n";
}
if (add_ass) {
ss << "<s>assistant\n";
}
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt = trim(message->content);
continue;
}
// in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n";
if (!system_prompt.empty() && role != "model") {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << trim(message->content) << "<end_of_turn>\n";
}
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt = message->content;
continue;
} else if (role == "user") {
ss << "Human: ";
if (!system_prompt.empty()) {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << message->content << "\n\nAssistant: </s>";
} else {
ss << message->content << "</s>";
}
}
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
// openchat/openchat-3.5-0106,
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "<|end_of_turn|>";
} else {
role[0] = toupper(role[0]);
ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
}
}
if (add_ass) {
ss << "GPT4 Correct Assistant:";
}
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
// eachadea/vicuna-13b-1.1 (and Orca variant)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// Orca-Vicuna variant uses a system prefix
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
ss << "SYSTEM: " << message->content << "\n";
} else {
ss << message->content << "\n\n";
}
} else if (role == "user") {
ss << "USER: " << message->content << "\n";
} else if (role == "assistant") {
ss << "ASSISTANT: " << message->content << "</s>\n";
}
}
if (add_ass) {
ss << "ASSISTANT:";
}
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
// deepseek-ai/deepseek-coder-33b-instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content;
} else if (role == "user") {
ss << "### Instruction:\n" << message->content << "\n";
} else if (role == "assistant") {
ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
}
}
if (add_ass) {
ss << "### Response:\n";
}
} else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
// CohereForAI/c4ai-command-r-plus
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
} else if (role == "user") {
ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
} else if (role == "assistant") {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
}
}
if (add_ass) {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
}
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
}
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
return LLAMA_CHAT_TEMPLATE_VICUNA;
}
} else if (stmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
return LLAMA_CHAT_TEMPLATE_DEEPSEEK;
} else if (stmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
return LLAMA_CHAT_TEMPLATE_COMMAND_R;
} else if (stmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
return LLAMA_CHAT_TEMPLATE_LLAMA3;
} else {
// template not supported
return -1;
return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED;
}
dest = ss.str();
return dest.size();
}
static int32_t llama_chat_get_prefix(
const llama_chat_template tmpl,
const char * role,
const char * prev_role,
char * buf,
int32_t length) {
std::stringstream ss;
std::string srole(role);
std::string sprev_role(prev_role == nullptr ? "" : prev_role);
auto str_toupper = [](std::string & str) {
std::string output(str);
for (size_t i = 0; i < output.size(); i++) {
output[i] = toupper(output[i]);
}
return output;
};
switch (tmpl) {
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
return -1;
case LLAMA_CHAT_TEMPLATE_CHATML:
ss << "<|im_start|>" << srole << "\n";
break;
case LLAMA_CHAT_TEMPLATE_LLAMA2:
if (srole == "user") {
ss << "[INST] ";
}
break;
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
if (!sprev_role.empty()) {
ss << "<s>";
}
// do not add "break"
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
if (srole == "system") {
ss << "[INST]<<SYS>>\n";
} else if (srole == "user" && sprev_role != "system") {
ss << "[INST] ";
}
break;
case LLAMA_CHAT_TEMPLATE_ZEPHYR:
ss << "<|" << srole << "|>" << "\n";
break;
case LLAMA_CHAT_TEMPLATE_MONARCH:
{
std::string bos = sprev_role.empty() ? "" : "<s>"; // skip BOS for first message
ss << bos << srole << "\n";
} break;
case LLAMA_CHAT_TEMPLATE_GEMMA:
// for gemma, "assistant" is "model"
srole = srole == "assistant" ? "model" : srole;
ss << "<start_of_turn>" << srole << "\n";
break;
case LLAMA_CHAT_TEMPLATE_ORION:
// for orion, "user" is "human"
srole = srole == "user" ? "human" : srole;
srole[0] = toupper(srole[0]); // upper case for first letter
ss << srole << ": ";
break;
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
if (srole == "system") {
ss << "";
} else {
srole[0] = toupper(srole[0]); // upper case for first letter
ss << "GPT4 Correct " << srole << ": ";
}
break;
case LLAMA_CHAT_TEMPLATE_VICUNA:
case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA:
// TODO: original vicuna template does not support system message
ss << str_toupper(srole) << ": ";
break;
case LLAMA_CHAT_TEMPLATE_DEEPSEEK:
if (srole == "user") {
ss << "### Instruction:\n";
} else {
ss << "### Response:\n";
}
break;
case LLAMA_CHAT_TEMPLATE_COMMAND_R:
// for command-r, "assistant" is "chatbot"
srole = srole == "assistant" ? "chatbot" : srole;
ss << "<|START_OF_TURN_TOKEN|><|" << str_toupper(srole) << "_TOKEN|>";
break;
case LLAMA_CHAT_TEMPLATE_LLAMA3:
ss << "<|start_header_id|>" << srole << "<|end_header_id|>\n\n";
break;
}
std::string output = ss.str();
snprintf(buf, length, "%s", output.c_str());
return output.size() + 1;
}
static int32_t llama_chat_get_postfix(
const llama_chat_template tmpl,
const char * role,
const char * prev_role,
char * buf,
int32_t length) {
std::stringstream ss;
std::string srole(role);
std::string sprev_role(prev_role == nullptr ? "" : prev_role);
switch (tmpl) {
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
return -1;
case LLAMA_CHAT_TEMPLATE_CHATML:
ss << "<|im_end|>\n";
break;
case LLAMA_CHAT_TEMPLATE_LLAMA2:
if (srole == "user") {
ss << " [/INST]";
}
break;
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
if (srole == "system") {
ss << "\n<</SYS>>\n\n";
} else {
ss << "</s>";
}
break;
case LLAMA_CHAT_TEMPLATE_ZEPHYR:
ss << "<|endoftext|>" << "\n";
break;
case LLAMA_CHAT_TEMPLATE_MONARCH:
ss << "</s>\n";
break;
case LLAMA_CHAT_TEMPLATE_GEMMA:
ss << "<end_of_turn>\n";
break;
case LLAMA_CHAT_TEMPLATE_ORION:
ss << "</s>";
break;
case LLAMA_CHAT_TEMPLATE_OPENCHAT:
srole[0] = toupper(srole[0]);
ss << "<|end_of_turn|>";
break;
case LLAMA_CHAT_TEMPLATE_VICUNA:
case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA:
ss << "</s>\n";
break;
case LLAMA_CHAT_TEMPLATE_DEEPSEEK:
if (srole == "user") {
ss << "\n";
} else {
ss << "\n<|EOT|>\n";
}
break;
case LLAMA_CHAT_TEMPLATE_COMMAND_R:
ss << "<|END_OF_TURN_TOKEN|>";
break;
case LLAMA_CHAT_TEMPLATE_LLAMA3:
ss << "<|eot_id|>";
break;
}
std::string output = ss.str();
snprintf(buf, length, "%s", output.c_str());
return output.size() + 1;
}
LLAMA_API int32_t llama_chat_apply_template(
@ -17275,11 +17330,8 @@ LLAMA_API int32_t llama_chat_apply_template(
int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
int32_t res = llama_chat_get_model_template(model, nullptr, model_template.data(), model_template.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
@ -17288,22 +17340,31 @@ LLAMA_API int32_t llama_chat_apply_template(
}
}
// format the chat to string
std::vector<const llama_chat_message *> chat_vec;
chat_vec.resize(n_msg);
for (size_t i = 0; i < n_msg; i++) {
chat_vec[i] = &chat[i];
// detect template type
llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str());
if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) {
return -1;
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
// format the chat to string
std::stringstream ss;
std::string prev_role;
std::vector<char> prefix(1024, 0);
std::vector<char> postfix(1024, 0);
for (size_t i = 0; i < n_msg; i++) {
std::string role(chat[i].role);
std::string content(chat[i].content);
llama_chat_get_prefix(ttmpl, role.c_str(), prev_role.c_str(), prefix.data(), prefix.size());
llama_chat_get_postfix(ttmpl, role.c_str(), prev_role.c_str(), postfix.data(), postfix.size());
ss << std::string(prefix.data(), prefix.size()) << content << std::string(postfix.data(), postfix.size());
prev_role = role;
}
std::string output = ss.str();
if (buf && length > 0) {
strncpy(buf, formatted_chat.c_str(), length);
snprintf(buf, length, "%s", output.c_str());
}
return res;
return output.size() + 1;
}
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {

18
llama.h
View file

@ -147,6 +147,24 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
enum llama_chat_template {
LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED = 0,
LLAMA_CHAT_TEMPLATE_CHATML = 1, // Example: teknium/OpenHermes-2.5-Mistral-7B
LLAMA_CHAT_TEMPLATE_LLAMA2 = 2, // Original llama2 template (no <<SYS>> support)
LLAMA_CHAT_TEMPLATE_LLAMA2_SYS = 3, // <<SYS>> support (example: bofenghuang/vigogne-2-70b-chat)
LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS = 4, // <<SYS>> support with BOS inside history (example: TomGrc/FusionNet_34Bx2_MoE)
LLAMA_CHAT_TEMPLATE_ZEPHYR = 5, // Example: HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1
LLAMA_CHAT_TEMPLATE_MONARCH = 6, // Example: mlabonne/AlphaMonarch-7B
LLAMA_CHAT_TEMPLATE_GEMMA = 7, // Example: google/gemma-7b-it
LLAMA_CHAT_TEMPLATE_ORION = 8, // Example: OrionStarAI/Orion-14B-Chat
LLAMA_CHAT_TEMPLATE_OPENCHAT = 9, // Example: openchat/openchat-3.5-0106
LLAMA_CHAT_TEMPLATE_VICUNA = 10, // Example: NousResearch/Nous-Capybara-34B
LLAMA_CHAT_TEMPLATE_VICUNA_ORCA = 11, // Variant of vicuna that supports system role
LLAMA_CHAT_TEMPLATE_DEEPSEEK = 12, // Example: deepseek-ai/deepseek-coder-33b-instruct
LLAMA_CHAT_TEMPLATE_COMMAND_R = 13, // Example: CohereForAI/c4ai-command-r-plus
LLAMA_CHAT_TEMPLATE_LLAMA3 = 14, // Example: meta-llama/Meta-Llama-3-8B-Instruct
};
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token