apply renames from jinja branch

This commit is contained in:
ochafik 2025-01-20 23:59:01 +00:00
parent 9bab6939cd
commit b110374714
5 changed files with 7 additions and 44 deletions

View file

@ -1913,38 +1913,6 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
};
}
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
minja::chat_template llama_chat_template_from_model(
const struct llama_model * model,
const std::string & chat_template_override,
bool prefer_tool_use)
{
// TODO: handle "chatml"?
std::string chat_template = chat_template_override;
if (chat_template.empty()) {
if (prefer_tool_use) {
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
}
if (chat_template.empty()) {
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
}
}
const auto vocab = llama_model_get_vocab(model);
auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true);
auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true);
return {std::move(chat_template), bos_token, eos_token};
}
//
// KV cache utils
//

View file

@ -645,11 +645,6 @@ std::string common_chat_format_example(
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
llama_chat_template llama_chat_template_from_model(
const struct llama_model * model,
const std::string & chat_template_override = "",
bool prefer_tool_use = false);
//
// KV cache utils
//

View file

@ -74,7 +74,7 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) {
}
}
llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template) {
llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template) {
const auto & src = chat_template.source();
if (src.find("<tool_call>") != std::string::npos) {
@ -399,7 +399,7 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages
llama_tool_call_handler llama_tool_call_handler_init(
llama_tool_call_style style,
const llama_chat_template & tmpl,
const common_chat_template & tmpl,
bool allow_content,
const nlohmann::ordered_json & parallel_tool_calls,
const nlohmann::ordered_json & messages,

View file

@ -41,13 +41,13 @@ struct llama_tool_call_handler {
std::string llama_tool_call_style_name(llama_tool_call_style style);
llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template);
llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template);
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
llama_tool_call_handler llama_tool_call_handler_init(
llama_tool_call_style style,
const llama_chat_template & tmpl,
const common_chat_template & tmpl,
bool allow_content,
const nlohmann::ordered_json & parallel_tool_calls,
const nlohmann::ordered_json & messages,

View file

@ -311,7 +311,7 @@ static void test_parsing() {
}
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
const llama_chat_template tmpl(read_file(template_file), "<s>", "</s>");
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
auto tool_call_style = llama_tool_call_style_detect(tmpl);
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tool_call_style);
@ -331,7 +331,7 @@ static void test_tool_call_style_detection() {
test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic);
}
static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
@ -356,7 +356,7 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
const llama_chat_template tmpl(read_file(template_file), bos_token, eos_token);
const common_chat_template tmpl(read_file(template_file), bos_token, eos_token);
auto tool_call_style = llama_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls");