allow multiple generations at the same time

This commit is contained in:
Xuan Son Nguyen 2025-02-06 11:18:00 +01:00
parent 518e077a92
commit c8dc8d7f55
6 changed files with 77 additions and 50 deletions

View file

@ -20,10 +20,12 @@ export default function ChatMessage({
msg,
id,
scrollToBottom,
isPending,
}: {
msg: Message | PendingMessage;
id?: string;
scrollToBottom: (requiresNearBottom: boolean) => void;
isPending?: boolean;
}) {
const { viewingConversation, replaceMessageAndGenerate, config } =
useAppContext();
@ -42,8 +44,6 @@ export default function ChatMessage({
[msg.timings]
);
const isPending: boolean = !!(msg as PendingMessage).convId;
// for reasoning model, we split the message into content and thought
// TODO: implement this as remark/rehype plugin in the future
const { content, thought, isThinking }: SplitMessage = useMemo(() => {

View file

@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context';
import StorageUtils from '../utils/storage';
import { useNavigate } from 'react-router';
import ChatMessage from './ChatMessage';
import { PendingMessage } from '../utils/types';
export default function ChatScreen() {
const {
@ -10,12 +11,15 @@ export default function ChatScreen() {
sendMessage,
isGenerating,
stopGenerating,
pendingMessage,
pendingMessages,
} = useAppContext();
const [inputMsg, setInputMsg] = useState('');
const containerRef = useRef<HTMLDivElement>(null);
const navigate = useNavigate();
const currConvId = viewingConversation?.id ?? '';
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
const scrollToBottom = (requiresNearBottom: boolean) => {
if (!containerRef.current) return;
const msgListElem = containerRef.current;
@ -70,14 +74,14 @@ export default function ChatScreen() {
<ChatMessage key={msg.id} msg={msg} scrollToBottom={scrollToBottom} />
))}
{pendingMessage !== null &&
pendingMessage.convId === viewingConversation?.id && (
<ChatMessage
msg={pendingMessage}
scrollToBottom={scrollToBottom}
id="pending-msg"
/>
)}
{pendingMsg && (
<ChatMessage
msg={pendingMsg}
scrollToBottom={scrollToBottom}
isPending
id="pending-msg"
/>
)}
</div>
{/* chat input */}
@ -97,8 +101,11 @@ export default function ChatScreen() {
id="msg-input"
dir="auto"
></textarea>
{isGenerating ? (
<button className="btn btn-neutral ml-2" onClick={stopGenerating}>
{isGenerating(currConvId) ? (
<button
className="btn btn-neutral ml-2"
onClick={() => stopGenerating(currConvId)}
>
Stop
</button>
) : (

View file

@ -27,9 +27,10 @@ export default function Header() {
}, [selectedTheme]);
const { isGenerating, viewingConversation } = useAppContext();
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
const removeConversation = () => {
if (isGenerating || !viewingConversation) return;
if (isCurrConvGenerating || !viewingConversation) return;
const convId = viewingConversation.id;
if (window.confirm('Are you sure to delete this conversation?')) {
StorageUtils.remove(convId);
@ -38,7 +39,7 @@ export default function Header() {
};
const downloadConversation = () => {
if (isGenerating || !viewingConversation) return;
if (isCurrConvGenerating || !viewingConversation) return;
const convId = viewingConversation.id;
const conversationJson = JSON.stringify(viewingConversation, null, 2);
const blob = new Blob([conversationJson], { type: 'application/json' });
@ -81,7 +82,7 @@ export default function Header() {
tabIndex={0}
role="button"
className="btn m-1"
disabled={isGenerating}
disabled={isCurrConvGenerating}
>
<svg
xmlns="http://www.w3.org/2000/svg"
@ -108,11 +109,7 @@ export default function Header() {
</ul>
</div>
<div className="tooltip tooltip-bottom" data-tip="Settings">
<button
className="btn"
disabled={isGenerating}
onClick={() => setShowSettingDialog(true)}
>
<button className="btn" onClick={() => setShowSettingDialog(true)}>
{/* settings button */}
<svg
xmlns="http://www.w3.org/2000/svg"

View file

@ -3,6 +3,7 @@ import { classNames } from '../utils/misc';
import { Conversation } from '../utils/types';
import StorageUtils from '../utils/storage';
import { useNavigate, useParams } from 'react-router';
import { useAppContext } from '../utils/app.context';
export default function Sidebar() {
const params = useParams();

View file

@ -10,15 +10,15 @@ import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
import { matchPath, useLocation } from 'react-router';
interface AppContextValue {
isGenerating: boolean;
viewingConversation: Conversation | null;
pendingMessage: PendingMessage | null;
pendingMessages: Record<Conversation['id'], PendingMessage>;
isGenerating: (convId: string) => boolean;
sendMessage: (
convId: string,
content: string,
onChunk?: CallbackGeneratedChunk
) => Promise<boolean>;
stopGenerating: () => void;
stopGenerating: (convId: string) => void;
replaceMessageAndGenerate: (
convId: string,
origMsgId: Message['id'],
@ -45,13 +45,14 @@ export const AppContextProvider = ({
const params = matchPath('/chat/:convId', pathname);
const convId = params?.params?.convId;
const [isGenerating, setIsGenerating] = useState(false);
const [viewingConversation, setViewingConversation] =
useState<Conversation | null>(null);
const [pendingMessage, setPendingMessage] = useState<PendingMessage | null>(
null
);
const [abortController, setAbortController] = useState(new AbortController());
const [pendingMessages, setPendingMessages] = useState<
Record<Conversation['id'], PendingMessage>
>({});
const [aborts, setAborts] = useState<
Record<Conversation['id'], AbortController>
>({});
const [config, setConfig] = useState(StorageUtils.getConfig());
useEffect(() => {
@ -66,11 +67,41 @@ export const AppContextProvider = ({
};
}, [convId]);
const setPending = (convId: string, pendingMsg: PendingMessage | null) => {
// if pendingMsg is null, remove the key from the object
if (!pendingMsg) {
setPendingMessages((prev) => {
const newState = { ...prev };
delete newState[convId];
return newState;
});
} else {
setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg }));
}
};
const setAbort = (convId: string, controller: AbortController | null) => {
if (!controller) {
setAborts((prev) => {
const newState = { ...prev };
delete newState[convId];
return newState;
});
} else {
setAborts((prev) => ({ ...prev, [convId]: controller }));
}
};
////////////////////////////////////////////////////////////////////////
// public functions
const isGenerating = (convId: string) => !!pendingMessages[convId];
const generateMessage = async (
convId: string,
onChunk?: CallbackGeneratedChunk
) => {
if (isGenerating) return;
if (isGenerating(convId)) return;
const config = StorageUtils.getConfig();
const currConversation = StorageUtils.getOneConversation(convId);
@ -79,16 +110,14 @@ export const AppContextProvider = ({
}
const abortController = new AbortController();
setIsGenerating(true);
setAbortController(abortController);
setAbort(convId, abortController);
let pendingMsg: PendingMessage = {
convId,
id: Date.now() + 1,
role: 'assistant',
content: null,
};
setPendingMessage(pendingMsg);
setPending(convId, pendingMsg);
try {
// prepare messages for API
@ -157,7 +186,6 @@ export const AppContextProvider = ({
const lastContent = pendingMsg.content || '';
if (addedContent) {
pendingMsg = {
convId,
id: pendingMsg.id,
role: 'assistant',
content: lastContent + addedContent,
@ -173,18 +201,15 @@ export const AppContextProvider = ({
predicted_ms: timings.predicted_ms,
};
}
setPendingMessage(pendingMsg);
setPending(convId, pendingMsg);
onChunk?.();
}
} catch (err) {
console.error(err);
setPendingMessage(null);
setIsGenerating(false);
setPending(convId, null);
if ((err as Error).name === 'AbortError') {
// user stopped the generation via stopGeneration() function
// we can safely ignore this error
} else {
setIsGenerating(false);
console.error(err);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
alert((err as any)?.message ?? 'Unknown error');
@ -200,8 +225,7 @@ export const AppContextProvider = ({
timings: pendingMsg.timings,
});
}
setPendingMessage(null);
setIsGenerating(false);
setPending(convId, null);
onChunk?.(); // trigger scroll to bottom
};
@ -210,7 +234,7 @@ export const AppContextProvider = ({
content: string,
onChunk?: CallbackGeneratedChunk
): Promise<boolean> => {
if (isGenerating || content.trim().length === 0) return false;
if (isGenerating(convId) || content.trim().length === 0) return false;
StorageUtils.appendMsg(convId, {
id: Date.now(),
@ -228,10 +252,9 @@ export const AppContextProvider = ({
return false;
};
const stopGenerating = () => {
setIsGenerating(false);
setPendingMessage(null);
abortController.abort();
const stopGenerating = (convId: string) => {
setPending(convId, null);
aborts[convId]?.abort();
};
// if content is undefined, we remove last assistant message
@ -241,7 +264,7 @@ export const AppContextProvider = ({
content?: string,
onChunk?: CallbackGeneratedChunk
) => {
if (isGenerating) return;
if (isGenerating(convId)) return;
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
if (content) {
@ -265,7 +288,7 @@ export const AppContextProvider = ({
value={{
isGenerating,
viewingConversation,
pendingMessage,
pendingMessages,
sendMessage,
stopGenerating,
replaceMessageAndGenerate,

View file

@ -21,6 +21,5 @@ export interface Conversation {
}
export type PendingMessage = Omit<Message, 'content'> & {
convId: string;
content: string | null;
};