Merge eb9a1ff63d
into b18532a4ef
This commit is contained in:
commit
da9e19ff6f
3 changed files with 66 additions and 1 deletions
|
@ -3039,3 +3039,46 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
// apply chat template for (n_msgs - 1) and (n_msgs), then get the added part
|
||||
std::string chat_get_added_part(const std::vector<chat_message> & messages, const std::string & tmpl) {
|
||||
auto apply_chat_template = [&tmpl](const std::vector<chat_message> & msgs, size_t delta, bool add_ass) {
|
||||
std::vector<llama_chat_message> chat(msgs.size());
|
||||
size_t alloc_size = 0;
|
||||
size_t chat_size = chat.size() - delta;
|
||||
for (size_t i = 0; i < msgs.size(); ++i) {
|
||||
chat[i].role = msgs[i].role.c_str();
|
||||
chat[i].content = msgs[i].content.c_str();
|
||||
alloc_size += msgs[i].role.size() + msgs[i].content.size();
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||
std::vector<char> buf(alloc_size * 2);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size());
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
const std::string formatted_chat(buf.data(), res);
|
||||
return formatted_chat;
|
||||
};
|
||||
|
||||
std::string formatted_chat_last = messages.size() > 0
|
||||
? apply_chat_template(messages, 1, false) // (n_msgs - 1) messages
|
||||
: "";
|
||||
std::string formatted_chat_curr = apply_chat_template(messages, 0, true);
|
||||
|
||||
// Extract the added part (user prompt)
|
||||
auto get_diff_part = [](const std::string & str1, const std::string & str2) {
|
||||
size_t i = 0;
|
||||
while (i < str1.size() && i < str2.size() && str1[i] == str2[i])
|
||||
++i;
|
||||
return str2.substr(i);
|
||||
};
|
||||
return get_diff_part(formatted_chat_last, formatted_chat_curr);
|
||||
}
|
||||
|
|
|
@ -336,3 +336,13 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
|
|||
static const char * const LLM_KV_SPLIT_NO = "split.no";
|
||||
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
||||
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
//
|
||||
// Chat templates utils
|
||||
//
|
||||
typedef struct chat_message {
|
||||
std::string role;
|
||||
std::string content;
|
||||
} chat_message;
|
||||
|
||||
std::string chat_get_added_part(const std::vector<chat_message> & messages, const std::string & tmpl);
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <cassert>
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
int main(void) {
|
||||
llama_chat_message conversation[] = {
|
||||
|
@ -104,8 +105,19 @@ int main(void) {
|
|||
);
|
||||
formatted_chat.resize(res);
|
||||
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||
std::cout << output << "\n-------------------------\n";
|
||||
std::cout << output << "\n-----\n";
|
||||
assert(output == expected);
|
||||
|
||||
std::vector<chat_message> v_messages;
|
||||
for (size_t i = 0; i < message_count; ++i) {
|
||||
v_messages.push_back({
|
||||
conversation[i].role,
|
||||
conversation[i].content,
|
||||
});
|
||||
}
|
||||
std::cout << "chat_get_added_part(): " << chat_get_added_part(v_messages, custom_template);
|
||||
std::cout << "\n-------------------------\n";
|
||||
// TODO: chat_get_added_part is currently printed for debugging. Should we add tests for it in the future?
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue