tool-call
: add server tests for llama 3.1
This commit is contained in:
parent
9e366b3d03
commit
a774093a99
3 changed files with 129 additions and 16 deletions
|
@ -316,7 +316,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
tool_rules.push_back(
|
tool_rules.push_back(
|
||||||
builder.add_rule(
|
builder.add_rule(
|
||||||
name + "-call",
|
name + "-call",
|
||||||
"\"\\n{\\\"name\\\": " + name + "\\\", \\\"parameters\\\", \" " +
|
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||||
builder.add_schema(name + "-args", parameters) +
|
builder.add_schema(name + "-args", parameters) +
|
||||||
" \"}\""));
|
" \"}\""));
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
|
|
|
@ -80,6 +80,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.temperature = None
|
context.temperature = None
|
||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
context.disable_ctx_shift = False
|
context.disable_ctx_shift = False
|
||||||
|
context.use_jinja = False
|
||||||
|
context.chat_template_file = None
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
|
@ -159,6 +161,16 @@ def step_slot_save_path(context, slot_save_path: str):
|
||||||
context.slot_save_path = slot_save_path
|
context.slot_save_path = slot_save_path
|
||||||
|
|
||||||
|
|
||||||
|
@step('jinja templates are enabled')
|
||||||
|
def step_use_jinja(context):
|
||||||
|
context.use_jinja = True
|
||||||
|
|
||||||
|
|
||||||
|
@step('chat template file {file}')
|
||||||
|
def step_use_jinja(context, file):
|
||||||
|
context.chat_template_file = file
|
||||||
|
|
||||||
|
|
||||||
@step('using slot id {id_slot:d}')
|
@step('using slot id {id_slot:d}')
|
||||||
def step_id_slot(context, id_slot: int):
|
def step_id_slot(context, id_slot: int):
|
||||||
context.id_slot = id_slot
|
context.id_slot = id_slot
|
||||||
|
@ -369,7 +381,7 @@ def step_response_format(context, response_format):
|
||||||
def step_tools(context, tools):
|
def step_tools(context, tools):
|
||||||
context.tools = json.loads(tools)
|
context.tools = json.loads(tools)
|
||||||
|
|
||||||
@step('tool choice {tool_choice}')
|
@step('a tool choice {tool_choice}')
|
||||||
def step_tool_choice(context, tool_choice):
|
def step_tool_choice(context, tool_choice):
|
||||||
context.tool_choice = tool_choice
|
context.tool_choice = tool_choice
|
||||||
|
|
||||||
|
@ -490,8 +502,11 @@ async def step_oai_chat_completions(context, api_error):
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
seeds = await completions_seed(context, num_seeds=1)
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
completion = await oai_chat_completions(context.prompts.pop(),
|
completion = await oai_chat_completions(context.prompts.pop(),
|
||||||
seeds[0] if seeds is not None else seeds,
|
seeds[0] if seeds else None,
|
||||||
context.system_prompt,
|
|
||||||
|
context.system_prompt
|
||||||
|
if hasattr(context, 'system_prompt') else None,
|
||||||
|
|
||||||
context.base_url,
|
context.base_url,
|
||||||
'/v1/chat',
|
'/v1/chat',
|
||||||
False,
|
False,
|
||||||
|
@ -631,6 +646,43 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
||||||
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||||
|
|
||||||
|
|
||||||
|
@step('tool {expected_name} is called with arguments {expected_arguments}')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_tool_called(context, expected_name, expected_arguments):
|
||||||
|
n_completions = await gather_tasks_results(context)
|
||||||
|
assert n_completions > 0
|
||||||
|
|
||||||
|
expected_name = expected_name if expected_name else None
|
||||||
|
expected_arguments = json.loads(expected_arguments) if expected_arguments else None
|
||||||
|
|
||||||
|
def check(tool_calls):
|
||||||
|
if tool_calls is None:
|
||||||
|
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
|
||||||
|
else:
|
||||||
|
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
|
||||||
|
tool_call = tool_calls[0]
|
||||||
|
actual_name = tool_call.name
|
||||||
|
actual_arguments = json.loads(tool_call.arguments)
|
||||||
|
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}"
|
||||||
|
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
|
||||||
|
|
||||||
|
for i in range(n_completions):
|
||||||
|
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
|
||||||
|
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||||
|
|
||||||
|
@step('no tool is called')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_tool_called(context):
|
||||||
|
n_completions = await gather_tasks_results(context)
|
||||||
|
assert n_completions > 0
|
||||||
|
|
||||||
|
def check(tool_calls):
|
||||||
|
assert tool_calls is None
|
||||||
|
|
||||||
|
for i in range(n_completions):
|
||||||
|
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
|
||||||
|
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||||
|
|
||||||
@step('embeddings are computed for')
|
@step('embeddings are computed for')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_compute_embedding(context):
|
async def step_compute_embedding(context):
|
||||||
|
@ -1001,19 +1053,23 @@ async def oai_chat_completions(user_prompt,
|
||||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||||
# openai client always expects an api key
|
# openai client always expects an api key
|
||||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
|
assert isinstance(seed, int), f'seed: {seed}'
|
||||||
seed = seed if seed is not None else 42
|
seed = seed if seed is not None else 42
|
||||||
|
|
||||||
enable_streaming = enable_streaming if enable_streaming is not None else False
|
enable_streaming = enable_streaming if enable_streaming is not None else False
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
})
|
||||||
|
if user_prompt:
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
})
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": messages,
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_prompt,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"stream": enable_streaming,
|
"stream": enable_streaming,
|
||||||
|
@ -1115,6 +1171,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
assert chat_completion.usage is not None
|
assert chat_completion.usage is not None
|
||||||
completion_response = {
|
completion_response = {
|
||||||
'content': chat_completion.choices[0].message.content,
|
'content': chat_completion.choices[0].message.content,
|
||||||
|
'tool_calls': chat_completion.choices[0].message.tool_calls,
|
||||||
'timings': {
|
'timings': {
|
||||||
'predicted_n': chat_completion.usage.completion_tokens,
|
'predicted_n': chat_completion.usage.completion_tokens,
|
||||||
'prompt_n': chat_completion.usage.prompt_tokens
|
'prompt_n': chat_completion.usage.prompt_tokens
|
||||||
|
@ -1181,11 +1238,13 @@ async def request_oai_embeddings(input, seed,
|
||||||
return [e.embedding for e in oai_embeddings.data]
|
return [e.embedding for e in oai_embeddings.data]
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None, tool_calls_check=None):
|
||||||
content = completion_response['content']
|
content = completion_response['content']
|
||||||
|
tool_calls = completion_response.get('tool_calls')
|
||||||
n_predicted = completion_response['timings']['predicted_n']
|
n_predicted = completion_response['timings']['predicted_n']
|
||||||
assert len(content) > 0, "no token predicted"
|
assert (content and len(content) > 0) or (tool_calls and len(tool_calls) > 0), "no token predicted"
|
||||||
if re_content is not None:
|
if re_content is not None:
|
||||||
|
assert content
|
||||||
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
|
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
|
||||||
matches = p.finditer(content)
|
matches = p.finditer(content)
|
||||||
last_match = 0
|
last_match = 0
|
||||||
|
@ -1201,6 +1260,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
||||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
||||||
print(f"Checking completion response: {highlighted}")
|
print(f"Checking completion response: {highlighted}")
|
||||||
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
|
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
|
||||||
|
if tool_calls_check:
|
||||||
|
tool_calls_check(tool_calls)
|
||||||
if expected_predicted_n and expected_predicted_n > 0:
|
if expected_predicted_n and expected_predicted_n > 0:
|
||||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||||
f' {n_predicted} <> {expected_predicted_n}')
|
f' {n_predicted} <> {expected_predicted_n}')
|
||||||
|
@ -1409,6 +1470,10 @@ def start_server_background(context):
|
||||||
server_args.extend(['--grp-attn-w', context.n_ga_w])
|
server_args.extend(['--grp-attn-w', context.n_ga_w])
|
||||||
if context.debug:
|
if context.debug:
|
||||||
server_args.append('--verbose')
|
server_args.append('--verbose')
|
||||||
|
if context.use_jinja:
|
||||||
|
server_args.append('--jinja')
|
||||||
|
if context.chat_template_file:
|
||||||
|
server_args.extend(['--chat-template-file', context.chat_template_file])
|
||||||
if context.lora_file:
|
if context.lora_file:
|
||||||
server_args.extend(['--lora', context.lora_file])
|
server_args.extend(['--lora', context.lora_file])
|
||||||
if context.disable_ctx_shift:
|
if context.disable_ctx_shift:
|
||||||
|
|
48
examples/server/tests/features/tool_call.feature
Normal file
48
examples/server/tests/features/tool_call.feature
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
@llama.cpp
|
||||||
|
@server
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model.gguf
|
||||||
|
And a model alias tinyllama-2
|
||||||
|
And BOS token is 1
|
||||||
|
And 42 as server seed
|
||||||
|
And 8192 KV cache size
|
||||||
|
And 32 as batch size
|
||||||
|
And 2 slots
|
||||||
|
And 64 server max tokens to predict
|
||||||
|
And prometheus compatible metrics exposed
|
||||||
|
And jinja templates are enabled
|
||||||
|
And chat template file ../../../tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Health
|
||||||
|
Then the server is ready
|
||||||
|
And all slots are idle
|
||||||
|
|
||||||
|
Scenario Outline: OAI Compatibility w/ required tool
|
||||||
|
Given a model test
|
||||||
|
And <n> max tokens to predict
|
||||||
|
And a user prompt write a hello world in python
|
||||||
|
And a tool choice <tool_choice>
|
||||||
|
And tools <tools>
|
||||||
|
Given an OAI compatible chat completions request with no api error
|
||||||
|
Then tool <tool_name> is called with arguments <tool_arguments>
|
||||||
|
|
||||||
|
Examples: Prompts
|
||||||
|
| n | tool_name | tool_arguments | tool_choice | tools |
|
||||||
|
| 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||||
|
| 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||||
|
|
||||||
|
Scenario: OAI Compatibility w/ no tool
|
||||||
|
Given a model test
|
||||||
|
And 16 max tokens to predict
|
||||||
|
And a user prompt write a hello world in python
|
||||||
|
And a tool choice <tool_choice>
|
||||||
|
And tools []
|
||||||
|
Given an OAI compatible chat completions request with no api error
|
||||||
|
Then no tool is called
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue