llama_chat_apply_template: change variable name to "tmpl"

This commit is contained in:
ngxson 2024-02-18 22:06:44 +01:00
parent 73fbd67901
commit 649f6f8290
2 changed files with 18 additions and 23 deletions

View file

@ -12459,11 +12459,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
return 0;
}
int32_t llama_chat_apply_template_internal(
const std::string & chat_template,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass);
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
size_t start = 0;
@ -12479,13 +12474,13 @@ static std::string trim(const std::string & str) {
// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
int32_t llama_chat_apply_template_internal(
const std::string & chat_template,
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
if (chat_template.find("<|im_start|>") != std::string::npos) {
if (tmpl.find("<|im_start|>") != std::string::npos) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@ -12493,16 +12488,16 @@ int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (chat_template.find("[INST]") != std::string::npos) {
} else if (tmpl.find("[INST]") != std::string::npos) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = chat_template.find("<<SYS>>") != std::string::npos;
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
// [variant] space before + after response
bool space_around_response = chat_template.find("' ' + eos_token") != std::string::npos;
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
// [variant] add BOS inside history
bool add_bos_inside_history = chat_template.find("bos_token + '[INST]") != std::string::npos;
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
// [variant] trim spaces from the input message
bool strip_message = chat_template.find("content.strip()") != std::string::npos;
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
@ -12528,7 +12523,7 @@ int32_t llama_chat_apply_template_internal(
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (chat_template.find("<|user|>") != std::string::npos) {
} else if (tmpl.find("<|user|>") != std::string::npos) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@ -12546,24 +12541,24 @@ int32_t llama_chat_apply_template_internal(
LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * custom_template,
const char * tmpl,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
int32_t length) {
std::string current_template(custom_template == nullptr ? "" : custom_template);
if (custom_template == nullptr) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), current_template.size());
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
current_template = "<|im_start|>"; // see llama_chat_apply_template_internal
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
} else {
current_template = std::string(model_template.data(), model_template.size());
curr_tmpl = std::string(model_template.data(), model_template.size());
}
}
// format the chat to string
@ -12573,7 +12568,7 @@ LLAMA_API int32_t llama_chat_apply_template(
chat_vec[i] = &chat[i];
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass);
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
}

View file

@ -707,7 +707,7 @@ extern "C" {
/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
/// @param custom_template A Jinja template to use for this chat. If this is nullptr, the models default chat template will be used instead.
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the models default chat template will be used instead.
/// @param chat Pointer to a list of multiple llama_chat_message
/// @param n_msg Number of llama_chat_message in this chat
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
@ -716,7 +716,7 @@ extern "C" {
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * custom_template,
const char * tmpl,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,