allow multiple generations at the same time
This commit is contained in:
parent
518e077a92
commit
c8dc8d7f55
6 changed files with 77 additions and 50 deletions
|
@ -20,10 +20,12 @@ export default function ChatMessage({
|
|||
msg,
|
||||
id,
|
||||
scrollToBottom,
|
||||
isPending,
|
||||
}: {
|
||||
msg: Message | PendingMessage;
|
||||
id?: string;
|
||||
scrollToBottom: (requiresNearBottom: boolean) => void;
|
||||
isPending?: boolean;
|
||||
}) {
|
||||
const { viewingConversation, replaceMessageAndGenerate, config } =
|
||||
useAppContext();
|
||||
|
@ -42,8 +44,6 @@ export default function ChatMessage({
|
|||
[msg.timings]
|
||||
);
|
||||
|
||||
const isPending: boolean = !!(msg as PendingMessage).convId;
|
||||
|
||||
// for reasoning model, we split the message into content and thought
|
||||
// TODO: implement this as remark/rehype plugin in the future
|
||||
const { content, thought, isThinking }: SplitMessage = useMemo(() => {
|
||||
|
|
|
@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context';
|
|||
import StorageUtils from '../utils/storage';
|
||||
import { useNavigate } from 'react-router';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { PendingMessage } from '../utils/types';
|
||||
|
||||
export default function ChatScreen() {
|
||||
const {
|
||||
|
@ -10,12 +11,15 @@ export default function ChatScreen() {
|
|||
sendMessage,
|
||||
isGenerating,
|
||||
stopGenerating,
|
||||
pendingMessage,
|
||||
pendingMessages,
|
||||
} = useAppContext();
|
||||
const [inputMsg, setInputMsg] = useState('');
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const navigate = useNavigate();
|
||||
|
||||
const currConvId = viewingConversation?.id ?? '';
|
||||
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
|
||||
|
||||
const scrollToBottom = (requiresNearBottom: boolean) => {
|
||||
if (!containerRef.current) return;
|
||||
const msgListElem = containerRef.current;
|
||||
|
@ -70,14 +74,14 @@ export default function ChatScreen() {
|
|||
<ChatMessage key={msg.id} msg={msg} scrollToBottom={scrollToBottom} />
|
||||
))}
|
||||
|
||||
{pendingMessage !== null &&
|
||||
pendingMessage.convId === viewingConversation?.id && (
|
||||
<ChatMessage
|
||||
msg={pendingMessage}
|
||||
scrollToBottom={scrollToBottom}
|
||||
id="pending-msg"
|
||||
/>
|
||||
)}
|
||||
{pendingMsg && (
|
||||
<ChatMessage
|
||||
msg={pendingMsg}
|
||||
scrollToBottom={scrollToBottom}
|
||||
isPending
|
||||
id="pending-msg"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* chat input */}
|
||||
|
@ -97,8 +101,11 @@ export default function ChatScreen() {
|
|||
id="msg-input"
|
||||
dir="auto"
|
||||
></textarea>
|
||||
{isGenerating ? (
|
||||
<button className="btn btn-neutral ml-2" onClick={stopGenerating}>
|
||||
{isGenerating(currConvId) ? (
|
||||
<button
|
||||
className="btn btn-neutral ml-2"
|
||||
onClick={() => stopGenerating(currConvId)}
|
||||
>
|
||||
Stop
|
||||
</button>
|
||||
) : (
|
||||
|
|
|
@ -27,9 +27,10 @@ export default function Header() {
|
|||
}, [selectedTheme]);
|
||||
|
||||
const { isGenerating, viewingConversation } = useAppContext();
|
||||
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
|
||||
|
||||
const removeConversation = () => {
|
||||
if (isGenerating || !viewingConversation) return;
|
||||
if (isCurrConvGenerating || !viewingConversation) return;
|
||||
const convId = viewingConversation.id;
|
||||
if (window.confirm('Are you sure to delete this conversation?')) {
|
||||
StorageUtils.remove(convId);
|
||||
|
@ -38,7 +39,7 @@ export default function Header() {
|
|||
};
|
||||
|
||||
const downloadConversation = () => {
|
||||
if (isGenerating || !viewingConversation) return;
|
||||
if (isCurrConvGenerating || !viewingConversation) return;
|
||||
const convId = viewingConversation.id;
|
||||
const conversationJson = JSON.stringify(viewingConversation, null, 2);
|
||||
const blob = new Blob([conversationJson], { type: 'application/json' });
|
||||
|
@ -81,7 +82,7 @@ export default function Header() {
|
|||
tabIndex={0}
|
||||
role="button"
|
||||
className="btn m-1"
|
||||
disabled={isGenerating}
|
||||
disabled={isCurrConvGenerating}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
@ -108,11 +109,7 @@ export default function Header() {
|
|||
</ul>
|
||||
</div>
|
||||
<div className="tooltip tooltip-bottom" data-tip="Settings">
|
||||
<button
|
||||
className="btn"
|
||||
disabled={isGenerating}
|
||||
onClick={() => setShowSettingDialog(true)}
|
||||
>
|
||||
<button className="btn" onClick={() => setShowSettingDialog(true)}>
|
||||
{/* settings button */}
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
|
|
@ -3,6 +3,7 @@ import { classNames } from '../utils/misc';
|
|||
import { Conversation } from '../utils/types';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useNavigate, useParams } from 'react-router';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
|
|
|
@ -10,15 +10,15 @@ import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
|
|||
import { matchPath, useLocation } from 'react-router';
|
||||
|
||||
interface AppContextValue {
|
||||
isGenerating: boolean;
|
||||
viewingConversation: Conversation | null;
|
||||
pendingMessage: PendingMessage | null;
|
||||
pendingMessages: Record<Conversation['id'], PendingMessage>;
|
||||
isGenerating: (convId: string) => boolean;
|
||||
sendMessage: (
|
||||
convId: string,
|
||||
content: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => Promise<boolean>;
|
||||
stopGenerating: () => void;
|
||||
stopGenerating: (convId: string) => void;
|
||||
replaceMessageAndGenerate: (
|
||||
convId: string,
|
||||
origMsgId: Message['id'],
|
||||
|
@ -45,13 +45,14 @@ export const AppContextProvider = ({
|
|||
const params = matchPath('/chat/:convId', pathname);
|
||||
const convId = params?.params?.convId;
|
||||
|
||||
const [isGenerating, setIsGenerating] = useState(false);
|
||||
const [viewingConversation, setViewingConversation] =
|
||||
useState<Conversation | null>(null);
|
||||
const [pendingMessage, setPendingMessage] = useState<PendingMessage | null>(
|
||||
null
|
||||
);
|
||||
const [abortController, setAbortController] = useState(new AbortController());
|
||||
const [pendingMessages, setPendingMessages] = useState<
|
||||
Record<Conversation['id'], PendingMessage>
|
||||
>({});
|
||||
const [aborts, setAborts] = useState<
|
||||
Record<Conversation['id'], AbortController>
|
||||
>({});
|
||||
const [config, setConfig] = useState(StorageUtils.getConfig());
|
||||
|
||||
useEffect(() => {
|
||||
|
@ -66,11 +67,41 @@ export const AppContextProvider = ({
|
|||
};
|
||||
}, [convId]);
|
||||
|
||||
const setPending = (convId: string, pendingMsg: PendingMessage | null) => {
|
||||
// if pendingMsg is null, remove the key from the object
|
||||
if (!pendingMsg) {
|
||||
setPendingMessages((prev) => {
|
||||
const newState = { ...prev };
|
||||
delete newState[convId];
|
||||
return newState;
|
||||
});
|
||||
} else {
|
||||
setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg }));
|
||||
}
|
||||
};
|
||||
|
||||
const setAbort = (convId: string, controller: AbortController | null) => {
|
||||
if (!controller) {
|
||||
setAborts((prev) => {
|
||||
const newState = { ...prev };
|
||||
delete newState[convId];
|
||||
return newState;
|
||||
});
|
||||
} else {
|
||||
setAborts((prev) => ({ ...prev, [convId]: controller }));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// public functions
|
||||
|
||||
const isGenerating = (convId: string) => !!pendingMessages[convId];
|
||||
|
||||
const generateMessage = async (
|
||||
convId: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating) return;
|
||||
if (isGenerating(convId)) return;
|
||||
|
||||
const config = StorageUtils.getConfig();
|
||||
const currConversation = StorageUtils.getOneConversation(convId);
|
||||
|
@ -79,16 +110,14 @@ export const AppContextProvider = ({
|
|||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
setIsGenerating(true);
|
||||
setAbortController(abortController);
|
||||
setAbort(convId, abortController);
|
||||
|
||||
let pendingMsg: PendingMessage = {
|
||||
convId,
|
||||
id: Date.now() + 1,
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
};
|
||||
setPendingMessage(pendingMsg);
|
||||
setPending(convId, pendingMsg);
|
||||
|
||||
try {
|
||||
// prepare messages for API
|
||||
|
@ -157,7 +186,6 @@ export const AppContextProvider = ({
|
|||
const lastContent = pendingMsg.content || '';
|
||||
if (addedContent) {
|
||||
pendingMsg = {
|
||||
convId,
|
||||
id: pendingMsg.id,
|
||||
role: 'assistant',
|
||||
content: lastContent + addedContent,
|
||||
|
@ -173,18 +201,15 @@ export const AppContextProvider = ({
|
|||
predicted_ms: timings.predicted_ms,
|
||||
};
|
||||
}
|
||||
setPendingMessage(pendingMsg);
|
||||
setPending(convId, pendingMsg);
|
||||
onChunk?.();
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
setPendingMessage(null);
|
||||
setIsGenerating(false);
|
||||
setPending(convId, null);
|
||||
if ((err as Error).name === 'AbortError') {
|
||||
// user stopped the generation via stopGeneration() function
|
||||
// we can safely ignore this error
|
||||
} else {
|
||||
setIsGenerating(false);
|
||||
console.error(err);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
alert((err as any)?.message ?? 'Unknown error');
|
||||
|
@ -200,8 +225,7 @@ export const AppContextProvider = ({
|
|||
timings: pendingMsg.timings,
|
||||
});
|
||||
}
|
||||
setPendingMessage(null);
|
||||
setIsGenerating(false);
|
||||
setPending(convId, null);
|
||||
onChunk?.(); // trigger scroll to bottom
|
||||
};
|
||||
|
||||
|
@ -210,7 +234,7 @@ export const AppContextProvider = ({
|
|||
content: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
): Promise<boolean> => {
|
||||
if (isGenerating || content.trim().length === 0) return false;
|
||||
if (isGenerating(convId) || content.trim().length === 0) return false;
|
||||
|
||||
StorageUtils.appendMsg(convId, {
|
||||
id: Date.now(),
|
||||
|
@ -228,10 +252,9 @@ export const AppContextProvider = ({
|
|||
return false;
|
||||
};
|
||||
|
||||
const stopGenerating = () => {
|
||||
setIsGenerating(false);
|
||||
setPendingMessage(null);
|
||||
abortController.abort();
|
||||
const stopGenerating = (convId: string) => {
|
||||
setPending(convId, null);
|
||||
aborts[convId]?.abort();
|
||||
};
|
||||
|
||||
// if content is undefined, we remove last assistant message
|
||||
|
@ -241,7 +264,7 @@ export const AppContextProvider = ({
|
|||
content?: string,
|
||||
onChunk?: CallbackGeneratedChunk
|
||||
) => {
|
||||
if (isGenerating) return;
|
||||
if (isGenerating(convId)) return;
|
||||
|
||||
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
|
||||
if (content) {
|
||||
|
@ -265,7 +288,7 @@ export const AppContextProvider = ({
|
|||
value={{
|
||||
isGenerating,
|
||||
viewingConversation,
|
||||
pendingMessage,
|
||||
pendingMessages,
|
||||
sendMessage,
|
||||
stopGenerating,
|
||||
replaceMessageAndGenerate,
|
||||
|
|
|
@ -21,6 +21,5 @@ export interface Conversation {
|
|||
}
|
||||
|
||||
export type PendingMessage = Omit<Message, 'content'> & {
|
||||
convId: string;
|
||||
content: string | null;
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue