Merge branch 'tool-call' into tool-call-handler

This commit is contained in:
ochafik 2025-01-26 15:32:53 +00:00
commit 11594557e3
4 changed files with 58 additions and 50 deletions

View file

@ -254,7 +254,7 @@ int main(int argc, char ** argv) {
} }
} }
const bool add_bos = llama_vocab_get_add_bos(vocab); const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja;
if (!llama_model_has_encoder(model)) { if (!llama_model_has_encoder(model)) {
GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
} }

View file

@ -2787,6 +2787,11 @@ struct server_context {
// track if given slot can be batched with slots already in the batch // track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr; server_slot * slot_batched = nullptr;
auto accept_special_token = [&](llama_token token) {
const auto & trigger_tokens = params_base.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences // frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) { if (slot.state != SLOT_STATE_GENERATING) {
@ -3150,7 +3155,7 @@ struct server_context {
completion_token_output result; completion_token_output result;
result.tok = id; result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -3239,7 +3244,7 @@ struct server_context {
completion_token_output result; completion_token_output result;
result.tok = ids[i]; result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
result.prob = 1.0f; // set later result.prob = 1.0f; // set later
// TODO: set result.probs // TODO: set result.probs

View file

@ -207,13 +207,13 @@ PYTHON_TOOL = {
"type": "function", "type": "function",
"function": { "function": {
"name": "python", "name": "python",
"description": "Runs code in a Python interpreter and returns the result of the execution after 60 seconds.", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"code": { "code": {
"type": "string", "type": "string",
"description": "The code to run in the Python interpreter." "description": "The code to run in the ipython interpreter."
} }
}, },
"required": ["code"] "required": ["code"]
@ -226,23 +226,23 @@ CODE_INTEPRETER_TOOL = {
} }
@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ @pytest.mark.parametrize("template_name,n_predict,tool,argument_key", [
("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, {"success": True} ), ("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, {"code": ". She was so excited to go to the park and climble agace. She was so excited to go to the park and play with her friends.\nThey played together and had lots of fun. They were very happy. At the park, they found the park and had a great time. After a while, they found"} ), ("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {"success": True} ), ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "It's a spector."} ), ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, "code"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {"success": True} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, "code"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {"success": True} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, "code"),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {"success": True} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, "success"),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spector."} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, "code"),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {"success": True} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spectork."} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, "code"),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {"success": True} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a speciachy!"} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, "code"),
]) ])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, argument_key: str):
global server global server
# server = ServerPreset.stories15m_moe() # server = ServerPreset.stories15m_moe()
server.jinja = True server.jinja = True
@ -269,7 +269,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"] assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
@ -308,30 +308,31 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [ @pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, {"code": "print('Hello, World!'}"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!'}"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print(\"hello world\")"}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), (PYTHON_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)),
# TODO: fix this model # TODO: fix this model
# (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), # (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
# (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)), # (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
]) ])
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server global server
server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 server.n_ctx = 8192
server.n_predict = 128 server.n_predict = 128
@ -346,12 +347,13 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
"max_tokens": 256, "max_tokens": 256,
"messages": [ "messages": [
{"role": "system", "content": "You are a coding assistant."}, {"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "say hello world with python"}, # {"role": "user", "content": "say hello world with python"},
{"role": "user", "content": "Print a hello world message with python"},
], ],
"tools": [tool], "tools": [tool],
"temperature": 0.0, "temperature": 0.5,
"top_k": 1, "top_k": 10,
"top_p": 1.0, "top_p": 0.9,
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]
@ -361,7 +363,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
if tool["type"] == "function": if tool["type"] == "function":
assert tool["function"]["name"] == tool_call["function"]["name"] assert tool["function"]["name"] == tool_call["function"]["name"]
elif tool["type"] == "code_interpreter": elif tool["type"] == "code_interpreter":
assert tool_call["function"]["name"] == "python" assert re.match('i?python', tool_call["function"]["name"])
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"

View file

@ -1155,15 +1155,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr); GGML_ASSERT(grammar.vocab != nullptr);
const auto & piece = grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) { if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false; grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear(); grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); llama_grammar_accept_str(grammar, piece);
return; return;
} else { } else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from). // TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += grammar.vocab->token_to_piece(token); grammar.trigger_buffer += piece;
for (const auto & word : grammar.trigger_words) { for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word); auto pos = grammar.trigger_buffer.find(word);
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -1187,7 +1189,6 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
const std::string & piece = grammar.vocab->token_to_piece(token);
llama_grammar_accept_str(grammar, piece); llama_grammar_accept_str(grammar, piece);
} }