This commit is contained in:
ngxson 2024-06-24 10:45:31 +02:00
parent 962be6a834
commit 43cab6bfc6
2 changed files with 12 additions and 12 deletions

View file

@ -2984,13 +2984,15 @@ std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & msgs,
bool add_ass) {
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf;
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
@ -3001,7 +3003,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
const std::string formatted_chat(buf.data(), res);
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}

View file

@ -37,7 +37,6 @@ static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static std::vector<llama_chat_msg> * g_chat_msgs;
static bool is_interacting = false;
static bool file_exists(const std::string & path) {
@ -118,13 +117,13 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
LOG_TEE("%s", text);
}
static std::string chat_add_and_format(std::string role, std::string content) {
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
llama_chat_msg new_msg{role, content};
auto formatted = llama_chat_format_single(
*g_model, g_params->chat_template, *g_chat_msgs, new_msg, role == "user");
g_chat_msgs->push_back({role, content});
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
return formatted;
}
};
int main(int argc, char ** argv) {
gpt_params params;
@ -202,7 +201,6 @@ int main(int argc, char ** argv) {
std::vector<llama_chat_msg> chat_msgs;
g_model = &model;
g_ctx = &ctx;
g_chat_msgs = &chat_msgs;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@ -262,7 +260,7 @@ int main(int argc, char ** argv) {
{
auto prompt = params.conversation
? chat_add_and_format("system", params.prompt) // format the system prompt in conversation mode
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
@ -810,7 +808,7 @@ int main(int argc, char ** argv) {
is_antiprompt = true;
}
chat_add_and_format("system", assistant_ss.str());
chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
is_interacting = true;
printf("\n");
}
@ -873,8 +871,8 @@ int main(int argc, char ** argv) {
}
std::string user_inp = params.conversation
? chat_add_and_format("user", buffer)
: buffer;
? chat_add_and_format(model, chat_msgs, "user", buffer)
: std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);