apply renames from jinja branch
This commit is contained in:
parent
9bab6939cd
commit
b110374714
5 changed files with 7 additions and 44 deletions
|
@ -1913,38 +1913,6 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
|
|||
};
|
||||
}
|
||||
|
||||
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
return std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
minja::chat_template llama_chat_template_from_model(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override,
|
||||
bool prefer_tool_use)
|
||||
{
|
||||
// TODO: handle "chatml"?
|
||||
std::string chat_template = chat_template_override;
|
||||
if (chat_template.empty()) {
|
||||
if (prefer_tool_use) {
|
||||
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
|
||||
}
|
||||
if (chat_template.empty()) {
|
||||
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
}
|
||||
}
|
||||
const auto vocab = llama_model_get_vocab(model);
|
||||
auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true);
|
||||
auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true);
|
||||
return {std::move(chat_template), bos_token, eos_token};
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
|
|
@ -645,11 +645,6 @@ std::string common_chat_format_example(
|
|||
|
||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||
|
||||
llama_chat_template llama_chat_template_from_model(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override = "",
|
||||
bool prefer_tool_use = false);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
|
|
@ -74,7 +74,7 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) {
|
|||
}
|
||||
}
|
||||
|
||||
llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template) {
|
||||
llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template) {
|
||||
const auto & src = chat_template.source();
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
|
@ -399,7 +399,7 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages
|
|||
|
||||
llama_tool_call_handler llama_tool_call_handler_init(
|
||||
llama_tool_call_style style,
|
||||
const llama_chat_template & tmpl,
|
||||
const common_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
const nlohmann::ordered_json & parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
|
|
|
@ -41,13 +41,13 @@ struct llama_tool_call_handler {
|
|||
|
||||
std::string llama_tool_call_style_name(llama_tool_call_style style);
|
||||
|
||||
llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template);
|
||||
llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template);
|
||||
|
||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||
|
||||
llama_tool_call_handler llama_tool_call_handler_init(
|
||||
llama_tool_call_style style,
|
||||
const llama_chat_template & tmpl,
|
||||
const common_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
const nlohmann::ordered_json & parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
|
|
|
@ -311,7 +311,7 @@ static void test_parsing() {
|
|||
}
|
||||
|
||||
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
|
||||
const llama_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
||||
auto tool_call_style = llama_tool_call_style_detect(tmpl);
|
||||
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
||||
assert_equals(expected, tool_call_style);
|
||||
|
@ -331,7 +331,7 @@ static void test_tool_call_style_detection() {
|
|||
test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic);
|
||||
}
|
||||
|
||||
static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
||||
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
||||
|
||||
|
@ -356,7 +356,7 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co
|
|||
|
||||
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
||||
const llama_chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
||||
const common_chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
||||
auto tool_call_style = llama_tool_call_style_detect(tmpl);
|
||||
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue