openai: fix message merging for mixtral (parallel calls)

This commit is contained in:
ochafik 2024-03-29 17:01:20 +00:00
parent ea34bd3e5c
commit 80c793047b

View file

@ -131,31 +131,40 @@ class ChatTemplate(BaseModel):
new_messages=[] new_messages=[]
i = 0 i = 0
n = len(messages) n = len(messages)
while i < n: current_role = 'user'
if messages[i].role == 'system': current_content = []
assert messages[i+1].role == 'user'
new_messages.append(Message( def flush():
role="user", nonlocal current_content
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}' nonlocal current_role
)) new_messages.append(Message(
i += 2 role=current_role,
elif messages[i].role == 'assistant' and messages[i].tool_calls and messages[i].content: content='\n'.join(current_content)
tc = '\n'.join(f'<tool_call>{json.dumps(tc.model_dump())}</tool_call>' for tc in messages[i].tool_calls) ))
new_messages.append(Message( current_content = []
role="assistant",
content=f'{messages[i].content}\n{tc}' for i, message in enumerate(messages):
)) if message.role == current_role:
i += 1 current_content.append(message.content)
elif messages[i].role == 'tool': elif message.role in ('user', 'assistant'):
new_messages.append(Message( flush()
role="user", current_role = 'assistant' if current_role == 'user' else 'user'
content=f'TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}): {messages[i].content}', current_content.append(message.content)
))
i += 1
else: else:
new_messages.append(messages[i]) if current_role == 'assistant':
i += 1 flush()
# print(f'new_messages={json.dumps(new_messages, indent=2)}') current_role = 'user'
if message.role == 'system':
current_content.append(f'[SYS]{messages[i].content}[/SYS]')
elif message.role == 'tool':
current_content.append(f'[TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}]{messages[i].content}[/TOOL RESULT]')
else:
sys.stderr.write(f'Unexpected message role: {message.role}\n')
current_content.append(f'[ROLE={messages[i].role}]{messages[i].content}[/ROLE]')
if current_content:
flush()
messages = new_messages messages = new_messages
result = self._template.render( result = self._template.render(