diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 922ba0288..f1a97deec 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -655,19 +655,21 @@ async def step_tool_called(context, expected_name, expected_arguments): expected_name = expected_name if expected_name else None expected_arguments = json.loads(expected_arguments) if expected_arguments else None - def check(tool_calls): - if tool_calls is None: - assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}' - else: - assert len(tool_calls) == 1, f"tool calls: {tool_calls}" - tool_call = tool_calls[0] - actual_name = tool_call.function.name - actual_arguments = json.loads(tool_call.function.arguments) - assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}" - assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" - for i in range(n_completions): - assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check) + result = context.tasks_result.pop() + + def check(tool_calls): + if tool_calls is None: + assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}, result = {result}' + else: + assert len(tool_calls) == 1, f"tool calls: {tool_calls}" + tool_call = tool_calls[0] + actual_name = tool_call.function.name + actual_arguments = json.loads(tool_call.function.arguments) + assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}, result = {result}" + assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" + + assert_n_tokens_predicted(result, tool_calls_check=check) assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests" @step('no tool is called')