tool-call: better error reporting for server tests

This commit is contained in:
ochafik 2024-09-28 18:31:22 +01:00
parent 7cef90cf9c
commit 55cf337560

View file

@ -655,19 +655,21 @@ async def step_tool_called(context, expected_name, expected_arguments):
expected_name = expected_name if expected_name else None
expected_arguments = json.loads(expected_arguments) if expected_arguments else None
def check(tool_calls):
if tool_calls is None:
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
else:
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
tool_call = tool_calls[0]
actual_name = tool_call.function.name
actual_arguments = json.loads(tool_call.function.arguments)
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}"
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
for i in range(n_completions):
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
result = context.tasks_result.pop()
def check(tool_calls):
if tool_calls is None:
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}, result = {result}'
else:
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
tool_call = tool_calls[0]
actual_name = tool_call.function.name
actual_arguments = json.loads(tool_call.function.arguments)
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}, result = {result}"
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
assert_n_tokens_predicted(result, tool_calls_check=check)
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
@step('no tool is called')