add more comments

This commit is contained in:
Xuan Son Nguyen 2025-02-10 15:24:25 +01:00
parent d83a8e3731
commit 2eef3e7db4
6 changed files with 72 additions and 39 deletions

Binary file not shown.

View file

@ -13,7 +13,7 @@ interface SplitMessage {
export default function ChatMessage({ export default function ChatMessage({
msg, msg,
siblingLastNodeIds, siblingLeafNodeIds,
siblingCurrIdx, siblingCurrIdx,
id, id,
onRegenerateMessage, onRegenerateMessage,
@ -22,7 +22,7 @@ export default function ChatMessage({
isPending, isPending,
}: { }: {
msg: Message | PendingMessage; msg: Message | PendingMessage;
siblingLastNodeIds: Message['id'][]; siblingLeafNodeIds: Message['id'][];
siblingCurrIdx: number; siblingCurrIdx: number;
id?: string; id?: string;
onRegenerateMessage(msg: Message): void; onRegenerateMessage(msg: Message): void;
@ -45,8 +45,8 @@ export default function ChatMessage({
: null, : null,
[msg.timings] [msg.timings]
); );
const nextSibling = siblingLastNodeIds[siblingCurrIdx + 1]; const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
const prevSibling = siblingLastNodeIds[siblingCurrIdx - 1]; const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
// for reasoning model, we split the message into content and thought // for reasoning model, we split the message into content and thought
// TODO: implement this as remark/rehype plugin in the future // TODO: implement this as remark/rehype plugin in the future
@ -203,7 +203,7 @@ export default function ChatMessage({
'flex-row-reverse': msg.role === 'user', 'flex-row-reverse': msg.role === 'user',
})} })}
> >
{siblingLastNodeIds && siblingLastNodeIds.length > 1 && ( {siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
<div className="flex gap-1 items-center opacity-60 text-sm"> <div className="flex gap-1 items-center opacity-60 text-sm">
<button <button
className={classNames({ className={classNames({
@ -215,7 +215,7 @@ export default function ChatMessage({
<ChevronLeftIcon className="h-4 w-4" /> <ChevronLeftIcon className="h-4 w-4" />
</button> </button>
<span> <span>
{siblingCurrIdx + 1} / {siblingLastNodeIds.length} {siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
</span> </span>
<button <button
className={classNames({ className={classNames({

View file

@ -6,24 +6,29 @@ import { classNames, throttle } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter'; import CanvasPyInterpreter from './CanvasPyInterpreter';
import StorageUtils from '../utils/storage'; import StorageUtils from '../utils/storage';
/**
* A message display is a message node with additional information for rendering.
* For example, siblings of the message node are stored as their last node (aka leaf node).
*/
export interface MessageDisplay { export interface MessageDisplay {
msg: Message | PendingMessage; msg: Message | PendingMessage;
siblingLastNodeIds: Message['id'][]; siblingLeafNodeIds: Message['id'][];
siblingCurrIdx: number; siblingCurrIdx: number;
isPending?: boolean; isPending?: boolean;
} }
function getListMessageDisplay( function getListMessageDisplay(
msgs: Readonly<Message[]>, msgs: Readonly<Message[]>,
lastNodeId: Message['id'] leafNodeId: Message['id']
): MessageDisplay[] { ): MessageDisplay[] {
const currNodes = StorageUtils.filterByLastNodeId(msgs, lastNodeId, true); const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
const res: MessageDisplay[] = []; const res: MessageDisplay[] = [];
const nodeMap = new Map<Message['id'], Message>(); const nodeMap = new Map<Message['id'], Message>();
for (const msg of msgs) { for (const msg of msgs) {
nodeMap.set(msg.id, msg); nodeMap.set(msg.id, msg);
} }
const findLastNode = (msgId: Message['id']): Message['id'] => { // find leaf node from a message node
const findLeafNode = (msgId: Message['id']): Message['id'] => {
let currNode: Message | undefined = nodeMap.get(msgId); let currNode: Message | undefined = nodeMap.get(msgId);
while (currNode) { while (currNode) {
if (currNode.children.length === 0) break; if (currNode.children.length === 0) break;
@ -39,7 +44,7 @@ function getListMessageDisplay(
if (msg.type !== 'root') { if (msg.type !== 'root') {
res.push({ res.push({
msg, msg,
siblingLastNodeIds: siblings.map(findLastNode), siblingLeafNodeIds: siblings.map(findLeafNode),
siblingCurrIdx: siblings.indexOf(msg.id), siblingCurrIdx: siblings.indexOf(msg.id),
}); });
} }
@ -77,11 +82,12 @@ export default function ChatScreen() {
} = useAppContext(); } = useAppContext();
const [inputMsg, setInputMsg] = useState(''); const [inputMsg, setInputMsg] = useState('');
// keep track of leaf node for rendering
const [currNodeId, setCurrNodeId] = useState<number>(-1); const [currNodeId, setCurrNodeId] = useState<number>(-1);
const messages: MessageDisplay[] = useMemo(() => { const messages: MessageDisplay[] = useMemo(() => {
if (!viewingChat) return []; if (!viewingChat) return [];
else return getListMessageDisplay(viewingChat.messages, currNodeId); else return getListMessageDisplay(viewingChat.messages, currNodeId);
}, [currNodeId, viewingChat?.messages]); }, [currNodeId, viewingChat]);
const currConvId = viewingChat?.conv.id ?? null; const currConvId = viewingChat?.conv.id ?? null;
const pendingMsg: PendingMessage | undefined = const pendingMsg: PendingMessage | undefined =
@ -94,7 +100,10 @@ export default function ChatScreen() {
scrollToBottom(false, 1); scrollToBottom(false, 1);
}, [currConvId]); }, [currConvId]);
const onChunk: CallbackGeneratedChunk = () => { const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
if (currLeafNodeId) {
setCurrNodeId(currLeafNodeId);
}
scrollToBottom(true); scrollToBottom(true);
}; };
@ -141,12 +150,14 @@ export default function ChatScreen() {
}; };
const hasCanvas = !!canvasData; const hasCanvas = !!canvasData;
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
const pendingMsgDisplay: MessageDisplay[] = const pendingMsgDisplay: MessageDisplay[] =
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
? [ ? [
{ {
msg: pendingMsg, msg: pendingMsg,
siblingLastNodeIds: [], siblingLeafNodeIds: [],
siblingCurrIdx: 0, siblingCurrIdx: 0,
isPending: true, isPending: true,
}, },
@ -178,7 +189,7 @@ export default function ChatScreen() {
<ChatMessage <ChatMessage
key={msg.msg.id} key={msg.msg.id}
msg={msg.msg} msg={msg.msg}
siblingLastNodeIds={msg.siblingLastNodeIds} siblingLeafNodeIds={msg.siblingLeafNodeIds}
siblingCurrIdx={msg.siblingCurrIdx} siblingCurrIdx={msg.siblingCurrIdx}
onRegenerateMessage={handleRegenerateMessage} onRegenerateMessage={handleRegenerateMessage}
onEditMessage={handleEditMessage} onEditMessage={handleEditMessage}

View file

@ -23,7 +23,7 @@ interface AppContextValue {
isGenerating: (convId: string) => boolean; isGenerating: (convId: string) => boolean;
sendMessage: ( sendMessage: (
convId: string | null, convId: string | null,
lastNodeId: Message['id'] | null, leafNodeId: Message['id'] | null,
content: string, content: string,
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
) => Promise<boolean>; ) => Promise<boolean>;
@ -47,7 +47,7 @@ interface AppContextValue {
} }
// this callback is used for scrolling to the bottom of the chat and switching to the last node // this callback is used for scrolling to the bottom of the chat and switching to the last node
export type CallbackGeneratedChunk = () => void; export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
const AppContext = createContext<AppContextValue>({} as any); const AppContext = createContext<AppContextValue>({} as any);
@ -130,7 +130,7 @@ export const AppContextProvider = ({
const generateMessage = async ( const generateMessage = async (
convId: string, convId: string,
lastNodeId: Message['id'], leafNodeId: Message['id'],
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
) => { ) => {
if (isGenerating(convId)) return; if (isGenerating(convId)) return;
@ -141,9 +141,9 @@ export const AppContextProvider = ({
throw new Error('Current conversation is not found'); throw new Error('Current conversation is not found');
} }
const currMessages = StorageUtils.filterByLastNodeId( const currMessages = StorageUtils.filterByLeafNodeId(
await StorageUtils.getMessages(convId), await StorageUtils.getMessages(convId),
lastNodeId, leafNodeId,
false false
); );
const abortController = new AbortController(); const abortController = new AbortController();
@ -161,7 +161,7 @@ export const AppContextProvider = ({
timestamp: pendingId, timestamp: pendingId,
role: 'assistant', role: 'assistant',
content: null, content: null,
parent: lastNodeId, parent: leafNodeId,
children: [], children: [],
}; };
setPending(convId, pendingMsg); setPending(convId, pendingMsg);
@ -264,26 +264,26 @@ export const AppContextProvider = ({
} }
if (pendingMsg.content !== null) { if (pendingMsg.content !== null) {
await StorageUtils.appendMsg(pendingMsg as Message, lastNodeId); await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
} }
setPending(convId, null); setPending(convId, null);
onChunk(); // trigger scroll to bottom and switch to the last node onChunk(pendingId); // trigger scroll to bottom and switch to the last node
}; };
const sendMessage = async ( const sendMessage = async (
convId: string | null, convId: string | null,
lastNodeId: Message['id'] | null, leafNodeId: Message['id'] | null,
content: string, content: string,
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
): Promise<boolean> => { ): Promise<boolean> => {
if (isGenerating(convId ?? '') || content.trim().length === 0) return false; if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
if (convId === null || convId.length === 0 || lastNodeId === null) { if (convId === null || convId.length === 0 || leafNodeId === null) {
const conv = await StorageUtils.createConversation( const conv = await StorageUtils.createConversation(
content.substring(0, 256) content.substring(0, 256)
); );
convId = conv.id; convId = conv.id;
lastNodeId = conv.currNode; leafNodeId = conv.currNode;
// if user is creating a new conversation, redirect to the new conversation // if user is creating a new conversation, redirect to the new conversation
navigate(`/chat/${convId}`); navigate(`/chat/${convId}`);
} }
@ -298,12 +298,12 @@ export const AppContextProvider = ({
convId, convId,
role: 'user', role: 'user',
content, content,
parent: lastNodeId, parent: leafNodeId,
children: [], children: [],
}, },
lastNodeId leafNodeId
); );
onChunk(); onChunk(currMsgId);
try { try {
await generateMessage(convId, currMsgId, onChunk); await generateMessage(convId, currMsgId, onChunk);
@ -346,8 +346,8 @@ export const AppContextProvider = ({
); );
parentNodeId = currMsgId; parentNodeId = currMsgId;
} }
onChunk(parentNodeId);
onChunk();
await generateMessage(convId, parentNodeId, onChunk); await generateMessage(convId, parentNodeId, onChunk);
}; };

View file

@ -48,17 +48,17 @@ const StorageUtils = {
return (await db.conversations.where('id').equals(convId).first()) ?? null; return (await db.conversations.where('id').equals(convId).first()) ?? null;
}, },
/** /**
* get messages by convId and timeline * get all message nodes in a conversation
*/ */
async getMessages(convId: string): Promise<Message[]> { async getMessages(convId: string): Promise<Message[]> {
return await db.messages.where({ convId }).toArray(); return await db.messages.where({ convId }).toArray();
}, },
/** /**
* use in conjunction with getMessages to filter messages by lastNodeId * use in conjunction with getMessages to filter messages by leafNodeId
*/ */
filterByLastNodeId( filterByLeafNodeId(
msgs: Readonly<Message[]>, msgs: Readonly<Message[]>,
lastNodeId: Message['id'], leafNodeId: Message['id'],
includeRoot: boolean includeRoot: boolean
): Readonly<Message[]> { ): Readonly<Message[]> {
const res: Message[] = []; const res: Message[] = [];
@ -66,7 +66,7 @@ const StorageUtils = {
for (const msg of msgs) { for (const msg of msgs) {
nodeMap.set(msg.id, msg); nodeMap.set(msg.id, msg);
} }
let startNode: Message | undefined = nodeMap.get(lastNodeId); let startNode: Message | undefined = nodeMap.get(leafNodeId);
if (!startNode) { if (!startNode) {
// if not found, we return the path with the latest timestamp // if not found, we return the path with the latest timestamp
let latestTime = -1; let latestTime = -1;
@ -77,7 +77,7 @@ const StorageUtils = {
} }
} }
} }
// traverse the path from lastNodeId to root // traverse the path from leafNodeId to root
// startNode can never be undefined here // startNode can never be undefined here
let currNode: Message | undefined = startNode; let currNode: Message | undefined = startNode;
while (currNode) { while (currNode) {
@ -89,7 +89,7 @@ const StorageUtils = {
return res; return res;
}, },
/** /**
* create a new conversation with a default timeline number 0 * create a new conversation with a default root node
*/ */
async createConversation(name: string): Promise<Conversation> { async createConversation(name: string): Promise<Conversation> {
const now = Date.now(); const now = Date.now();

View file

@ -7,11 +7,33 @@ export interface TimingReport {
/** /**
* What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow. * What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow.
* Inspired by ChatGPT UI where you edit a message, a new branch of the conversation is created, and the old message is still visible. * Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible.
* *
* We use the same node based structure as ChatGPT, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI. * We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI.
*
* root
* message 1
* message 2
* message 3
* message 4
* message 5
*
* In the above example, assuming that user wants to edit message 2, a new branch will be created:
*
* message 2
* message 3
* message 6
*
* Message 2 and 6 are siblings, and message 6 is the new branch.
*
* We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5.
*
* For the implementation:
* - StorageUtils.getMessages() returns list of all nodes
* - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node
*/ */
// Note: the term "message" and "node" are used interchangeably in this context
export interface Message { export interface Message {
id: number; id: number;
convId: string; convId: string;