improve
This commit is contained in:
parent
962be6a834
commit
43cab6bfc6
2 changed files with 12 additions and 12 deletions
|
@ -2984,13 +2984,15 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
|||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & msgs,
|
||||
bool add_ass) {
|
||||
int alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||
std::vector<char> buf;
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
@ -3001,7 +3003,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
|||
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
const std::string formatted_chat(buf.data(), res);
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
return formatted_chat;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ static gpt_params * g_params;
|
|||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
static std::vector<llama_token> * g_output_tokens;
|
||||
static std::vector<llama_chat_msg> * g_chat_msgs;
|
||||
static bool is_interacting = false;
|
||||
|
||||
static bool file_exists(const std::string & path) {
|
||||
|
@ -118,13 +117,13 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
|||
LOG_TEE("%s", text);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(std::string role, std::string content) {
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||
llama_chat_msg new_msg{role, content};
|
||||
auto formatted = llama_chat_format_single(
|
||||
*g_model, g_params->chat_template, *g_chat_msgs, new_msg, role == "user");
|
||||
g_chat_msgs->push_back({role, content});
|
||||
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
return formatted;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
@ -202,7 +201,6 @@ int main(int argc, char ** argv) {
|
|||
std::vector<llama_chat_msg> chat_msgs;
|
||||
g_model = &model;
|
||||
g_ctx = &ctx;
|
||||
g_chat_msgs = &chat_msgs;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
|
@ -262,7 +260,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
{
|
||||
auto prompt = params.conversation
|
||||
? chat_add_and_format("system", params.prompt) // format the system prompt in conversation mode
|
||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
||||
: params.prompt;
|
||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||
LOG("tokenize the prompt\n");
|
||||
|
@ -810,7 +808,7 @@ int main(int argc, char ** argv) {
|
|||
is_antiprompt = true;
|
||||
}
|
||||
|
||||
chat_add_and_format("system", assistant_ss.str());
|
||||
chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
|
||||
is_interacting = true;
|
||||
printf("\n");
|
||||
}
|
||||
|
@ -873,8 +871,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
std::string user_inp = params.conversation
|
||||
? chat_add_and_format("user", buffer)
|
||||
: buffer;
|
||||
? chat_add_and_format(model, chat_msgs, "user", buffer)
|
||||
: std::move(buffer);
|
||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
||||
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue