openai: fix message merging for mixtral (parallel calls)

This commit is contained in:
ochafik 2024-03-29 17:01:20 +00:00
parent ea34bd3e5c
commit 80c793047b

View file

@ -131,31 +131,40 @@ class ChatTemplate(BaseModel):
new_messages=[]
i = 0
n = len(messages)
while i < n:
if messages[i].role == 'system':
assert messages[i+1].role == 'user'
new_messages.append(Message(
role="user",
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'
))
i += 2
elif messages[i].role == 'assistant' and messages[i].tool_calls and messages[i].content:
tc = '\n'.join(f'<tool_call>{json.dumps(tc.model_dump())}</tool_call>' for tc in messages[i].tool_calls)
new_messages.append(Message(
role="assistant",
content=f'{messages[i].content}\n{tc}'
))
i += 1
elif messages[i].role == 'tool':
new_messages.append(Message(
role="user",
content=f'TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}): {messages[i].content}',
))
i += 1
current_role = 'user'
current_content = []
def flush():
nonlocal current_content
nonlocal current_role
new_messages.append(Message(
role=current_role,
content='\n'.join(current_content)
))
current_content = []
for i, message in enumerate(messages):
if message.role == current_role:
current_content.append(message.content)
elif message.role in ('user', 'assistant'):
flush()
current_role = 'assistant' if current_role == 'user' else 'user'
current_content.append(message.content)
else:
new_messages.append(messages[i])
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
if current_role == 'assistant':
flush()
current_role = 'user'
if message.role == 'system':
current_content.append(f'[SYS]{messages[i].content}[/SYS]')
elif message.role == 'tool':
current_content.append(f'[TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}]{messages[i].content}[/TOOL RESULT]')
else:
sys.stderr.write(f'Unexpected message role: {message.role}\n')
current_content.append(f'[ROLE={messages[i].role}]{messages[i].content}[/ROLE]')
if current_content:
flush()
messages = new_messages
result = self._template.render(