llama : update llama_chat_apply_template
ggml-ci
This commit is contained in:
parent
22b31cd16d
commit
68db76595e
8 changed files with 31 additions and 34 deletions
|
@ -1648,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
|
|||
|
||||
bool common_chat_verify_template(const std::string & tmpl) {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
|
@ -1659,16 +1659,16 @@ std::string common_chat_apply_template(const struct llama_model * model,
|
|||
int alloc_size = 0;
|
||||
bool fallback = false; // indicate if we must fallback to default chatml
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (auto & msg : msgs) {
|
||||
for (const auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
|
@ -1676,18 +1676,17 @@ std::string common_chat_apply_template(const struct llama_model * model,
|
|||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
} else {
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
fallback = true;
|
||||
}
|
||||
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
fallback = true;
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(
|
||||
fallback ? nullptr : model,
|
||||
fallback ? "chatml" : ptr_tmpl,
|
||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
|
|
@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
|||
// Function to apply the chat template and resize `formatted` if needed
|
||||
static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
||||
int result = llama_chat_apply_template(
|
||||
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||
llama_data.fmtted.resize(result);
|
||||
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
|
||||
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
|
||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||
llama_data.fmtted.size());
|
||||
}
|
||||
|
|
|
@ -1740,7 +1740,8 @@ struct server_context {
|
|||
|
||||
bool validate_builtin_chat_template() const {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||
const char * tmpl = llama_model_chat_template(model);
|
||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
||||
return chat_res > 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -161,12 +161,14 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
const char * tmpl = llama_model_chat_template(model);
|
||||
|
||||
// add the user input to the message list and format it
|
||||
messages.push_back({"user", strdup(user.c_str())});
|
||||
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
||||
int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
||||
if (new_len > (int)formatted.size()) {
|
||||
formatted.resize(new_len);
|
||||
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
||||
new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
||||
}
|
||||
if (new_len < 0) {
|
||||
fprintf(stderr, "failed to apply the chat template\n");
|
||||
|
@ -183,7 +185,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// add the response to the messages
|
||||
messages.push_back({"assistant", strdup(response.c_str())});
|
||||
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
|
||||
prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
|
||||
if (prev_len < 0) {
|
||||
fprintf(stderr, "failed to apply the chat template\n");
|
||||
return 1;
|
||||
|
|
|
@ -489,6 +489,9 @@ extern "C" {
|
|||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
|
||||
// Get the default chat template. Returns nullptr if not available
|
||||
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
|
||||
|
||||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
|
@ -1009,9 +1012,7 @@ extern "C" {
|
|||
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
|
||||
/// @param length The size of the allocated buffer
|
||||
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
|
||||
/// TODO: change to llama_vocab
|
||||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * tmpl,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
|
|
|
@ -3839,6 +3839,15 @@ uint64_t llama_model_size(const struct llama_model * model) {
|
|||
return model->size();
|
||||
}
|
||||
|
||||
const char * llama_model_chat_template(const struct llama_model * model) {
|
||||
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
|
||||
if (it == model->gguf_kv.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.c_str();
|
||||
}
|
||||
|
||||
uint64_t llama_model_n_params(const struct llama_model * model) {
|
||||
return model->n_elements();
|
||||
}
|
||||
|
|
|
@ -9834,27 +9834,13 @@ int32_t llama_decode(
|
|||
//
|
||||
|
||||
int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * tmpl,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
int32_t length) {
|
||||
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
||||
if (tmpl == nullptr) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
|
||||
// load template from model, if available
|
||||
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
|
||||
if (it != model->gguf_kv.end() && it->second.size() > 0) {
|
||||
curr_tmpl = it->second;
|
||||
}
|
||||
else {
|
||||
// worst case: there is no information about template, we will use chatml by default
|
||||
curr_tmpl = "chatml"; // see llm_chat_apply_template
|
||||
}
|
||||
}
|
||||
const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
|
||||
|
||||
// format the chat to string
|
||||
std::vector<const llama_chat_message *> chat_vec;
|
||||
|
|
|
@ -157,7 +157,7 @@ int main(void) {
|
|||
}
|
||||
|
||||
// test invalid chat template
|
||||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
|
||||
res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
|
||||
assert(res < 0);
|
||||
|
||||
for (size_t i = 0; i < templates.size(); i++) {
|
||||
|
@ -165,7 +165,6 @@ int main(void) {
|
|||
std::string expected = expected_output[i];
|
||||
formatted_chat.resize(1024);
|
||||
res = llama_chat_apply_template(
|
||||
nullptr,
|
||||
custom_template.c_str(),
|
||||
conversation,
|
||||
message_count,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue