From a1d64cfb924c6edff412ecca5a330f74939a9192 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 May 2024 18:19:27 +0100 Subject: [PATCH] openai: function call arguments must be returned stringified! --- examples/agent/agent.py | 5 ++-- examples/openai/api.py | 3 ++- examples/openai/prompting.py | 52 +++++++----------------------------- 3 files changed, 14 insertions(+), 46 deletions(-) diff --git a/examples/agent/agent.py b/examples/agent/agent.py index ebb51e111..bf2ee907e 100644 --- a/examples/agent/agent.py +++ b/examples/agent/agent.py @@ -109,10 +109,11 @@ def completion_with_tool_usage( if content: print(f'💭 {content}') - pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in tool_call.function.arguments.items())})' + args = json.loads(tool_call.function.arguments) + pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' sys.stdout.write(f'⚙️ {pretty_call}') sys.stdout.flush() - tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments) + tool_result = tool_map[tool_call.function.name](**args) sys.stdout.write(f" → {tool_result}\n") messages.append(Message( tool_call_id=tool_call.id, diff --git a/examples/openai/api.py b/examples/openai/api.py index 705c5654b..cafe12752 100644 --- a/examples/openai/api.py +++ b/examples/openai/api.py @@ -4,7 +4,8 @@ from pydantic import BaseModel, Json, TypeAdapter class FunctionCall(BaseModel): name: str - arguments: Union[Dict[str, Any], str] + arguments: str + # arguments: Union[Dict[str, Any], str] class ToolCall(BaseModel): id: Optional[str] = None diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py index f1b7d17eb..386d540e8 100644 --- a/examples/openai/prompting.py +++ b/examples/openai/prompting.py @@ -56,7 +56,6 @@ class ChatTemplate(BaseModel): bos_token: str inferred_tool_style: Annotated[Optional['ToolsPromptStyle'], Field(exclude=True)] = None - expects_stringified_function_arguments: Annotated[Optional[bool], Field(exclude=True)] = None expects_strict_user_assistant_alternance: Annotated[Optional[bool], Field(exclude=True)] = None formats_tool_call: Annotated[Optional[bool], Field(exclude=True)] = None formats_tool_call_content: Annotated[Optional[bool], Field(exclude=True)] = None @@ -108,7 +107,7 @@ class ChatTemplate(BaseModel): thought = "Precious thought" fn_name = "callMeMaybe" - toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments={"lol": 123})) + toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments=json.dumps({"lol": 123}))) toolcall_msg = Message(role="assistant", content=None, tool_calls=[toolcall]) tool_result = "Tool result" tool_name = "additioner" @@ -119,8 +118,6 @@ class ChatTemplate(BaseModel): self.formats_tool_call = succeeds([user_msg, toolcall_msg], (fn_name,)) if self.formats_tool_call: self.formats_tool_call_content = succeeds([user_msg, toolcall_content_msg], (thought,)) - self.expects_stringified_function_arguments = \ - not succeeds([user_msg, toolcall_content_msg]) and succeeds([user_msg, stringified_toolcall_msg], (fn_name,)) self.formats_tool_result = succeeds([user_msg, assistant_msg, tool_msg], (tool_result,)) self.formats_tool_name = succeeds([user_msg, assistant_msg, tool_msg], (tool_name,)) @@ -246,24 +243,6 @@ class ChatHandler(ABC): ]) ]) ) - elif self.args.chat_template.expects_stringified_function_arguments: - return Message( - role=m.role, - content=m.content, - name=m.name, - tool_call_id=m.tool_call_id, - tool_calls=[ - ToolCall( - id=tc.id, - type=tc.type, - function=FunctionCall( - name=tc.function.name, - arguments=json.dumps(tc.function.arguments) - ) - ) - for tc in m.tool_calls - ], - ) else: return m elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'): @@ -313,25 +292,6 @@ class ChatHandler(ABC): # JSON! # messages = [m.model_dump() for m in messages] - # if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: - if self.args.chat_template.expects_stringified_function_arguments: - messages = [ - Message(**{ - **m.model_dump(), - "tool_calls": [ - ToolCall(**{ - **tc.model_dump(), - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - } - }) - for tc in m.tool_calls - ] if m.tool_calls else None - }) - for m in messages - ] - return self.args.chat_template.raw_render( messages=messages, add_generation_prompt=True, @@ -429,7 +389,9 @@ class ToolCallTagsChatHandler(ChatHandler): tool_calls.append( ToolCall( id=gen_callid(), - function=FunctionCall(**fc))) + function=FunctionCall( + name=fc["name"], + arguments=json.dumps(fc["arguments"])))) content_str = '\n'.join(content).strip() return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls) @@ -653,7 +615,11 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler): role="assistant", content=data.get(_THOUGHT_KEY), tool_calls=[ - ToolCall(id=gen_callid(), function=FunctionCall(**tc)) + ToolCall( + id=gen_callid(), + function=FunctionCall( + name=tc["name"], + arguments=json.dumps(tc["arguments"]))) for tc in next_step['tool_calls'] ] )