diff --git a/common/chat-template.hpp b/common/chat-template.hpp index dfd46d750..2c3d96c36 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -270,6 +270,28 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + std::string apply( const chat_template_inputs & inputs, const chat_template_options & opts = chat_template_options()) const diff --git a/common/chat.cpp b/common/chat.cpp index a7a51b645..ca9693655 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -175,6 +175,28 @@ static void foreach_function(const json & tools, const std::function", "<|END_ACTION|>", }; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } @@ -489,7 +511,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); @@ -568,7 +590,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ "<|tool▁call▁end|>", }; }, grammar_options); - auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); // Hacks to fix the official (broken) prompt. // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, @@ -614,10 +636,10 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { fprintf(stderr, "%s\n", __func__); common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, - }, /* adjust_inputs= */ false); + }); if (!inputs.tools.is_null() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -661,7 +683,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (!inputs.tools.is_null() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; @@ -788,7 +810,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.grammar_triggers.push_back({"" }; }, grammar_options); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } @@ -904,7 +926,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.grammar_lazy = false; if (!inputs.json_schema.is_null()) { diff --git a/examples/run/run.cpp b/examples/run/run.cpp index ca9273155..39353ba30 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -848,7 +848,15 @@ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & ll }); } try { - auto result = tmpl.apply(messages, /* tools= */ json(), append); + minja::chat_template_inputs tmpl_inputs; + tmpl_inputs.messages = messages; + tmpl_inputs.add_generation_prompt = append; + + minja::chat_template_options tmpl_opts; + tmpl_opts.use_bos_token = false; + tmpl_opts.use_eos_token = false; + + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); llama_data.fmtted.resize(result.size() + 1); memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); return result.size(); diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 156940d24..424fe8c16 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -340,6 +340,112 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' +@pytest.mark.slow +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), +]) +def test_calc_result(hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server + n_predict = 512 + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, + {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "function": { + "name": "calculate", + "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}" + } + } + ] + }, + { + "role": "tool", + "name": "calculate", + "content": "0.5" + } + ], + "tools": [ + { + "type":"function", + "function":{ + "name":"calculate", + "description":"A calculator function that computes values of arithmetic expressions in the Python syntax", + "parameters":{ + "type":"object", + "properties":{ + "expression":{ + "type":"string", + "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)" + } + }, + "required":["expression"] + } + } + } + ] + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" + location = actual_arguments["location"] + assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" + assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + + @pytest.mark.slow @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),