server: simplify format_chat

This commit is contained in:
ngxson 2024-06-22 20:30:33 +02:00
parent c91f972775
commit 317452730d

View file

@ -118,36 +118,17 @@ static inline void server_log(const char * level, const char * function, int lin
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) { inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
size_t alloc_size = 0; std::vector<llama_chat_msg> chat;
// vector holding all allocated string to be passed to llama_chat_apply_template
std::vector<std::string> str(messages.size() * 2);
std::vector<llama_chat_message> chat(messages.size());
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i]; const auto & curr_msg = messages[i];
str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); std::string role = json_value(curr_msg, "role", std::string(""));
str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); std::string content = json_value(curr_msg, "content", std::string(""));
alloc_size += str[i*2 + 1].length(); chat.push_back({role, content});
chat[i].role = str[i*2 + 0].c_str();
chat[i].content = str[i*2 + 1].c_str();
} }
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); auto formatted_chat = llama_chat_format(model, tmpl, chat, true);
std::vector<char> buf(alloc_size * 2);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
}
const std::string formatted_chat(buf.data(), res);
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
return formatted_chat; return formatted_chat;
} }