tool-call
: better error reporting for server tests
This commit is contained in:
parent
7cef90cf9c
commit
55cf337560
1 changed files with 14 additions and 12 deletions
|
@ -655,19 +655,21 @@ async def step_tool_called(context, expected_name, expected_arguments):
|
|||
expected_name = expected_name if expected_name else None
|
||||
expected_arguments = json.loads(expected_arguments) if expected_arguments else None
|
||||
|
||||
def check(tool_calls):
|
||||
if tool_calls is None:
|
||||
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
|
||||
else:
|
||||
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
|
||||
tool_call = tool_calls[0]
|
||||
actual_name = tool_call.function.name
|
||||
actual_arguments = json.loads(tool_call.function.arguments)
|
||||
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}"
|
||||
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
|
||||
|
||||
for i in range(n_completions):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
|
||||
result = context.tasks_result.pop()
|
||||
|
||||
def check(tool_calls):
|
||||
if tool_calls is None:
|
||||
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}, result = {result}'
|
||||
else:
|
||||
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
|
||||
tool_call = tool_calls[0]
|
||||
actual_name = tool_call.function.name
|
||||
actual_arguments = json.loads(tool_call.function.arguments)
|
||||
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}, result = {result}"
|
||||
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
|
||||
|
||||
assert_n_tokens_predicted(result, tool_calls_check=check)
|
||||
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||
|
||||
@step('no tool is called')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue