From b11037471422f8db70903ea52d5ed8f47e99967d Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 23:59:01 +0000 Subject: [PATCH] apply renames from jinja branch --- common/common.cpp | 32 -------------------------------- common/common.h | 5 ----- common/tool-call.cpp | 4 ++-- common/tool-call.h | 4 ++-- tests/test-tool-call.cpp | 6 +++--- 5 files changed, 7 insertions(+), 44 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a00927b42..046e236f2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1913,38 +1913,6 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model }; } -static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { - int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - return std::string(curr_tmpl_buf.data(), tlen); - } - } - return ""; -} - -minja::chat_template llama_chat_template_from_model( - const struct llama_model * model, - const std::string & chat_template_override, - bool prefer_tool_use) -{ - // TODO: handle "chatml"? - std::string chat_template = chat_template_override; - if (chat_template.empty()) { - if (prefer_tool_use) { - chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); - } - if (chat_template.empty()) { - chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template"); - } - } - const auto vocab = llama_model_get_vocab(model); - auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); - auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); - return {std::move(chat_template), bos_token, eos_token}; -} - // // KV cache utils // diff --git a/common/common.h b/common/common.h index c83df8063..3035dfb24 100644 --- a/common/common.h +++ b/common/common.h @@ -645,11 +645,6 @@ std::string common_chat_format_example( common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); -llama_chat_template llama_chat_template_from_model( - const struct llama_model * model, - const std::string & chat_template_override = "", - bool prefer_tool_use = false); - // // KV cache utils // diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 26bb60479..0c2e802bd 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -74,7 +74,7 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) { } } -llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template) { +llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template) { const auto & src = chat_template.source(); if (src.find("") != std::string::npos) { @@ -399,7 +399,7 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, - const llama_chat_template & tmpl, + const common_chat_template & tmpl, bool allow_content, const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & messages, diff --git a/common/tool-call.h b/common/tool-call.h index f96ed2b1f..b83faa772 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -41,13 +41,13 @@ struct llama_tool_call_handler { std::string llama_tool_call_style_name(llama_tool_call_style style); -llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template); +llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template); llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, - const llama_chat_template & tmpl, + const common_chat_template & tmpl, bool allow_content, const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & messages, diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 2230bfa65..95762395b 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -311,7 +311,7 @@ static void test_parsing() { } static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { - const llama_chat_template tmpl(read_file(template_file), "", ""); + const common_chat_template tmpl(read_file(template_file), "", ""); auto tool_call_style = llama_tool_call_style_detect(tmpl); std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; assert_equals(expected, tool_call_style); @@ -331,7 +331,7 @@ static void test_tool_call_style_detection() { test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic); } -static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { +static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); @@ -356,7 +356,7 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { std::cout << "# Testing template: " << template_file << std::endl << std::flush; - const llama_chat_template tmpl(read_file(template_file), bos_token, eos_token); + const common_chat_template tmpl(read_file(template_file), bos_token, eos_token); auto tool_call_style = llama_tool_call_style_detect(tmpl); auto & tool_calls = tool_calling_message.at("tool_calls");