llama_chat_apply_template: use term "chat" everywhere
This commit is contained in:
parent
dba4337f85
commit
73fbd67901
2 changed files with 23 additions and 15 deletions
28
llama.cpp
28
llama.cpp
|
@ -12459,7 +12459,10 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
|||
return 0;
|
||||
}
|
||||
|
||||
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass);
|
||||
int32_t llama_chat_apply_template_internal(
|
||||
const std::string & chat_template,
|
||||
const std::vector<const llama_chat_message *> & chat,
|
||||
std::string & dest, bool add_ass);
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
|
@ -12476,12 +12479,15 @@ static std::string trim(const std::string & str) {
|
|||
|
||||
// Simple version of "llama_apply_chat_template" that only works with strings
|
||||
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
||||
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass) {
|
||||
int32_t llama_chat_apply_template_internal(
|
||||
const std::string & chat_template,
|
||||
const std::vector<const llama_chat_message *> & chat,
|
||||
std::string & dest, bool add_ass) {
|
||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||
std::stringstream ss;
|
||||
if (chat_template.find("<|im_start|>") != std::string::npos) {
|
||||
// chatml template
|
||||
for (auto message : conversation) {
|
||||
for (auto message : chat) {
|
||||
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
|
@ -12500,7 +12506,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
|
|||
// construct the prompt
|
||||
bool is_inside_turn = true; // skip BOS at the beginning
|
||||
ss << "[INST] ";
|
||||
for (auto message : conversation) {
|
||||
for (auto message : chat) {
|
||||
std::string content = strip_message ? trim(message->content) : message->content;
|
||||
std::string role(message->role);
|
||||
if (!is_inside_turn) {
|
||||
|
@ -12524,7 +12530,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
|
|||
// llama2 templates seem to not care about "add_generation_prompt"
|
||||
} else if (chat_template.find("<|user|>") != std::string::npos) {
|
||||
// zephyr template
|
||||
for (auto message : conversation) {
|
||||
for (auto message : chat) {
|
||||
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
|
@ -12541,7 +12547,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
|
|||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * custom_template,
|
||||
const struct llama_chat_message * msg,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
|
@ -12560,14 +12566,14 @@ LLAMA_API int32_t llama_chat_apply_template(
|
|||
current_template = std::string(model_template.data(), model_template.size());
|
||||
}
|
||||
}
|
||||
// format the conversation to string
|
||||
std::vector<const llama_chat_message *> conversation_vec;
|
||||
conversation_vec.resize(n_msg);
|
||||
// format the chat to string
|
||||
std::vector<const llama_chat_message *> chat_vec;
|
||||
chat_vec.resize(n_msg);
|
||||
for (size_t i = 0; i < n_msg; i++) {
|
||||
conversation_vec[i] = &msg[i];
|
||||
chat_vec[i] = &chat[i];
|
||||
}
|
||||
std::string formatted_chat;
|
||||
int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass);
|
||||
int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass);
|
||||
if (res < 0) {
|
||||
return res;
|
||||
}
|
||||
|
|
10
llama.h
10
llama.h
|
@ -704,18 +704,20 @@ extern "C" {
|
|||
char * buf,
|
||||
int32_t length);
|
||||
|
||||
/// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python.
|
||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
|
||||
/// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead.
|
||||
/// @param msg Pointer to a list of multiple llama_chat_message
|
||||
/// @param custom_template A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
||||
/// @param chat Pointer to a list of multiple llama_chat_message
|
||||
/// @param n_msg Number of llama_chat_message in this chat
|
||||
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
|
||||
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
|
||||
/// @param length The size of the allocated buffer
|
||||
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
|
||||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * custom_template,
|
||||
const struct llama_chat_message * msg,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue