improve
This commit is contained in:
parent
962be6a834
commit
43cab6bfc6
2 changed files with 12 additions and 12 deletions
|
@ -2984,13 +2984,15 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
const std::vector<llama_chat_msg> & msgs,
|
const std::vector<llama_chat_msg> & msgs,
|
||||||
bool add_ass) {
|
bool add_ass) {
|
||||||
|
int alloc_size = 0;
|
||||||
std::vector<llama_chat_message> chat;
|
std::vector<llama_chat_message> chat;
|
||||||
for (auto & msg : msgs) {
|
for (auto & msg : msgs) {
|
||||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||||
|
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||||
std::vector<char> buf;
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
@ -3001,7 +3003,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string formatted_chat(buf.data(), res);
|
std::string formatted_chat(buf.data(), res);
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,6 @@ static gpt_params * g_params;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::ostringstream * g_output_ss;
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
static std::vector<llama_token> * g_output_tokens;
|
||||||
static std::vector<llama_chat_msg> * g_chat_msgs;
|
|
||||||
static bool is_interacting = false;
|
static bool is_interacting = false;
|
||||||
|
|
||||||
static bool file_exists(const std::string & path) {
|
static bool file_exists(const std::string & path) {
|
||||||
|
@ -118,13 +117,13 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
||||||
LOG_TEE("%s", text);
|
LOG_TEE("%s", text);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string chat_add_and_format(std::string role, std::string content) {
|
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||||
llama_chat_msg new_msg{role, content};
|
llama_chat_msg new_msg{role, content};
|
||||||
auto formatted = llama_chat_format_single(
|
auto formatted = llama_chat_format_single(
|
||||||
*g_model, g_params->chat_template, *g_chat_msgs, new_msg, role == "user");
|
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||||
g_chat_msgs->push_back({role, content});
|
chat_msgs.push_back({role, content});
|
||||||
return formatted;
|
return formatted;
|
||||||
}
|
};
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
@ -202,7 +201,6 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_chat_msg> chat_msgs;
|
std::vector<llama_chat_msg> chat_msgs;
|
||||||
g_model = &model;
|
g_model = &model;
|
||||||
g_ctx = &ctx;
|
g_ctx = &ctx;
|
||||||
g_chat_msgs = &chat_msgs;
|
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||||
|
@ -262,7 +260,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
{
|
{
|
||||||
auto prompt = params.conversation
|
auto prompt = params.conversation
|
||||||
? chat_add_and_format("system", params.prompt) // format the system prompt in conversation mode
|
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
||||||
: params.prompt;
|
: params.prompt;
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
LOG("tokenize the prompt\n");
|
LOG("tokenize the prompt\n");
|
||||||
|
@ -810,7 +808,7 @@ int main(int argc, char ** argv) {
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
chat_add_and_format("system", assistant_ss.str());
|
chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
@ -873,8 +871,8 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string user_inp = params.conversation
|
std::string user_inp = params.conversation
|
||||||
? chat_add_and_format("user", buffer)
|
? chat_add_and_format(model, chat_msgs, "user", buffer)
|
||||||
: buffer;
|
: std::move(buffer);
|
||||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
||||||
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
|
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue