diff --git a/common/common.cpp b/common/common.cpp index 388f650ec..b1de5615b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2984,13 +2984,15 @@ std::string llama_chat_apply_template(const struct llama_model * model, const std::string & tmpl, const std::vector & msgs, bool add_ass) { + int alloc_size = 0; std::vector chat; for (auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); - std::vector buf; + std::vector buf(alloc_size); // run the first time to get the total output length int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); @@ -3001,7 +3003,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } - const std::string formatted_chat(buf.data(), res); + std::string formatted_chat(buf.data(), res); return formatted_chat; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 36f060401..e1f0a1a12 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -37,7 +37,6 @@ static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; -static std::vector * g_chat_msgs; static bool is_interacting = false; static bool file_exists(const std::string & path) { @@ -118,13 +117,13 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v LOG_TEE("%s", text); } -static std::string chat_add_and_format(std::string role, std::string content) { +static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { llama_chat_msg new_msg{role, content}; auto formatted = llama_chat_format_single( - *g_model, g_params->chat_template, *g_chat_msgs, new_msg, role == "user"); - g_chat_msgs->push_back({role, content}); + model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + chat_msgs.push_back({role, content}); return formatted; -} +}; int main(int argc, char ** argv) { gpt_params params; @@ -202,7 +201,6 @@ int main(int argc, char ** argv) { std::vector chat_msgs; g_model = &model; g_ctx = &ctx; - g_chat_msgs = &chat_msgs; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -262,7 +260,7 @@ int main(int argc, char ** argv) { { auto prompt = params.conversation - ? chat_add_and_format("system", params.prompt) // format the system prompt in conversation mode + ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); @@ -810,7 +808,7 @@ int main(int argc, char ** argv) { is_antiprompt = true; } - chat_add_and_format("system", assistant_ss.str()); + chat_add_and_format(model, chat_msgs, "system", assistant_ss.str()); is_interacting = true; printf("\n"); } @@ -873,8 +871,8 @@ int main(int argc, char ** argv) { } std::string user_inp = params.conversation - ? chat_add_and_format("user", buffer) - : buffer; + ? chat_add_and_format(model, chat_msgs, "user", buffer) + : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);