openai: function call arguments must be returned stringified!

This commit is contained in:
ochafik 2024-05-18 18:19:27 +01:00
parent e41b6ceee9
commit a1d64cfb92
3 changed files with 14 additions and 46 deletions

View file

@ -109,10 +109,11 @@ def completion_with_tool_usage(
if content: if content:
print(f'💭 {content}') print(f'💭 {content}')
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in tool_call.function.arguments.items())})' args = json.loads(tool_call.function.arguments)
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}') sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush() sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments) tool_result = tool_map[tool_call.function.name](**args)
sys.stdout.write(f"{tool_result}\n") sys.stdout.write(f"{tool_result}\n")
messages.append(Message( messages.append(Message(
tool_call_id=tool_call.id, tool_call_id=tool_call.id,

View file

@ -4,7 +4,8 @@ from pydantic import BaseModel, Json, TypeAdapter
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
name: str name: str
arguments: Union[Dict[str, Any], str] arguments: str
# arguments: Union[Dict[str, Any], str]
class ToolCall(BaseModel): class ToolCall(BaseModel):
id: Optional[str] = None id: Optional[str] = None

View file

@ -56,7 +56,6 @@ class ChatTemplate(BaseModel):
bos_token: str bos_token: str
inferred_tool_style: Annotated[Optional['ToolsPromptStyle'], Field(exclude=True)] = None inferred_tool_style: Annotated[Optional['ToolsPromptStyle'], Field(exclude=True)] = None
expects_stringified_function_arguments: Annotated[Optional[bool], Field(exclude=True)] = None
expects_strict_user_assistant_alternance: Annotated[Optional[bool], Field(exclude=True)] = None expects_strict_user_assistant_alternance: Annotated[Optional[bool], Field(exclude=True)] = None
formats_tool_call: Annotated[Optional[bool], Field(exclude=True)] = None formats_tool_call: Annotated[Optional[bool], Field(exclude=True)] = None
formats_tool_call_content: Annotated[Optional[bool], Field(exclude=True)] = None formats_tool_call_content: Annotated[Optional[bool], Field(exclude=True)] = None
@ -108,7 +107,7 @@ class ChatTemplate(BaseModel):
thought = "Precious thought" thought = "Precious thought"
fn_name = "callMeMaybe" fn_name = "callMeMaybe"
toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments={"lol": 123})) toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments=json.dumps({"lol": 123})))
toolcall_msg = Message(role="assistant", content=None, tool_calls=[toolcall]) toolcall_msg = Message(role="assistant", content=None, tool_calls=[toolcall])
tool_result = "Tool result" tool_result = "Tool result"
tool_name = "additioner" tool_name = "additioner"
@ -119,8 +118,6 @@ class ChatTemplate(BaseModel):
self.formats_tool_call = succeeds([user_msg, toolcall_msg], (fn_name,)) self.formats_tool_call = succeeds([user_msg, toolcall_msg], (fn_name,))
if self.formats_tool_call: if self.formats_tool_call:
self.formats_tool_call_content = succeeds([user_msg, toolcall_content_msg], (thought,)) self.formats_tool_call_content = succeeds([user_msg, toolcall_content_msg], (thought,))
self.expects_stringified_function_arguments = \
not succeeds([user_msg, toolcall_content_msg]) and succeeds([user_msg, stringified_toolcall_msg], (fn_name,))
self.formats_tool_result = succeeds([user_msg, assistant_msg, tool_msg], (tool_result,)) self.formats_tool_result = succeeds([user_msg, assistant_msg, tool_msg], (tool_result,))
self.formats_tool_name = succeeds([user_msg, assistant_msg, tool_msg], (tool_name,)) self.formats_tool_name = succeeds([user_msg, assistant_msg, tool_msg], (tool_name,))
@ -246,24 +243,6 @@ class ChatHandler(ABC):
]) ])
]) ])
) )
elif self.args.chat_template.expects_stringified_function_arguments:
return Message(
role=m.role,
content=m.content,
name=m.name,
tool_call_id=m.tool_call_id,
tool_calls=[
ToolCall(
id=tc.id,
type=tc.type,
function=FunctionCall(
name=tc.function.name,
arguments=json.dumps(tc.function.arguments)
)
)
for tc in m.tool_calls
],
)
else: else:
return m return m
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'): elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
@ -313,25 +292,6 @@ class ChatHandler(ABC):
# JSON! # JSON!
# messages = [m.model_dump() for m in messages] # messages = [m.model_dump() for m in messages]
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
if self.args.chat_template.expects_stringified_function_arguments:
messages = [
Message(**{
**m.model_dump(),
"tool_calls": [
ToolCall(**{
**tc.model_dump(),
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
}
})
for tc in m.tool_calls
] if m.tool_calls else None
})
for m in messages
]
return self.args.chat_template.raw_render( return self.args.chat_template.raw_render(
messages=messages, messages=messages,
add_generation_prompt=True, add_generation_prompt=True,
@ -429,7 +389,9 @@ class ToolCallTagsChatHandler(ChatHandler):
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
id=gen_callid(), id=gen_callid(),
function=FunctionCall(**fc))) function=FunctionCall(
name=fc["name"],
arguments=json.dumps(fc["arguments"]))))
content_str = '\n'.join(content).strip() content_str = '\n'.join(content).strip()
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls) return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
@ -653,7 +615,11 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
role="assistant", role="assistant",
content=data.get(_THOUGHT_KEY), content=data.get(_THOUGHT_KEY),
tool_calls=[ tool_calls=[
ToolCall(id=gen_callid(), function=FunctionCall(**tc)) ToolCall(
id=gen_callid(),
function=FunctionCall(
name=tc["name"],
arguments=json.dumps(tc["arguments"])))
for tc in next_step['tool_calls'] for tc in next_step['tool_calls']
] ]
) )