tool-call
: test MistralNemo in forced tools server tests (w/ parallel tool calls disabled)
This commit is contained in:
parent
30bd00bcf7
commit
080982ebf3
4 changed files with 57 additions and 29 deletions
|
@ -1047,7 +1047,7 @@ std::string build_grammar(const std::function<void(const llama_grammar_builder &
|
||||||
return converter.add_rule(name, rule);
|
return converter.add_rule(name, rule);
|
||||||
},
|
},
|
||||||
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
|
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
|
||||||
return converter.visit(schema, name);
|
return converter.visit(schema, name == "root" ? "" : name);
|
||||||
},
|
},
|
||||||
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
|
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
|
||||||
converter.resolve_refs(schema, "");
|
converter.resolve_refs(schema, "");
|
||||||
|
|
|
@ -267,26 +267,31 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
|
||||||
static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
|
static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
|
||||||
auto content_end = input.find("[TOOL_CALLS]");
|
auto content_end = input.find("[TOOL_CALLS]");
|
||||||
size_t tc_start = std::string::npos;
|
size_t tc_start = std::string::npos;
|
||||||
|
|
||||||
|
llama_tool_calls result;
|
||||||
|
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||||
|
for (const auto & tool_call : tool_calls) {
|
||||||
|
const auto & arguments = tool_call["arguments"];
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
tool_call["name"],
|
||||||
|
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
|
tool_call["id"],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
if (content_end != std::string::npos) {
|
if (content_end != std::string::npos) {
|
||||||
tc_start = content_end + 12;
|
tc_start = content_end + 12;
|
||||||
|
result.content = input.substr(0, content_end);
|
||||||
|
auto tool_calls = json::parse(input.substr(tc_start));
|
||||||
|
process_tool_calls(tool_calls);
|
||||||
} else {
|
} else {
|
||||||
// Somehow not getting [TOOL_CALLS] in the output. Oh well, just do without it.
|
// Somehow not getting [TOOL_CALLS] in the output. Oh well, just do without it.
|
||||||
content_end = input.find("[{\"");
|
try {
|
||||||
if (content_end == std::string::npos || content_end > 0) {
|
auto tool_calls = json::parse(input);
|
||||||
return {input, {}};
|
process_tool_calls(tool_calls);
|
||||||
|
} catch (const json::exception & e) {
|
||||||
|
throw std::runtime_error("Failed to parse tool calls: " + std::string(e.what()) + ":\n" + input);
|
||||||
}
|
}
|
||||||
tc_start = content_end;
|
|
||||||
}
|
|
||||||
llama_tool_calls result;
|
|
||||||
result.content = input.substr(0, content_end);
|
|
||||||
auto tool_calls = json::parse(input.substr(tc_start));
|
|
||||||
for (const auto & tool_call : tool_calls) {
|
|
||||||
const auto & arguments = tool_call["arguments"];
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
tool_call["name"],
|
|
||||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
||||||
tool_call["id"],
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -403,7 +408,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
: tool_call;
|
: tool_call;
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
builder.add_schema("", schema);
|
builder.add_schema("root", schema);
|
||||||
});
|
});
|
||||||
// TODO: add schema to system prompt.
|
// TODO: add schema to system prompt.
|
||||||
auto tweaked_messages = add_system(
|
auto tweaked_messages = add_system(
|
||||||
|
@ -450,11 +455,12 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
if (!parallel) {
|
if (!parallel) {
|
||||||
schema["maxItems"] = 1;
|
schema["maxItems"] = 1;
|
||||||
}
|
}
|
||||||
builder.add_schema("", schema);
|
builder.add_schema("root", schema);
|
||||||
});
|
});
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
|
handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
|
||||||
handler.grammar_trigger_words.push_back("[{\"");
|
handler.grammar_trigger_words.push_back("[{\"");
|
||||||
|
handler.grammar_trigger_words.push_back("[ { \"");
|
||||||
}
|
}
|
||||||
auto tweaked_messages = add_system(messages, "Prefix any tool calls with [TOOL_CALLS]");
|
auto tweaked_messages = add_system(messages, "Prefix any tool calls with [TOOL_CALLS]");
|
||||||
handler.prompt = tmpl.apply(tweaked_messages, tools, /* add_generation_prompt= */ true);
|
handler.prompt = tmpl.apply(tweaked_messages, tools, /* add_generation_prompt= */ true);
|
||||||
|
|
|
@ -78,6 +78,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.response_format = None
|
context.response_format = None
|
||||||
context.tools = None
|
context.tools = None
|
||||||
context.tool_choice = None
|
context.tool_choice = None
|
||||||
|
context.parallel_tool_calls = None
|
||||||
context.temperature = None
|
context.temperature = None
|
||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
context.disable_ctx_shift = False
|
context.disable_ctx_shift = False
|
||||||
|
@ -393,6 +394,17 @@ def step_tools(context, tools):
|
||||||
def step_tool_choice(context, tool_choice):
|
def step_tool_choice(context, tool_choice):
|
||||||
context.tool_choice = tool_choice
|
context.tool_choice = tool_choice
|
||||||
|
|
||||||
|
@step('parallel tool calls is {enable_parallel_tool_calls}')
|
||||||
|
def step_parallel_tool_calls(context, enable_parallel_tool_calls):
|
||||||
|
if enable_parallel_tool_calls == 'enabled':
|
||||||
|
context.parallel_tool_calls = True
|
||||||
|
elif enable_parallel_tool_calls == 'disabled':
|
||||||
|
context.parallel_tool_calls = False
|
||||||
|
elif enable_parallel_tool_calls == '':
|
||||||
|
context.parallel_tool_calls = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid value for enable_parallel_tool_calls: {enable_parallel_tool_calls}")
|
||||||
|
|
||||||
@step('{temperature:f} temperature')
|
@step('{temperature:f} temperature')
|
||||||
def step_temperature(context, temperature):
|
def step_temperature(context, temperature):
|
||||||
context.temperature = temperature
|
context.temperature = temperature
|
||||||
|
@ -541,6 +553,7 @@ async def step_oai_chat_completions(context, api_error):
|
||||||
if hasattr(context, 'tools') else None,
|
if hasattr(context, 'tools') else None,
|
||||||
|
|
||||||
tool_choice=context.tool_choice,
|
tool_choice=context.tool_choice,
|
||||||
|
parallel_tool_calls=context.parallel_tool_calls,
|
||||||
|
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None,
|
if hasattr(context, 'user_api_key') else None,
|
||||||
|
@ -615,6 +628,7 @@ async def step_oai_chat_completions(context):
|
||||||
tools=context.tools
|
tools=context.tools
|
||||||
if hasattr(context, 'tools') else None,
|
if hasattr(context, 'tools') else None,
|
||||||
tool_choice=context.tool_choice,
|
tool_choice=context.tool_choice,
|
||||||
|
parallel_tool_calls=context.parallel_tool_calls,
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
@ -638,6 +652,7 @@ async def step_oai_chat_completions(context):
|
||||||
# if hasattr(context, 'response_format') else None,
|
# if hasattr(context, 'response_format') else None,
|
||||||
tools=context.tools,# if hasattr(context, 'tools') else None,
|
tools=context.tools,# if hasattr(context, 'tools') else None,
|
||||||
tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
|
tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
|
||||||
|
parallel_tool_calls=context.parallel_tool_calls,
|
||||||
user_api_key=context.user_api_key)
|
user_api_key=context.user_api_key)
|
||||||
# if hasattr(context, 'user_api_key') else None)
|
# if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
@ -1099,6 +1114,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
response_format=None,
|
response_format=None,
|
||||||
tools=None,
|
tools=None,
|
||||||
tool_choice=None,
|
tool_choice=None,
|
||||||
|
parallel_tool_calls=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
expect_api_error=None) -> int | dict[str, Any]:
|
expect_api_error=None) -> int | dict[str, Any]:
|
||||||
if debug:
|
if debug:
|
||||||
|
@ -1133,6 +1149,8 @@ async def oai_chat_completions(user_prompt,
|
||||||
payload['tools'] = tools
|
payload['tools'] = tools
|
||||||
if tool_choice is not None:
|
if tool_choice is not None:
|
||||||
payload['tool_choice'] = tool_choice
|
payload['tool_choice'] = tool_choice
|
||||||
|
if parallel_tool_calls is not None:
|
||||||
|
payload['parallel_tool_calls'] = parallel_tool_calls
|
||||||
completion_response = {
|
completion_response = {
|
||||||
'content': '',
|
'content': '',
|
||||||
'timings': {
|
'timings': {
|
||||||
|
@ -1199,6 +1217,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
response_format=payload.get('response_format') or openai.NOT_GIVEN,
|
response_format=payload.get('response_format') or openai.NOT_GIVEN,
|
||||||
tools=payload.get('tools') or openai.NOT_GIVEN,
|
tools=payload.get('tools') or openai.NOT_GIVEN,
|
||||||
tool_choice=payload.get('tool_choice') or openai.NOT_GIVEN,
|
tool_choice=payload.get('tool_choice') or openai.NOT_GIVEN,
|
||||||
|
parallel_tool_calls=payload.get('parallel_tool_calls', openai.NOT_GIVEN),
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=payload['temperature']
|
temperature=payload['temperature']
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ Feature: llama.cpp server
|
||||||
And jinja templates are enabled
|
And jinja templates are enabled
|
||||||
|
|
||||||
|
|
||||||
Scenario Outline: OAI Compatibility w/ tools and required tool_choice
|
Scenario Outline: OAI Compatibility w/ tools and required tool_choice (<template_name> template, <tool_name> tool)
|
||||||
Given a chat template file ../../../tests/chat/templates/<template_name>.jinja
|
Given a chat template file ../../../tests/chat/templates/<template_name>.jinja
|
||||||
And the server is starting
|
And the server is starting
|
||||||
And the server is healthy
|
And the server is healthy
|
||||||
|
@ -25,22 +25,25 @@ Feature: llama.cpp server
|
||||||
And a user prompt write a hello world in python
|
And a user prompt write a hello world in python
|
||||||
And a tool choice required
|
And a tool choice required
|
||||||
And tools <tools>
|
And tools <tools>
|
||||||
|
And parallel tool calls is <parallel_tool_calls>
|
||||||
And an OAI compatible chat completions request with no api error
|
And an OAI compatible chat completions request with no api error
|
||||||
Then tool <tool_name> is called with arguments <tool_arguments>
|
Then tool <tool_name> is called with arguments <tool_arguments>
|
||||||
|
|
||||||
Examples: Prompts
|
Examples: Prompts
|
||||||
| template_name | n_predict | tool_name | tool_arguments | tools |
|
| template_name | n_predict | tool_name | tool_arguments | tools | parallel_tool_calls |
|
||||||
| meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | disabled |
|
||||||
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | disabled |
|
||||||
| meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | disabled |
|
||||||
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | disabled |
|
||||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | disabled |
|
||||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | ipython | {"code": "it and realed at the otter. Asked Dave Dasty, Daisy is a big, shiny blue. As"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | ipython | {"code": "it and realed at the otter. Asked Dave Dasty, Daisy is a big, shiny blue. As"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | disabled |
|
||||||
| meta-llama-Llama-3.2-3B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meta-llama-Llama-3.2-3B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | disabled |
|
||||||
| meta-llama-Llama-3.2-3B-Instruct | 64 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meta-llama-Llama-3.2-3B-Instruct | 64 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | disabled |
|
||||||
|
| mistralai-Mistral-Nemo-Instruct-2407 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | disabled |
|
||||||
|
| mistralai-Mistral-Nemo-Instruct-2407 | 128 | ipython | {"code": "It's a small cable."} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | disabled |
|
||||||
|
|
||||||
|
|
||||||
Scenario Outline: OAI Compatibility w/ tools and auto tool_choice
|
Scenario Outline: OAI Compatibility w/ tools and auto tool_choice (<template_name> template)
|
||||||
Given a chat template file ../../../tests/chat/templates/<template_name>.jinja
|
Given a chat template file ../../../tests/chat/templates/<template_name>.jinja
|
||||||
And the server is starting
|
And the server is starting
|
||||||
And the server is healthy
|
And the server is healthy
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue