openai: function call arguments must be returned stringified!
This commit is contained in:
parent
e41b6ceee9
commit
a1d64cfb92
3 changed files with 14 additions and 46 deletions
|
@ -109,10 +109,11 @@ def completion_with_tool_usage(
|
||||||
if content:
|
if content:
|
||||||
print(f'💭 {content}')
|
print(f'💭 {content}')
|
||||||
|
|
||||||
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in tool_call.function.arguments.items())})'
|
args = json.loads(tool_call.function.arguments)
|
||||||
|
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
sys.stdout.write(f'⚙️ {pretty_call}')
|
sys.stdout.write(f'⚙️ {pretty_call}')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
|
tool_result = tool_map[tool_call.function.name](**args)
|
||||||
sys.stdout.write(f" → {tool_result}\n")
|
sys.stdout.write(f" → {tool_result}\n")
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
|
|
|
@ -4,7 +4,8 @@ from pydantic import BaseModel, Json, TypeAdapter
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
arguments: Union[Dict[str, Any], str]
|
arguments: str
|
||||||
|
# arguments: Union[Dict[str, Any], str]
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
|
|
|
@ -56,7 +56,6 @@ class ChatTemplate(BaseModel):
|
||||||
bos_token: str
|
bos_token: str
|
||||||
|
|
||||||
inferred_tool_style: Annotated[Optional['ToolsPromptStyle'], Field(exclude=True)] = None
|
inferred_tool_style: Annotated[Optional['ToolsPromptStyle'], Field(exclude=True)] = None
|
||||||
expects_stringified_function_arguments: Annotated[Optional[bool], Field(exclude=True)] = None
|
|
||||||
expects_strict_user_assistant_alternance: Annotated[Optional[bool], Field(exclude=True)] = None
|
expects_strict_user_assistant_alternance: Annotated[Optional[bool], Field(exclude=True)] = None
|
||||||
formats_tool_call: Annotated[Optional[bool], Field(exclude=True)] = None
|
formats_tool_call: Annotated[Optional[bool], Field(exclude=True)] = None
|
||||||
formats_tool_call_content: Annotated[Optional[bool], Field(exclude=True)] = None
|
formats_tool_call_content: Annotated[Optional[bool], Field(exclude=True)] = None
|
||||||
|
@ -108,7 +107,7 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
thought = "Precious thought"
|
thought = "Precious thought"
|
||||||
fn_name = "callMeMaybe"
|
fn_name = "callMeMaybe"
|
||||||
toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments={"lol": 123}))
|
toolcall = ToolCall(id="call_531873", type="function", function=FunctionCall(name=fn_name, arguments=json.dumps({"lol": 123})))
|
||||||
toolcall_msg = Message(role="assistant", content=None, tool_calls=[toolcall])
|
toolcall_msg = Message(role="assistant", content=None, tool_calls=[toolcall])
|
||||||
tool_result = "Tool result"
|
tool_result = "Tool result"
|
||||||
tool_name = "additioner"
|
tool_name = "additioner"
|
||||||
|
@ -119,8 +118,6 @@ class ChatTemplate(BaseModel):
|
||||||
self.formats_tool_call = succeeds([user_msg, toolcall_msg], (fn_name,))
|
self.formats_tool_call = succeeds([user_msg, toolcall_msg], (fn_name,))
|
||||||
if self.formats_tool_call:
|
if self.formats_tool_call:
|
||||||
self.formats_tool_call_content = succeeds([user_msg, toolcall_content_msg], (thought,))
|
self.formats_tool_call_content = succeeds([user_msg, toolcall_content_msg], (thought,))
|
||||||
self.expects_stringified_function_arguments = \
|
|
||||||
not succeeds([user_msg, toolcall_content_msg]) and succeeds([user_msg, stringified_toolcall_msg], (fn_name,))
|
|
||||||
|
|
||||||
self.formats_tool_result = succeeds([user_msg, assistant_msg, tool_msg], (tool_result,))
|
self.formats_tool_result = succeeds([user_msg, assistant_msg, tool_msg], (tool_result,))
|
||||||
self.formats_tool_name = succeeds([user_msg, assistant_msg, tool_msg], (tool_name,))
|
self.formats_tool_name = succeeds([user_msg, assistant_msg, tool_msg], (tool_name,))
|
||||||
|
@ -246,24 +243,6 @@ class ChatHandler(ABC):
|
||||||
])
|
])
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
elif self.args.chat_template.expects_stringified_function_arguments:
|
|
||||||
return Message(
|
|
||||||
role=m.role,
|
|
||||||
content=m.content,
|
|
||||||
name=m.name,
|
|
||||||
tool_call_id=m.tool_call_id,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
id=tc.id,
|
|
||||||
type=tc.type,
|
|
||||||
function=FunctionCall(
|
|
||||||
name=tc.function.name,
|
|
||||||
arguments=json.dumps(tc.function.arguments)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for tc in m.tool_calls
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return m
|
return m
|
||||||
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
|
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
|
||||||
|
@ -313,25 +292,6 @@ class ChatHandler(ABC):
|
||||||
# JSON!
|
# JSON!
|
||||||
# messages = [m.model_dump() for m in messages]
|
# messages = [m.model_dump() for m in messages]
|
||||||
|
|
||||||
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
|
||||||
if self.args.chat_template.expects_stringified_function_arguments:
|
|
||||||
messages = [
|
|
||||||
Message(**{
|
|
||||||
**m.model_dump(),
|
|
||||||
"tool_calls": [
|
|
||||||
ToolCall(**{
|
|
||||||
**tc.model_dump(),
|
|
||||||
"function": {
|
|
||||||
"name": tc.function.name,
|
|
||||||
"arguments": tc.function.arguments,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
for tc in m.tool_calls
|
|
||||||
] if m.tool_calls else None
|
|
||||||
})
|
|
||||||
for m in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
return self.args.chat_template.raw_render(
|
return self.args.chat_template.raw_render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
|
@ -429,7 +389,9 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(**fc)))
|
function=FunctionCall(
|
||||||
|
name=fc["name"],
|
||||||
|
arguments=json.dumps(fc["arguments"]))))
|
||||||
|
|
||||||
content_str = '\n'.join(content).strip()
|
content_str = '\n'.join(content).strip()
|
||||||
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
|
return Message(role="assistant", content=content_str if content_str else None, tool_calls=tool_calls)
|
||||||
|
@ -653,7 +615,11 @@ class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=data.get(_THOUGHT_KEY),
|
content=data.get(_THOUGHT_KEY),
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(id=gen_callid(), function=FunctionCall(**tc))
|
ToolCall(
|
||||||
|
id=gen_callid(),
|
||||||
|
function=FunctionCall(
|
||||||
|
name=tc["name"],
|
||||||
|
arguments=json.dumps(tc["arguments"])))
|
||||||
for tc in next_step['tool_calls']
|
for tc in next_step['tool_calls']
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue