openai: fix message merging for mixtral (parallel calls)

This commit is contained in:
ochafik 2024-03-29 17:01:20 +00:00
parent ea34bd3e5c
commit 80c793047b

View file

@ -131,31 +131,40 @@ class ChatTemplate(BaseModel):
new_messages=[] new_messages=[]
i = 0 i = 0
n = len(messages) n = len(messages)
while i < n: current_role = 'user'
if messages[i].role == 'system': current_content = []
assert messages[i+1].role == 'user'
def flush():
nonlocal current_content
nonlocal current_role
new_messages.append(Message( new_messages.append(Message(
role="user", role=current_role,
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}' content='\n'.join(current_content)
)) ))
i += 2 current_content = []
elif messages[i].role == 'assistant' and messages[i].tool_calls and messages[i].content:
tc = '\n'.join(f'<tool_call>{json.dumps(tc.model_dump())}</tool_call>' for tc in messages[i].tool_calls) for i, message in enumerate(messages):
new_messages.append(Message( if message.role == current_role:
role="assistant", current_content.append(message.content)
content=f'{messages[i].content}\n{tc}' elif message.role in ('user', 'assistant'):
)) flush()
i += 1 current_role = 'assistant' if current_role == 'user' else 'user'
elif messages[i].role == 'tool': current_content.append(message.content)
new_messages.append(Message(
role="user",
content=f'TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}): {messages[i].content}',
))
i += 1
else: else:
new_messages.append(messages[i]) if current_role == 'assistant':
i += 1 flush()
# print(f'new_messages={json.dumps(new_messages, indent=2)}') current_role = 'user'
if message.role == 'system':
current_content.append(f'[SYS]{messages[i].content}[/SYS]')
elif message.role == 'tool':
current_content.append(f'[TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}]{messages[i].content}[/TOOL RESULT]')
else:
sys.stderr.write(f'Unexpected message role: {message.role}\n')
current_content.append(f'[ROLE={messages[i].role}]{messages[i].content}[/ROLE]')
if current_content:
flush()
messages = new_messages messages = new_messages
result = self._template.render( result = self._template.render(