add chat template support for llama-cli
This commit is contained in:
parent
3e58b0ee35
commit
5a2fde8385
5 changed files with 134 additions and 20 deletions
|
@ -2967,12 +2967,54 @@ bool llama_should_add_bos_token(const llama_model * model) {
|
||||||
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
|
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Chat template utils
|
||||||
|
//
|
||||||
|
|
||||||
bool llama_chat_verify_template(const std::string & tmpl) {
|
bool llama_chat_verify_template(const std::string & tmpl) {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||||
return res >= 0;
|
return res >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_chat_format(const struct llama_model * model,
|
||||||
|
const std::string & tmpl,
|
||||||
|
const std::vector<llama_chat_msg> & msgs,
|
||||||
|
bool add_ass) {
|
||||||
|
std::vector<llama_chat_message> chat;
|
||||||
|
for (auto & msg : msgs) {
|
||||||
|
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||||
|
std::vector<char> buf;
|
||||||
|
|
||||||
|
// run the first time to get the total output length
|
||||||
|
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
|
||||||
|
// if it turns out that our buffer is too small, we resize it
|
||||||
|
if ((size_t) res > buf.size()) {
|
||||||
|
buf.resize(res);
|
||||||
|
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string formatted_chat(buf.data(), res);
|
||||||
|
return formatted_chat;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
|
const std::string & tmpl,
|
||||||
|
const std::vector<llama_chat_msg> & past_msg,
|
||||||
|
const llama_chat_msg & new_msg,
|
||||||
|
bool add_ass) {
|
||||||
|
auto fmt_past_msg = llama_chat_format(model, tmpl, past_msg, false);
|
||||||
|
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||||
|
chat_new.push_back(new_msg);
|
||||||
|
auto fmt_new_msg = llama_chat_format(model, tmpl, chat_new, add_ass);
|
||||||
|
auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
|
return formatted;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -360,9 +360,28 @@ bool llama_should_add_bos_token(const llama_model * model);
|
||||||
// Chat template utils
|
// Chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// same with llama_chat_message, but uses std::string
|
||||||
|
struct llama_chat_msg {
|
||||||
|
std::string role;
|
||||||
|
std::string content;
|
||||||
|
};
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
bool llama_chat_verify_template(const std::string & tmpl);
|
bool llama_chat_verify_template(const std::string & tmpl);
|
||||||
|
|
||||||
|
// CPP wrapper for llama_chat_apply_template
|
||||||
|
std::string llama_chat_format(const struct llama_model * model,
|
||||||
|
const std::string & tmpl,
|
||||||
|
const std::vector<llama_chat_msg> & chat,
|
||||||
|
bool add_ass);
|
||||||
|
|
||||||
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
|
std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
|
const std::string & tmpl,
|
||||||
|
const std::vector<llama_chat_msg> & past_msg,
|
||||||
|
const llama_chat_msg & new_msg,
|
||||||
|
bool add_ass);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -37,6 +37,7 @@ static gpt_params * g_params;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::ostringstream * g_output_ss;
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
static std::vector<llama_token> * g_output_tokens;
|
||||||
|
static std::vector<llama_chat_msg> * g_chat_msgs;
|
||||||
static bool is_interacting = false;
|
static bool is_interacting = false;
|
||||||
|
|
||||||
static bool file_exists(const std::string & path) {
|
static bool file_exists(const std::string & path) {
|
||||||
|
@ -117,6 +118,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
||||||
LOG_TEE("%s", text);
|
LOG_TEE("%s", text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string chat_add_and_format(std::string role, std::string content) {
|
||||||
|
llama_chat_msg new_msg{role, content};
|
||||||
|
auto formatted = llama_chat_format_single(
|
||||||
|
*g_model, g_params->chat_template, *g_chat_msgs, new_msg, role == "user");
|
||||||
|
g_chat_msgs->push_back({role, content});
|
||||||
|
return formatted;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
@ -190,8 +199,10 @@ int main(int argc, char ** argv) {
|
||||||
llama_model * model;
|
llama_model * model;
|
||||||
llama_context * ctx;
|
llama_context * ctx;
|
||||||
llama_context * ctx_guidance = NULL;
|
llama_context * ctx_guidance = NULL;
|
||||||
|
std::vector<llama_chat_msg> chat_msgs;
|
||||||
g_model = &model;
|
g_model = &model;
|
||||||
g_ctx = &ctx;
|
g_ctx = &ctx;
|
||||||
|
g_chat_msgs = &chat_msgs;
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||||
|
@ -249,16 +260,21 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto prompt = params.conversation
|
||||||
|
? chat_add_and_format("system", params.prompt) // format the system prompt in conversation mode
|
||||||
|
: params.prompt;
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
LOG("tokenize the prompt\n");
|
LOG("tokenize the prompt\n");
|
||||||
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
embd_inp = ::llama_tokenize(ctx, prompt, true, true);
|
||||||
} else {
|
} else {
|
||||||
LOG("use session tokens\n");
|
LOG("use session tokens\n");
|
||||||
embd_inp = session_tokens;
|
embd_inp = session_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
LOG("prompt: \"%s\"\n", log_tostr(prompt));
|
||||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
if (embd_inp.empty()) {
|
if (embd_inp.empty()) {
|
||||||
|
@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||||
std::ostringstream output_ss; g_output_ss = &output_ss;
|
std::ostringstream output_ss; g_output_ss = &output_ss;
|
||||||
|
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
||||||
|
|
||||||
// the first thing we will do is to output the prompt, so set color accordingly
|
// the first thing we will do is to output the prompt, so set color accordingly
|
||||||
console::set_display(console::prompt);
|
console::set_display(console::prompt);
|
||||||
|
@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chat_add_and_format("system", assistant_ss.str());
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if current token is not EOG, we add it to current assistant message
|
||||||
|
if (params.conversation) {
|
||||||
|
auto id = llama_sampling_last(ctx_sampling);
|
||||||
|
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||||
|
}
|
||||||
|
|
||||||
if (n_past > 0 && is_interacting) {
|
if (n_past > 0 && is_interacting) {
|
||||||
LOG("waiting for user input\n");
|
LOG("waiting for user input\n");
|
||||||
|
|
||||||
|
@ -848,8 +872,14 @@ int main(int argc, char ** argv) {
|
||||||
string_process_escapes(buffer);
|
string_process_escapes(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string user_inp = params.conversation
|
||||||
|
? chat_add_and_format("user", buffer)
|
||||||
|
: buffer;
|
||||||
|
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||||
|
bool accept_special_content = params.conversation;
|
||||||
|
|
||||||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
||||||
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
|
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, accept_special_content);
|
||||||
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
||||||
|
|
||||||
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
|
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
|
||||||
|
@ -864,6 +894,9 @@ int main(int argc, char ** argv) {
|
||||||
output_ss << llama_token_to_piece(ctx, token);
|
output_ss << llama_token_to_piece(ctx, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reset assistant message
|
||||||
|
assistant_ss.str("");
|
||||||
|
|
||||||
n_remain -= line_inp.size();
|
n_remain -= line_inp.size();
|
||||||
LOG("n_remain: %d\n", n_remain);
|
LOG("n_remain: %d\n", n_remain);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -18589,10 +18589,10 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|im_start|>assistant\n";
|
ss << "<|im_start|>assistant\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
|
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
|
||||||
// llama2 template and its variants
|
// llama2 template and its variants
|
||||||
// [variant] support system message
|
// [variant] support system message
|
||||||
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
|
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
|
||||||
// [variant] space before + after response
|
// [variant] space before + after response
|
||||||
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
|
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
|
||||||
// [variant] add BOS inside history
|
// [variant] add BOS inside history
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
llama_chat_message conversation[] = {
|
llama_chat_message conversation[] = {
|
||||||
|
@ -119,5 +120,24 @@ int main(void) {
|
||||||
std::cout << output << "\n-------------------------\n";
|
std::cout << output << "\n-------------------------\n";
|
||||||
assert(output == expected);
|
assert(output == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test llama_chat_format_single
|
||||||
|
std::cout << "\n\n=== llama_chat_format_single ===\n\n";
|
||||||
|
std::vector<llama_chat_msg> chat2;
|
||||||
|
chat2.push_back({"system", "You are a helpful assistant"});
|
||||||
|
chat2.push_back({"user", "Hello"});
|
||||||
|
chat2.push_back({"assistant", "I am assistant"});
|
||||||
|
llama_chat_msg new_msg{"user", "How are you"};
|
||||||
|
|
||||||
|
auto fmt_single = [&](std::string tmpl) {
|
||||||
|
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
|
||||||
|
std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
|
||||||
|
return output;
|
||||||
|
};
|
||||||
|
assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
|
||||||
|
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
|
||||||
|
assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
|
||||||
|
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue