| import { useState, useRef, useEffect, useCallback } from 'react' | |
| import { Send, Eraser, Loader2, X } from 'lucide-react' | |
| import { | |
| ChatMessage, | |
| TextGenerationWorkerInput, | |
| WorkerMessage | |
| } from '../../types' | |
| import { useModel } from '../../contexts/ModelContext' | |
| import { useTextGeneration } from '../../contexts/TextGenerationContext' | |
| function TextGeneration() { | |
| const { config, messages, setMessages } = useTextGeneration() | |
| const [currentMessage, setCurrentMessage] = useState<string>('') | |
| const [prompt, setPrompt] = useState<string>('') | |
| const [generatedText, setGeneratedText] = useState<string>('') | |
| const [isGenerating, setIsGenerating] = useState<boolean>(false) | |
| const { | |
| activeWorker, | |
| status, | |
| modelInfo, | |
| hasBeenLoaded, | |
| selectedQuantization | |
| } = useModel() | |
| const messagesEndRef = useRef<HTMLDivElement>(null) | |
| const scrollToBottom = () => { | |
| messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }) | |
| } | |
| useEffect(() => { | |
| scrollToBottom() | |
| }, [messages, generatedText]) | |
| const stopGeneration = useCallback(() => { | |
| if (activeWorker && isGenerating) { | |
| activeWorker.postMessage({ type: 'stop' }) | |
| setIsGenerating(false) | |
| } | |
| }, [activeWorker, isGenerating]) | |
| const handleSendMessage = useCallback(() => { | |
| if (!currentMessage.trim() || !modelInfo || !activeWorker || isGenerating) | |
| return | |
| const userMessage: ChatMessage = { | |
| role: 'user', | |
| content: currentMessage.trim() | |
| } | |
| const updatedMessages = [...messages, userMessage] | |
| setMessages(updatedMessages) | |
| setCurrentMessage('') | |
| setIsGenerating(true) | |
| const message: TextGenerationWorkerInput = { | |
| type: 'generate', | |
| messages: updatedMessages, | |
| hasChatTemplate: modelInfo.hasChatTemplate, | |
| model: modelInfo.id, | |
| dtype: selectedQuantization ?? 'fp32', | |
| config | |
| } | |
| activeWorker.postMessage(message) | |
| }, [ | |
| currentMessage, | |
| messages, | |
| setMessages, | |
| modelInfo, | |
| activeWorker, | |
| config, | |
| isGenerating, | |
| selectedQuantization | |
| ]) | |
| const handleGenerateText = useCallback(() => { | |
| if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) return | |
| setIsGenerating(true) | |
| const message: TextGenerationWorkerInput = { | |
| type: 'generate', | |
| prompt: prompt.trim(), | |
| hasChatTemplate: modelInfo.hasChatTemplate, | |
| model: modelInfo.id, | |
| config, | |
| dtype: selectedQuantization ?? 'fp32' | |
| } | |
| activeWorker.postMessage(message) | |
| }, [ | |
| prompt, | |
| modelInfo, | |
| activeWorker, | |
| config, | |
| isGenerating, | |
| selectedQuantization | |
| ]) | |
| useEffect(() => { | |
| if (!activeWorker) return | |
| const onMessageReceived = (e: MessageEvent<WorkerMessage>) => { | |
| const { status, output } = e.data | |
| if (status === 'output' && output) { | |
| setIsGenerating(false) | |
| if (modelInfo?.hasChatTemplate) { | |
| const assistantMessage: ChatMessage = { | |
| role: 'assistant', | |
| content: output.content | |
| } | |
| setMessages((prev) => [...prev, assistantMessage]) | |
| } else { | |
| setGeneratedText(output.content) | |
| } | |
| } else if (status === 'ready' || status === 'error') { | |
| setIsGenerating(false) | |
| } | |
| } | |
| activeWorker.addEventListener('message', onMessageReceived) | |
| return () => activeWorker.removeEventListener('message', onMessageReceived) | |
| }, [activeWorker, modelInfo?.hasChatTemplate, setMessages]) | |
| const handleKeyPress = (e: React.KeyboardEvent) => { | |
| if (e.key === 'Enter' && !e.shiftKey) { | |
| e.preventDefault() | |
| if (modelInfo?.hasChatTemplate) { | |
| handleSendMessage() | |
| } else { | |
| handleGenerateText() | |
| } | |
| } | |
| } | |
| const clearChat = () => { | |
| if (modelInfo?.hasChatTemplate) { | |
| setMessages((prev) => prev.filter((msg) => msg.role === 'system')) | |
| } else { | |
| setPrompt('') | |
| setGeneratedText('') | |
| } | |
| } | |
| const busy = status !== 'ready' || isGenerating | |
| const hasChatTemplate = modelInfo?.hasChatTemplate | |
| return ( | |
| <div className="flex flex-col min-h-[30dvh] max-h-[calc(100dvh-128px)] w-full p-4"> | |
| <div className="flex items-center justify-between mb-4"> | |
| <h1 className="text-2xl font-bold"> | |
| Text Generation {hasChatTemplate ? '(Chat)' : ''} | |
| </h1> | |
| <div className="flex gap-2"> | |
| <button | |
| onClick={clearChat} | |
| className="p-2 bg-red-100 hover:bg-red-200 rounded-lg transition-colors" | |
| title={hasChatTemplate ? 'Clear Chat' : 'Clear Text'} | |
| > | |
| <Eraser className="w-4 h-4" /> | |
| </button> | |
| {isGenerating && ( | |
| <button | |
| onClick={stopGeneration} | |
| className="p-2 bg-orange-100 hover:bg-orange-200 rounded-lg transition-colors" | |
| title="Stop Generation" | |
| > | |
| <X className="w-4 h-4" /> | |
| </button> | |
| )} | |
| </div> | |
| </div> | |
| {hasChatTemplate ? ( | |
| <> | |
| <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 mb-4 bg-white"> | |
| <div className="space-y-4"> | |
| {messages | |
| .filter((msg) => msg.role !== 'system') | |
| .map((message, index) => ( | |
| <div | |
| key={index} | |
| className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`} | |
| > | |
| <div | |
| className={`max-w-[80%] p-3 rounded-lg ${ | |
| message.role === 'user' | |
| ? 'bg-blue-500 text-white' | |
| : 'bg-gray-100 text-gray-800' | |
| }`} | |
| > | |
| <div className="text-xs font-medium mb-1 opacity-70"> | |
| {message.role === 'user' ? 'You' : 'Assistant'} | |
| </div> | |
| <div className="whitespace-pre-wrap"> | |
| {message.content} | |
| </div> | |
| </div> | |
| </div> | |
| ))} | |
| {isGenerating && ( | |
| <div className="flex justify-start"> | |
| <div className="bg-gray-100 text-gray-800 p-3 rounded-lg"> | |
| <div className="text-xs font-medium mb-1 opacity-70"> | |
| Assistant | |
| </div> | |
| <div className="flex items-center space-x-2"> | |
| <Loader2 className="w-4 h-4 animate-spin" /> | |
| <div>Loading...</div> | |
| </div> | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| <div ref={messagesEndRef} /> | |
| </div> | |
| <div className="flex gap-2"> | |
| <textarea | |
| value={currentMessage} | |
| onChange={(e) => setCurrentMessage(e.target.value)} | |
| onKeyPress={handleKeyPress} | |
| placeholder="Type your message... (Press Enter to send, Shift+Enter for new line)" | |
| className="flex-1 p-3 border border-gray-300 rounded-lg resize-none focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed" | |
| rows={2} | |
| disabled={!hasBeenLoaded || isGenerating} | |
| /> | |
| <button | |
| onClick={handleSendMessage} | |
| disabled={!currentMessage.trim() || busy || !hasBeenLoaded} | |
| className="px-4 py-2 bg-blue-500 hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center justify-center" | |
| > | |
| {isGenerating ? ( | |
| <Loader2 className="w-4 h-4 animate-spin" /> | |
| ) : ( | |
| <Send className="w-4 h-4" /> | |
| )} | |
| </button> | |
| </div> | |
| </> | |
| ) : ( | |
| <> | |
| <div className="mb-4"> | |
| <label className="block text-sm font-medium text-gray-700 mb-2"> | |
| Enter your prompt: | |
| </label> | |
| <textarea | |
| value={prompt} | |
| onChange={(e) => setPrompt(e.target.value)} | |
| onKeyPress={handleKeyPress} | |
| placeholder="Enter your text prompt here... (Press Enter to generate, Shift+Enter for new line)" | |
| className="w-full p-3 border border-gray-300 rounded-lg resize-none focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed" | |
| rows={4} | |
| disabled={!hasBeenLoaded || isGenerating} | |
| /> | |
| </div> | |
| <div className="mb-4"> | |
| <button | |
| onClick={handleGenerateText} | |
| disabled={!prompt.trim() || busy || !hasBeenLoaded} | |
| className="px-6 py-2 bg-green-500 hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center gap-2" | |
| > | |
| {isGenerating ? ( | |
| <> | |
| <Loader2 className="w-4 h-4 animate-spin" /> | |
| Generating... | |
| </> | |
| ) : ( | |
| <> | |
| <Send className="w-4 h-4" /> | |
| Generate Text | |
| </> | |
| )} | |
| </button> | |
| </div> | |
| <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 bg-white"> | |
| <div className="mb-2"> | |
| <label className="block text-sm font-medium text-gray-700"> | |
| Generated Text: | |
| </label> | |
| </div> | |
| {generatedText ? ( | |
| <div className="whitespace-pre-wrap text-gray-800 bg-gray-50 p-3 rounded-sm border"> | |
| {generatedText} | |
| </div> | |
| ) : ( | |
| <div className="text-gray-500 italic flex items-center gap-2"> | |
| {isGenerating ? ( | |
| <> | |
| <Loader2 className="w-4 h-4 animate-spin" /> | |
| Generating text... | |
| </> | |
| ) : ( | |
| 'Generated text will appear here' | |
| )} | |
| </div> | |
| )} | |
| <div ref={messagesEndRef} /> | |
| </div> | |
| </> | |
| )} | |
| {!hasBeenLoaded && ( | |
| <div className="text-center text-gray-500 text-sm mt-2"> | |
| Please load a model first to start{' '} | |
| {hasChatTemplate ? 'chatting' : 'generating text'} | |
| </div> | |
| )} | |
| </div> | |
| ) | |
| } | |
| export default TextGeneration | |