add chat_get_added_part

This commit is contained in:
ngxson 2024-04-21 18:13:42 +02:00
parent 3b8f1ec4b1
commit eb9a1ff63d
3 changed files with 66 additions and 1 deletions

View file

@ -2947,3 +2947,46 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
return result;
}
// apply chat template for (n_msgs - 1) and (n_msgs), then get the added part
std::string chat_get_added_part(const std::vector<chat_message> & messages, const std::string & tmpl) {
auto apply_chat_template = [&tmpl](const std::vector<chat_message> & msgs, size_t delta, bool add_ass) {
std::vector<llama_chat_message> chat(msgs.size());
size_t alloc_size = 0;
size_t chat_size = chat.size() - delta;
for (size_t i = 0; i < msgs.size(); ++i) {
chat[i].role = msgs[i].role.c_str();
chat[i].content = msgs[i].content.c_str();
alloc_size += msgs[i].role.size() + msgs[i].content.size();
}
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size());
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size());
}
const std::string formatted_chat(buf.data(), res);
return formatted_chat;
};
std::string formatted_chat_last = messages.size() > 0
? apply_chat_template(messages, 1, false) // (n_msgs - 1) messages
: "";
std::string formatted_chat_curr = apply_chat_template(messages, 0, true);
// Extract the added part (user prompt)
auto get_diff_part = [](const std::string & str1, const std::string & str2) {
size_t i = 0;
while (i < str1.size() && i < str2.size() && str1[i] == str2[i])
++i;
return str2.substr(i);
};
return get_diff_part(formatted_chat_last, formatted_chat_curr);
}

View file

@ -322,3 +322,13 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
//
// Chat templates utils
//
typedef struct chat_message {
std::string role;
std::string content;
} chat_message;
std::string chat_get_added_part(const std::vector<chat_message> & messages, const std::string & tmpl);

View file

@ -7,6 +7,7 @@
#include <cassert>
#include "llama.h"
#include "common.h"
int main(void) {
llama_chat_message conversation[] = {
@ -96,8 +97,19 @@ int main(void) {
);
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
std::cout << output << "\n-------------------------\n";
std::cout << output << "\n-----\n";
assert(output == expected);
std::vector<chat_message> v_messages;
for (size_t i = 0; i < message_count; ++i) {
v_messages.push_back({
conversation[i].role,
conversation[i].content,
});
}
std::cout << "chat_get_added_part(): " << chat_get_added_part(v_messages, custom_template);
std::cout << "\n-------------------------\n";
// TODO: chat_get_added_part is currently printed for debugging. Should we add tests for it in the future?
}
return 0;
}