openai: function call arguments must be returned stringified!

This commit is contained in:
ochafik 2024-05-18 18:19:27 +01:00
parent e41b6ceee9
commit a1d64cfb92
3 changed files with 14 additions and 46 deletions

View file

@ -109,10 +109,11 @@ def completion_with_tool_usage(
if content:
print(f'💭 {content}')
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in tool_call.function.arguments.items())})'
args = json.loads(tool_call.function.arguments)
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
tool_result = tool_map[tool_call.function.name](**args)
sys.stdout.write(f"{tool_result}\n")
messages.append(Message(
tool_call_id=tool_call.id,

View file

@ -4,7 +4,8 @@ from pydantic import BaseModel, Json, TypeAdapter
class FunctionCall(BaseModel):
name: str
arguments: Union[Dict[str, Any], str]
arguments: str
# arguments: Union[Dict[str, Any], str]
class ToolCall(BaseModel):
id: Optional[str] = None

View file

@ -56,7 +56,6 @@ class ChatTemplate(BaseModel):
bos_token: str
inferred_tool_style: Annotated[Optional['ToolsPromptStyle'], Field(exclude=True)] = None
expects_stringified_function_arguments: Annotated[Optional[bool], Field(exclude=True)] = None
expects_strict_user_assistant_alternance: Annotated[Optional[bool], Field(exclude=True)] = None
formats_tool_call: Annotated[Optional[bool], Field(exclude=True)] = None
formats_tool_call_content: Annotated[Optional[bool], Field(exclude=True)] = None
@ -108,7 +107,7 @@ class ChatTemplate(BaseModel):
thought = "Precious thought"
fn_name = "callMeMaybe"
toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments={"lol": 123}))
toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments=json.dumps({"lol": 123})))
toolcall_msg = Message(role="assistant", content=None, tool_calls=[toolcall])
tool_result = "Tool result"
tool_name = "additioner"
@ -119,8 +118,6 @@ class ChatTemplate(BaseModel):
self.formats_tool_call = succeeds([user_msg, toolcall_msg], (fn_name,))
if self.formats_tool_call:
self.formats_tool_call_content = succeeds([user_msg, toolcall_content_msg], (thought,))
self.expects_stringified_function_arguments = \
not succeeds([user_msg, toolcall_content_msg]) and succeeds([user_msg, stringified_toolcall_msg], (fn_name,))
self.formats_tool_result = succeeds([user_msg, assistant_msg, tool_msg], (tool_result,))
self.formats_tool_name = succeeds([user_msg, assistant_msg, tool_msg], (tool_name,))
@ -246,24 +243,6 @@ class ChatHandler(ABC):
])
])
)
elif self.args.chat_template.expects_stringified_function_arguments:
return Message(
role=m.role,
content=m.content,
name=m.name,
tool_call_id=m.tool_call_id,
tool_calls=[
ToolCall(
id=tc.id,
type=tc.type,
function=FunctionCall(
name=tc.function.name,
arguments=json.dumps(tc.function.arguments)
)
)
for tc in m.tool_calls
],
)
else:
return m
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
@ -313,25 +292,6 @@ class ChatHandler(ABC):
# JSON!
# messages = [m.model_dump() for m in messages]
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
if self.args.chat_template.expects_stringified_function_arguments:
messages = [
Message(**{
**m.model_dump(),
"tool_calls": [
ToolCall(**{
**tc.model_dump(),
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
}
})
for tc in m.tool_calls
] if m.tool_calls else None
})
for m in messages
]
return self.args.chat_template.raw_render(
messages=messages,
add_generation_prompt=True,
@ -429,7 +389,9 @@ class ToolCallTagsChatHandler(ChatHandler):
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(**fc)))
function=FunctionCall(
name=fc["name"],
arguments=json.dumps(fc["arguments"]))))
content_str = '\n'.join(content).strip()
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
@ -653,7 +615,11 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
role="assistant",
content=data.get(_THOUGHT_KEY),
tool_calls=[
ToolCall(id=gen_callid(), function=FunctionCall(**tc))
ToolCall(
id=gen_callid(),
function=FunctionCall(
name=tc["name"],
arguments=json.dumps(tc["arguments"])))
for tc in next_step['tool_calls']
]
)