tool-call: stabilize server tests

This commit is contained in:
ochafik 2024-12-15 00:16:12 +00:00
parent 7bfcd0a8dd
commit 7e3feff073
5 changed files with 53 additions and 57 deletions

View file

@ -646,7 +646,7 @@ class llama_antiprompts {
}; };
std::vector<std::string> stop_words; std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words; std::vector<std::string> grammar_triggers;
private: private:
// The AhoCorasick algorithm allows efficient string matching with multiple patterns. // The AhoCorasick algorithm allows efficient string matching with multiple patterns.
@ -740,25 +740,25 @@ private:
stop_tokens.clear(); stop_tokens.clear();
} }
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) { void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
build( build(
[&](const std::string & text) { [&](const std::string & text) {
return common_tokenize(ctx, text, /* special= */ true); return common_tokenize(ctx, text, /* special= */ true);
}, },
stop_words, stop_words,
grammar_trigger_words grammar_triggers
); );
} }
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) { void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
clear(); clear();
this->stop_words = stop_words; this->stop_words = stop_words;
this->grammar_trigger_words = grammar_trigger_words; this->grammar_triggers = grammar_triggers;
for (const std::string & stop_word : stop_words) { for (const std::string & stop_word : stop_words) {
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false}); antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
} }
for (const std::string & trigger : grammar_trigger_words) { for (const std::string & trigger : grammar_triggers) {
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true}); antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
} }

View file

@ -520,7 +520,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) { if (!parallel) {
schema["maxItems"] = 1; schema["maxItems"] = 1;
} }
builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema)); builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}); });
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back("[TOOL_CALLS]"); handler.grammar_triggers.push_back("[TOOL_CALLS]");

View file

@ -93,7 +93,6 @@ struct slot_params {
json input_prefix; json input_prefix;
json input_suffix; json input_suffix;
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
std::vector<std::string> grammar_triggers;
bool timings_per_token = false; bool timings_per_token = false;
bool ignore_eos = false; bool ignore_eos = false;
@ -318,47 +317,39 @@ struct server_task {
} }
} }
if (data.contains("grammar_triggers")) { auto to_string_vec = [](const json & j) {
const auto & triggers = data.at("grammar_triggers"); std::vector<std::string> out;
if (triggers.is_array()) { if (j.is_array()) {
for (const auto & trigger : triggers) { for (const auto & e : j) {
if (trigger.is_string()) { if (e.is_string()) {
params.grammar_triggers.push_back(trigger); out.push_back(e);
} }
} }
} }
return out;
};
{
const auto grammar_trigger_words = data.find("grammar_trigger_words");
if (grammar_trigger_words != data.end()) {
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
}
} }
{ {
params.antiprompt.clear(); const auto stop = data.find("stop");
if (stop != data.end()) {
const auto & stop = data.find("stop"); params.antiprompt = to_string_vec(*stop);
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
}
} }
} }
{ {
const auto & samplers = data.find("samplers"); const auto samplers = data.find("samplers");
if (samplers != data.end()) { if (samplers != data.end()) {
if (samplers->is_array()) { if (samplers->is_array()) {
std::vector<std::string> sampler_names; params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false);
for (const auto & name : *samplers) {
if (name.is_string()) {
sampler_names.emplace_back(name);
}
}
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
} else if (samplers->is_string()){ } else if (samplers->is_string()){
std::string sampler_string; params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
for (const auto & name : *samplers) {
sampler_string += name;
}
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
} }
} else { } else {
params.sampling.samplers = defaults.sampling.samplers; params.sampling.samplers = defaults.sampling.samplers;
@ -546,7 +537,7 @@ struct server_task_result_cmpl_final : server_task_result {
llama_tool_calls parsed_tool_calls; llama_tool_calls parsed_tool_calls;
json tool_calls; json tool_calls;
json message_content; json message_content;
if (!oaicompat_tools.is_null()) { if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) {
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) { if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls"; finish_reason = "tool_calls";
@ -1759,7 +1750,7 @@ struct server_context {
{ {
slot.antiprompts.clear(); slot.antiprompts.clear();
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.grammar_triggers); slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
} }
{ {
@ -1805,7 +1796,7 @@ struct server_context {
if (match.pos != std::string::npos && !match.is_partial) { if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) { if (match.is_grammar_trigger) {
common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params_base.special)); common_sampler_trigger_grammar(model, slot.smpl, token_str);
} else { } else {
// slot.stopped_word = true; // slot.stopped_word = true;
slot.stopping_word = match.pattern; slot.stopping_word = match.pattern;
@ -2014,7 +2005,7 @@ struct server_context {
{"mirostat_eta", slot.params.sampling.mirostat_eta}, {"mirostat_eta", slot.params.sampling.mirostat_eta},
{"penalize_nl", slot.params.sampling.penalize_nl}, {"penalize_nl", slot.params.sampling.penalize_nl},
{"stop", slot.params.antiprompt}, {"stop", slot.params.antiprompt},
{"grammar_trigger", slot.params.grammar_triggers}, {"grammar_trigger_words", slot.params.sampling.grammar_trigger_words},
{"max_tokens", slot.params.n_predict}, // User configured n_predict {"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard}, {"n_discard", slot.params.n_discard},
@ -3564,7 +3555,7 @@ int main(int argc, char ** argv) {
task.params.oaicompat = oaicompat; task.params.oaicompat = oaicompat;
task.params.oaicompat_chat = oaicompat_chat; task.params.oaicompat_chat = oaicompat_chat;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_tools = json_value(data, "tools", json::array()); task.params.oaicompat_tools = json_value(data, "tools", json());
task.params.oaicompat_tool_call_style = tool_call_style; task.params.oaicompat_tool_call_style = tool_call_style;
// oaicompat_model is already populated by params_from_json_cmpl // oaicompat_model is already populated by params_from_json_cmpl

View file

@ -202,23 +202,24 @@ CODE_INTEPRETER_TOOL = {
@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ @pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [
("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ), ("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and s"} ), ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ),
("meetkai-functionary-medium-v3.2", 32, TEST_TOOL, {} ), ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 32, PYTHON_TOOL, {"code": "Yes,"} ), ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a small cost."} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ),
]) ])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
global server global server
server.use_jinja = True server.use_jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start() server.start()
res = server.make_request("POST", "/chat/completions", data={ res = server.make_request("POST", "/chat/completions", data={
@ -227,13 +228,14 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
{"role": "system", "content": "You are a coding assistant."}, {"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"}, {"role": "user", "content": "Write an example"},
], ],
"tool_choice": tool["function"]["name"], "tool_choice": "required",
"tools": [tool], "tools": [tool],
"parallel_tool_calls": False,
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"] assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])
@ -254,6 +256,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
global server global server
server.use_jinja = True server.use_jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start() server.start()
res = server.make_request("POST", "/chat/completions", data={ res = server.make_request("POST", "/chat/completions", data={
@ -267,7 +270,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]
assert "tool_calls" not in choice["message"], f'Expected no tool call in {choice["message"]}' assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.slow @pytest.mark.slow
@ -296,6 +299,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server global server
server.use_jinja = True server.use_jinja = True
server.n_predict = 128
server.model_hf_repo = hf_repo server.model_hf_repo = hf_repo
server.model_hf_file = hf_file server.model_hf_file = hf_file
if template_override: if template_override:
@ -314,7 +318,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"] assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])

View file

@ -494,7 +494,7 @@ static json oaicompat_completion_params_parse(
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty(); auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", json()); auto stream = json_value(body, "stream", false);
if (stream && has_tools) { if (stream && has_tools) {
throw std::runtime_error("Cannot use tools with stream"); throw std::runtime_error("Cannot use tools with stream");
} }
@ -561,11 +561,12 @@ static json oaicompat_completion_params_parse(
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
} }
if (!handler.grammar_triggers.empty()) { if (!handler.grammar_triggers.empty()) {
auto triggers = json::array(); auto trigger_words = json::array();
for (const auto & word : handler.grammar_triggers) { for (const auto & word : handler.grammar_triggers) {
triggers.push_back(word); trigger_words.push_back(word);
} }
llama_params["grammar_triggers"] = triggers; llama_params["grammar_trigger_words"] = trigger_words;
} }
if (!handler.grammar.empty()) { if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) { if (llama_params.contains("grammar")) {