openai: fix message merging for mixtral (parallel calls)
This commit is contained in:
parent
ea34bd3e5c
commit
80c793047b
1 changed files with 33 additions and 24 deletions
|
@ -131,31 +131,40 @@ class ChatTemplate(BaseModel):
|
|||
new_messages=[]
|
||||
i = 0
|
||||
n = len(messages)
|
||||
while i < n:
|
||||
if messages[i].role == 'system':
|
||||
assert messages[i+1].role == 'user'
|
||||
new_messages.append(Message(
|
||||
role="user",
|
||||
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'
|
||||
))
|
||||
i += 2
|
||||
elif messages[i].role == 'assistant' and messages[i].tool_calls and messages[i].content:
|
||||
tc = '\n'.join(f'<tool_call>{json.dumps(tc.model_dump())}</tool_call>' for tc in messages[i].tool_calls)
|
||||
new_messages.append(Message(
|
||||
role="assistant",
|
||||
content=f'{messages[i].content}\n{tc}'
|
||||
))
|
||||
i += 1
|
||||
elif messages[i].role == 'tool':
|
||||
new_messages.append(Message(
|
||||
role="user",
|
||||
content=f'TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}): {messages[i].content}',
|
||||
))
|
||||
i += 1
|
||||
current_role = 'user'
|
||||
current_content = []
|
||||
|
||||
def flush():
|
||||
nonlocal current_content
|
||||
nonlocal current_role
|
||||
new_messages.append(Message(
|
||||
role=current_role,
|
||||
content='\n'.join(current_content)
|
||||
))
|
||||
current_content = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if message.role == current_role:
|
||||
current_content.append(message.content)
|
||||
elif message.role in ('user', 'assistant'):
|
||||
flush()
|
||||
current_role = 'assistant' if current_role == 'user' else 'user'
|
||||
current_content.append(message.content)
|
||||
else:
|
||||
new_messages.append(messages[i])
|
||||
i += 1
|
||||
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||
if current_role == 'assistant':
|
||||
flush()
|
||||
current_role = 'user'
|
||||
if message.role == 'system':
|
||||
current_content.append(f'[SYS]{messages[i].content}[/SYS]')
|
||||
elif message.role == 'tool':
|
||||
current_content.append(f'[TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}]{messages[i].content}[/TOOL RESULT]')
|
||||
else:
|
||||
sys.stderr.write(f'Unexpected message role: {message.role}\n')
|
||||
current_content.append(f'[ROLE={messages[i].role}]{messages[i].content}[/ROLE]')
|
||||
|
||||
if current_content:
|
||||
flush()
|
||||
|
||||
messages = new_messages
|
||||
|
||||
result = self._template.render(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue