Add ModelContext for managing model loading state and progress
Browse files- src/App.tsx +48 -0
- src/components/TextClassification.tsx +5 -5
- src/components/ZeroShotClassification.tsx +4 -5
- src/contexts/ModelContext.tsx +39 -0
- src/index.tsx +4 -0
src/App.tsx
CHANGED
|
@@ -4,9 +4,11 @@ import ZeroShotClassification from './components/ZeroShotClassification';
|
|
| 4 |
import TextClassification from './components/TextClassification';
|
| 5 |
import Header from './Header';
|
| 6 |
import Footer from './Footer';
|
|
|
|
| 7 |
|
| 8 |
function App() {
|
| 9 |
const [pipeline, setPipeline] = useState('zero-shot-classification');
|
|
|
|
| 10 |
|
| 11 |
return (
|
| 12 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
|
@@ -21,6 +23,47 @@ function App() {
|
|
| 21 |
</h2>
|
| 22 |
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
{/* Pipeline Description */}
|
| 25 |
<div className="mt-4 p-4 bg-gray-50 rounded-lg">
|
| 26 |
<div className="flex items-start space-x-3">
|
|
@@ -42,6 +85,11 @@ function App() {
|
|
| 42 |
{pipeline === 'zero-shot-classification'
|
| 43 |
? 'Zero-Shot Classification'
|
| 44 |
: 'Text-Classification'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
</h3>
|
| 46 |
<p className="text-sm text-gray-600 mt-1">
|
| 47 |
{pipeline === 'zero-shot-classification'
|
|
|
|
| 4 |
import TextClassification from './components/TextClassification';
|
| 5 |
import Header from './Header';
|
| 6 |
import Footer from './Footer';
|
| 7 |
+
import { useModel } from './contexts/ModelContext';
|
| 8 |
|
| 9 |
function App() {
|
| 10 |
const [pipeline, setPipeline] = useState('zero-shot-classification');
|
| 11 |
+
const { progress, status, model } = useModel();
|
| 12 |
|
| 13 |
return (
|
| 14 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
|
|
|
| 23 |
</h2>
|
| 24 |
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
| 25 |
|
| 26 |
+
{/* Model Loading Progress */}
|
| 27 |
+
{status === 'progress' && (
|
| 28 |
+
<div className="mt-4 p-4 bg-blue-50 rounded-lg">
|
| 29 |
+
<div className="flex items-center space-x-3">
|
| 30 |
+
<div className="flex-shrink-0">
|
| 31 |
+
<svg
|
| 32 |
+
className="animate-spin h-5 w-5 text-blue-500"
|
| 33 |
+
fill="none"
|
| 34 |
+
viewBox="0 0 24 24"
|
| 35 |
+
>
|
| 36 |
+
<circle
|
| 37 |
+
className="opacity-25"
|
| 38 |
+
cx="12"
|
| 39 |
+
cy="12"
|
| 40 |
+
r="10"
|
| 41 |
+
stroke="currentColor"
|
| 42 |
+
strokeWidth="4"
|
| 43 |
+
></circle>
|
| 44 |
+
<path
|
| 45 |
+
className="opacity-75"
|
| 46 |
+
fill="currentColor"
|
| 47 |
+
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
| 48 |
+
></path>
|
| 49 |
+
</svg>
|
| 50 |
+
</div>
|
| 51 |
+
<div className="flex-1">
|
| 52 |
+
<p className="text-sm font-medium text-blue-900">
|
| 53 |
+
Loading Model...
|
| 54 |
+
</p>
|
| 55 |
+
<div className="mt-2 bg-blue-200 rounded-full h-2">
|
| 56 |
+
<div
|
| 57 |
+
className="bg-blue-500 h-2 rounded-full transition-all duration-300"
|
| 58 |
+
style={{ width: `${progress.toFixed(2)}%` }}
|
| 59 |
+
></div>
|
| 60 |
+
</div>
|
| 61 |
+
<p className="text-xs text-blue-700 mt-1">{progress.toFixed(2)}%</p>
|
| 62 |
+
</div>
|
| 63 |
+
</div>
|
| 64 |
+
</div>
|
| 65 |
+
)}
|
| 66 |
+
|
| 67 |
{/* Pipeline Description */}
|
| 68 |
<div className="mt-4 p-4 bg-gray-50 rounded-lg">
|
| 69 |
<div className="flex items-start space-x-3">
|
|
|
|
| 85 |
{pipeline === 'zero-shot-classification'
|
| 86 |
? 'Zero-Shot Classification'
|
| 87 |
: 'Text-Classification'}
|
| 88 |
+
{model && (
|
| 89 |
+
<span className="ml-2 text-xs text-gray-500 font-normal">
|
| 90 |
+
({model})
|
| 91 |
+
</span>
|
| 92 |
+
)}
|
| 93 |
</h3>
|
| 94 |
<p className="text-sm text-gray-600 mt-1">
|
| 95 |
{pipeline === 'zero-shot-classification'
|
src/components/TextClassification.tsx
CHANGED
|
@@ -4,6 +4,8 @@ import {
|
|
| 4 |
TextClassificationWorkerInput,
|
| 5 |
WorkerMessage
|
| 6 |
} from '../types';
|
|
|
|
|
|
|
| 7 |
|
| 8 |
const PLACEHOLDER_TEXTS: string[] = [
|
| 9 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
@@ -21,8 +23,8 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
| 21 |
function TextClassification() {
|
| 22 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
|
| 23 |
const [results, setResults] = useState<ClassificationOutput[]>([]);
|
| 24 |
-
const
|
| 25 |
-
|
| 26 |
|
| 27 |
// Create a reference to the worker object.
|
| 28 |
const worker = useRef<Worker | null>(null);
|
|
@@ -54,6 +56,7 @@ function TextClassification() {
|
|
| 54 |
)
|
| 55 |
setProgress(e.data.output.progress);
|
| 56 |
} else if (status === 'output') {
|
|
|
|
| 57 |
const result = e.data.output!;
|
| 58 |
setResults((prevResults) => [...prevResults, result]);
|
| 59 |
console.log(result);
|
|
@@ -111,9 +114,6 @@ function TextClassification() {
|
|
| 111 |
? 'Model loading...'
|
| 112 |
: 'Processing...'}
|
| 113 |
</button>
|
| 114 |
-
{status === 'progress' && (
|
| 115 |
-
<div className="text-sm font-medium">{progress}%</div>
|
| 116 |
-
)}
|
| 117 |
<button
|
| 118 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
| 119 |
onClick={handleClear}
|
|
|
|
| 4 |
TextClassificationWorkerInput,
|
| 5 |
WorkerMessage
|
| 6 |
} from '../types';
|
| 7 |
+
import { useModel } from '../contexts/ModelContext';
|
| 8 |
+
|
| 9 |
|
| 10 |
const PLACEHOLDER_TEXTS: string[] = [
|
| 11 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
|
|
| 23 |
function TextClassification() {
|
| 24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
|
| 25 |
const [results, setResults] = useState<ClassificationOutput[]>([]);
|
| 26 |
+
const { setProgress, status, setStatus, setModel } = useModel();
|
| 27 |
+
setModel('Xenova/bert-base-multilingual-uncased-sentiment')
|
| 28 |
|
| 29 |
// Create a reference to the worker object.
|
| 30 |
const worker = useRef<Worker | null>(null);
|
|
|
|
| 56 |
)
|
| 57 |
setProgress(e.data.output.progress);
|
| 58 |
} else if (status === 'output') {
|
| 59 |
+
setStatus('output');
|
| 60 |
const result = e.data.output!;
|
| 61 |
setResults((prevResults) => [...prevResults, result]);
|
| 62 |
console.log(result);
|
|
|
|
| 114 |
? 'Model loading...'
|
| 115 |
: 'Processing...'}
|
| 116 |
</button>
|
|
|
|
|
|
|
|
|
|
| 117 |
<button
|
| 118 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
| 119 |
onClick={handleClear}
|
src/components/ZeroShotClassification.tsx
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
// src/App.tsx
|
| 2 |
import { useState, useRef, useEffect, useCallback } from 'react';
|
| 3 |
import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
|
|
|
|
| 4 |
|
| 5 |
const PLACEHOLDER_REVIEWS: string[] = [
|
| 6 |
// battery/charging problems
|
|
@@ -44,8 +45,8 @@ function ZeroShotClassification() {
|
|
| 44 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
| 45 |
);
|
| 46 |
|
| 47 |
-
const
|
| 48 |
-
|
| 49 |
|
| 50 |
// Create a reference to the worker object.
|
| 51 |
const worker = useRef<Worker | null>(null);
|
|
@@ -77,6 +78,7 @@ function ZeroShotClassification() {
|
|
| 77 |
)
|
| 78 |
setProgress(e.data.output.progress);
|
| 79 |
} else if (status === 'output') {
|
|
|
|
| 80 |
const { sequence, labels, scores } = e.data.output!;
|
| 81 |
|
| 82 |
// Threshold for classification
|
|
@@ -175,9 +177,6 @@ function ZeroShotClassification() {
|
|
| 175 |
? 'Model loading...'
|
| 176 |
: 'Processing'}
|
| 177 |
</button>
|
| 178 |
-
{status === 'progress' && (
|
| 179 |
-
<div className="text-sm font-medium">{progress}%</div>
|
| 180 |
-
)}
|
| 181 |
<div className="flex gap-1">
|
| 182 |
<button
|
| 183 |
className="border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium cursor-pointer"
|
|
|
|
| 1 |
// src/App.tsx
|
| 2 |
import { useState, useRef, useEffect, useCallback } from 'react';
|
| 3 |
import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
|
| 4 |
+
import { useModel } from '../contexts/ModelContext';
|
| 5 |
|
| 6 |
const PLACEHOLDER_REVIEWS: string[] = [
|
| 7 |
// battery/charging problems
|
|
|
|
| 45 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
| 46 |
);
|
| 47 |
|
| 48 |
+
const { setProgress, status, setStatus, setModel } = useModel();
|
| 49 |
+
setModel('MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33')
|
| 50 |
|
| 51 |
// Create a reference to the worker object.
|
| 52 |
const worker = useRef<Worker | null>(null);
|
|
|
|
| 78 |
)
|
| 79 |
setProgress(e.data.output.progress);
|
| 80 |
} else if (status === 'output') {
|
| 81 |
+
setStatus('output');
|
| 82 |
const { sequence, labels, scores } = e.data.output!;
|
| 83 |
|
| 84 |
// Threshold for classification
|
|
|
|
| 177 |
? 'Model loading...'
|
| 178 |
: 'Processing'}
|
| 179 |
</button>
|
|
|
|
|
|
|
|
|
|
| 180 |
<div className="flex gap-1">
|
| 181 |
<button
|
| 182 |
className="border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium cursor-pointer"
|
src/contexts/ModelContext.tsx
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { createContext, useContext, useEffect, useState } from 'react';
|
| 2 |
+
|
| 3 |
+
interface ModelContextType {
|
| 4 |
+
progress: number;
|
| 5 |
+
status: string;
|
| 6 |
+
setProgress: (progress: number) => void;
|
| 7 |
+
setStatus: (status: string) => void;
|
| 8 |
+
model: string;
|
| 9 |
+
setModel: (model: string) => void;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
const ModelContext = createContext<ModelContextType | undefined>(undefined);
|
| 13 |
+
|
| 14 |
+
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
| 15 |
+
const [progress, setProgress] = useState<number>(0);
|
| 16 |
+
const [status, setStatus] = useState<string>('idle');
|
| 17 |
+
const [model, setModel] = useState<string>('');
|
| 18 |
+
|
| 19 |
+
// set progress to 0 when model is changed
|
| 20 |
+
useEffect(() => {
|
| 21 |
+
setProgress(0);
|
| 22 |
+
}, [model]);
|
| 23 |
+
|
| 24 |
+
return (
|
| 25 |
+
<ModelContext.Provider
|
| 26 |
+
value={{ progress, setProgress, status, setStatus, model, setModel }}
|
| 27 |
+
>
|
| 28 |
+
{children}
|
| 29 |
+
</ModelContext.Provider>
|
| 30 |
+
);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
export function useModel() {
|
| 34 |
+
const context = useContext(ModelContext);
|
| 35 |
+
if (context === undefined) {
|
| 36 |
+
throw new Error('useModel must be used within a ModelProvider');
|
| 37 |
+
}
|
| 38 |
+
return context;
|
| 39 |
+
}
|
src/index.tsx
CHANGED
|
@@ -3,14 +3,18 @@ import ReactDOM from 'react-dom/client';
|
|
| 3 |
import './index.css';
|
| 4 |
import App from './App';
|
| 5 |
import reportWebVitals from './reportWebVitals';
|
|
|
|
|
|
|
| 6 |
|
| 7 |
const root = ReactDOM.createRoot(
|
| 8 |
document.getElementById('root') as HTMLElement
|
| 9 |
);
|
| 10 |
root.render(
|
|
|
|
| 11 |
<React.StrictMode>
|
| 12 |
<App />
|
| 13 |
</React.StrictMode>
|
|
|
|
| 14 |
);
|
| 15 |
|
| 16 |
// If you want to start measuring performance in your app, pass a function
|
|
|
|
| 3 |
import './index.css';
|
| 4 |
import App from './App';
|
| 5 |
import reportWebVitals from './reportWebVitals';
|
| 6 |
+
import { ModelProvider } from './contexts/ModelContext';
|
| 7 |
+
|
| 8 |
|
| 9 |
const root = ReactDOM.createRoot(
|
| 10 |
document.getElementById('root') as HTMLElement
|
| 11 |
);
|
| 12 |
root.render(
|
| 13 |
+
<ModelProvider>
|
| 14 |
<React.StrictMode>
|
| 15 |
<App />
|
| 16 |
</React.StrictMode>
|
| 17 |
+
</ModelProvider>
|
| 18 |
);
|
| 19 |
|
| 20 |
// If you want to start measuring performance in your app, pass a function
|