diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 03e95a78b..418d5e5be 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - std::shared_ptr chat_parser; virtual int get_index() override { return index; @@ -1191,7 +1190,6 @@ struct server_slot { std::string stopping_word; - std::shared_ptr chat_parser; // sampling json json_schema; @@ -1200,6 +1198,8 @@ struct server_slot { llama_token sampled; + common_chat_parser chat_parser; + // stats size_t n_sent_text = 0; // number of sent text character @@ -3998,8 +3998,6 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - LOG_INF("Request: %s\n", body.dump(2).c_str()); - json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); return handle_completions_impl( diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 57c053e5d..117fd2da8 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -61,28 +61,7 @@ WEATHER_TOOL = { } -@pytest.mark.parametrize("template_name,tool,argument_key", [ - ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), - ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), - ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), - ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), - ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), - ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), - ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), - ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), - ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), - ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), - # TODO: fix these -]) -def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): +def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): n_predict = 512 global server # server = ServerPreset.stories15m_moe() @@ -117,6 +96,40 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): + do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + + +@pytest.mark.slow +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), + ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), + ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None): + do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + + @pytest.mark.slow @pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [ (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), @@ -154,7 +167,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str if template_override: (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start() res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, @@ -183,18 +196,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" -@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ - ("meetkai-functionary-medium-v3.1", 128, [], None), - ("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None), - ("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'), - ("meetkai-functionary-medium-v3.2", 128, [], None), - ("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None), - ("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [], None), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [TEST_TOOL], None), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [PYTHON_TOOL], 'none'), -]) -def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): +def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): global server server.jinja = True server.n_predict = n_predict @@ -217,6 +219,31 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + + +@pytest.mark.slow +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meetkai-functionary-medium-v3.1", 128, [], None), + ("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'), + ("meetkai-functionary-medium-v3.2", 128, [], None), + ("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'), + ("meta-llama-Llama-3.2-3B-Instruct", 128, [], None), + ("meta-llama-Llama-3.2-3B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Llama-3.2-3B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + + @pytest.mark.slow @pytest.mark.parametrize("hf_repo,hf_file,template_override", [ ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), @@ -243,7 +270,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[ if template_override: (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start(timeout_seconds=15*60) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": 256, @@ -292,7 +319,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_ if template_override: (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start(timeout_seconds=15*60) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": 256, diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b6e4e1def..7593b4691 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -596,6 +596,11 @@ static json oaicompat_completion_params_parse( throw std::runtime_error("tools param requires --jinja flag"); } } + if (!use_jinja) { + if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { + throw std::runtime_error("Unsupported param: tool_choice"); + } + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -605,7 +610,6 @@ static json oaicompat_completion_params_parse( } // Handle "response_format" field - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); if (body.contains("response_format")) { json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); @@ -649,16 +653,6 @@ static json oaicompat_completion_params_parse( throw std::runtime_error("top_logprobs requires logprobs to be set to true"); } - // Params supported by OAI but unsupported by llama.cpp - if (!use_jinja) { - static const std::vector unsupported_params { "tool_choice" }; - for (const auto & param : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); - } - } - } - // Copy remaining properties to llama_params // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp diff --git a/scripts/get_hf_chat_template.py b/scripts/get_chat_template.py similarity index 86% rename from scripts/get_hf_chat_template.py rename to scripts/get_chat_template.py index 23bb1de59..fbea9c927 100644 --- a/scripts/get_hf_chat_template.py +++ b/scripts/get_chat_template.py @@ -4,12 +4,12 @@ If a model has multiple chat templates, you can specify the variant name. Syntax: - ./scripts/get_hf_chat_template.py model_id [variant] + ./scripts/get_chat_template.py model_id [variant] Examples: - ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct - ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use - ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct + ./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct + ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use + ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct ''' import json @@ -17,7 +17,7 @@ import re import sys -def get_hf_chat_template(model_id, variant=None): +def get_chat_template(model_id, variant=None): try: # Use huggingface_hub library if available. # Allows access to gated models if the user has access and ran `huggingface-cli login`. @@ -69,9 +69,10 @@ def main(args): model_id = args[0] variant = None if len(args) < 2 else args[1] - template = get_hf_chat_template(model_id, variant) + template = get_chat_template(model_id, variant) sys.stdout.write(template) if __name__ == '__main__': main(sys.argv[1:]) + diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 589324a85..cd5798773 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index dfd0f4764..4ebde1452 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -118,8 +118,8 @@ struct llama_grammar { // lazy grammars wait for trigger words or tokens before constraining the sampling. // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. // (useful e.g. for tool_choice=required) - bool lazy; // Useful when resetting - bool awaiting_trigger; // Initialized to lazy + bool lazy; + bool awaiting_trigger; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_words; diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index aef5fbd22..c6ea5c02e 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -169,9 +169,6 @@ struct delta_data { }; static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { - fprintf(stderr, "Template source: %s\n", tmpl.source().c_str()); - fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str()); - common_chat_params params; params.parallel_tool_calls = true; params.messages = json::array(); @@ -209,12 +206,14 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto return {delta, full_data.grammar, full_data.parser}; } +/* + Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false, + gets the diff, removes any end tokens and parses the result w/ the grammar, checking that + the parsed message is the same as the test_message +*/ static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) { - // auto tool_call_style = common_tool_call_style_detect(tmpl); common_chat_msg expected_msg = msg_from_json(test_message); - // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, - // get the diff and try and parse it w/ the grammar. auto user_message = json { {"role", "user"}, {"content", "Hello, world!"} @@ -228,7 +227,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector