From c479d39abde46752b0dff717f716ee457a25de06 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 25 Jan 2025 04:51:53 +0000 Subject: [PATCH 1/5] tool-call: allow special tokens that are grammar triggers --- examples/server/server.cpp | 9 +++++++-- src/llama-grammar.cpp | 7 ++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 939e6c36a..a8ea4d05b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2795,6 +2795,11 @@ struct server_context { // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; + auto accept_special_token = [&](llama_token token) { + const auto & trigger_tokens = params_base.sampling.grammar_trigger_tokens; + return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end(); + }; + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { @@ -3158,7 +3163,7 @@ struct server_context { completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.params.sampling.n_probs > 0) { @@ -3247,7 +3252,7 @@ struct server_context { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 2c1ae0975..501b0037b 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1155,15 +1155,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); + const auto & piece = grammar.vocab->token_to_piece(token); + if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); + llama_grammar_accept_str(grammar, piece); return; } else { // TODO: consider a smarter incremental substring search algorithm (store last position to search from). - grammar.trigger_buffer += grammar.vocab->token_to_piece(token); + grammar.trigger_buffer += piece; for (const auto & word : grammar.trigger_words) { auto pos = grammar.trigger_buffer.find(word); if (pos != std::string::npos) { @@ -1187,7 +1189,6 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->token_to_piece(token); llama_grammar_accept_str(grammar, piece); } From 0208b20767ab96953c5d56e99ccfe5c55553c477 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 25 Jan 2025 04:52:03 +0000 Subject: [PATCH 2/5] Update test_chat_completion.py --- .../server/tests/unit/test_chat_completion.py | 56 ++++++++++--------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 4bbd10c0e..92286143d 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -207,13 +207,13 @@ PYTHON_TOOL = { "type": "function", "function": { "name": "python", - "description": "Runs code in a Python interpreter and returns the result of the execution after 60 seconds.", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters": { "type": "object", "properties": { "code": { "type": "string", - "description": "The code to run in the Python interpreter." + "description": "The code to run in the ipython interpreter." } }, "required": ["code"] @@ -308,30 +308,31 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: @pytest.mark.slow @pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [ - (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")), - (PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (PYTHON_TOOL, {"code": "print('Hello, World!'}"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!'}"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, {"code": "print(\"hello world\")"}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)), # TODO: fix this model - # (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), - # (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)), + # (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + # (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), ]) def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server + server.n_slots = 1 server.jinja = True server.n_ctx = 8192 server.n_predict = 128 @@ -346,12 +347,13 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st "max_tokens": 256, "messages": [ {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "say hello world with python"}, + # {"role": "user", "content": "say hello world with python"}, + {"role": "user", "content": "Print a hello world message with python"}, ], "tools": [tool], - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + "temperature": 0.5, + "top_k": 10, + "top_p": 0.9, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -361,7 +363,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st if tool["type"] == "function": assert tool["function"]["name"] == tool_call["function"]["name"] elif tool["type"] == "code_interpreter": - assert tool_call["function"]["name"] == "python" + assert re.match('i?python', tool_call["function"]["name"]) actual_arguments = json.loads(tool_call["function"]["arguments"]) assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" From a6463c1e353358c320edf39cebd65dfba8463b8b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 25 Jan 2025 04:52:42 +0000 Subject: [PATCH 3/5] jinja: don't add bos when jinja enabled --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1e2e98b64..e654d3542 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -254,7 +254,7 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_vocab_get_add_bos(vocab); + const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja; if (!llama_model_has_encoder(model)) { GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); } From 51b7aab841aa48d31ae5ef875c36439a066376dd Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 25 Jan 2025 04:57:40 +0000 Subject: [PATCH 4/5] Update test_chat_completion.py --- .../server/tests/unit/test_chat_completion.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 92286143d..399d8b937 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -226,23 +226,23 @@ CODE_INTEPRETER_TOOL = { } -@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ - ("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, {"success": True} ), - ("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, {"code": ". She was so excited to go to the park and climble agace. She was so excited to go to the park and play with her friends.\nThey played together and had lots of fun. They were very happy. At the park, they found the park and had a great time. After a while, they found"} ), - ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {"success": True} ), - ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "It's a spector."} ), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {"success": True} ), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {"success": True} ), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {"success": True} ), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spector."} ), - ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {"success": True} ), - ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spectork."} ), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {"success": True} ), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a speciachy!"} ), +@pytest.mark.parametrize("template_name,n_predict,tool,argument_key", [ + ("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, "code"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, "code"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, "code"), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, "success"), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, "code"), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, "code"), ]) -def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): +def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, argument_key: str): global server # server = ServerPreset.stories15m_moe() server.jinja = True @@ -269,7 +269,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: tool_call = tool_calls[0] assert tool["function"]["name"] == tool_call["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) - assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ From 3f3fc0398344bb9f3b5cd5d7a79341bc11217cb7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 26 Jan 2025 15:32:13 +0000 Subject: [PATCH 5/5] nit: trailing spaces --- src/llama-grammar.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 501b0037b..2eae29bb9 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1156,7 +1156,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ASSERT(grammar.vocab != nullptr); const auto & piece = grammar.vocab->token_to_piece(token); - + if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false;