diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py index 6a8fe46c7..c431a7987 100644 --- a/examples/openai/prompting.py +++ b/examples/openai/prompting.py @@ -131,31 +131,40 @@ class ChatTemplate(BaseModel): new_messages=[] i = 0 n = len(messages) - while i < n: - if messages[i].role == 'system': - assert messages[i+1].role == 'user' - new_messages.append(Message( - role="user", - content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}' - )) - i += 2 - elif messages[i].role == 'assistant' and messages[i].tool_calls and messages[i].content: - tc = '\n'.join(f'{json.dumps(tc.model_dump())}' for tc in messages[i].tool_calls) - new_messages.append(Message( - role="assistant", - content=f'{messages[i].content}\n{tc}' - )) - i += 1 - elif messages[i].role == 'tool': - new_messages.append(Message( - role="user", - content=f'TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}): {messages[i].content}', - )) - i += 1 + current_role = 'user' + current_content = [] + + def flush(): + nonlocal current_content + nonlocal current_role + new_messages.append(Message( + role=current_role, + content='\n'.join(current_content) + )) + current_content = [] + + for i, message in enumerate(messages): + if message.role == current_role: + current_content.append(message.content) + elif message.role in ('user', 'assistant'): + flush() + current_role = 'assistant' if current_role == 'user' else 'user' + current_content.append(message.content) else: - new_messages.append(messages[i]) - i += 1 - # print(f'new_messages={json.dumps(new_messages, indent=2)}') + if current_role == 'assistant': + flush() + current_role = 'user' + if message.role == 'system': + current_content.append(f'[SYS]{messages[i].content}[/SYS]') + elif message.role == 'tool': + current_content.append(f'[TOOL RESULT(name={messages[i].name}, id={messages[i].tool_call_id}]{messages[i].content}[/TOOL RESULT]') + else: + sys.stderr.write(f'Unexpected message role: {message.role}\n') + current_content.append(f'[ROLE={messages[i].role}]{messages[i].content}[/ROLE]') + + if current_content: + flush() + messages = new_messages result = self._template.render(