Refactor modelInfo handling across components for improved null safety
Browse files
src/App.tsx
CHANGED
|
@@ -3,7 +3,6 @@ import PipelineSelector from './components/PipelineSelector'
|
|
| 3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
| 4 |
import TextClassification from './components/TextClassification'
|
| 5 |
import Header from './Header'
|
| 6 |
-
import Footer from './Footer'
|
| 7 |
import { useModel } from './contexts/ModelContext'
|
| 8 |
import { getModelsByPipeline } from './lib/huggingface'
|
| 9 |
import ModelSelector from './components/ModelSelector'
|
|
@@ -11,9 +10,10 @@ import ModelInfo from './components/ModelInfo'
|
|
| 11 |
import ModelReadme from './components/ModelReadme'
|
| 12 |
|
| 13 |
function App() {
|
| 14 |
-
const { pipeline, setPipeline, setModels, modelInfo } = useModel()
|
| 15 |
|
| 16 |
useEffect(() => {
|
|
|
|
| 17 |
const fetchModels = async () => {
|
| 18 |
const fetchedModels = await getModelsByPipeline(pipeline)
|
| 19 |
setModels(fetchedModels)
|
|
|
|
| 3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
| 4 |
import TextClassification from './components/TextClassification'
|
| 5 |
import Header from './Header'
|
|
|
|
| 6 |
import { useModel } from './contexts/ModelContext'
|
| 7 |
import { getModelsByPipeline } from './lib/huggingface'
|
| 8 |
import ModelSelector from './components/ModelSelector'
|
|
|
|
| 10 |
import ModelReadme from './components/ModelReadme'
|
| 11 |
|
| 12 |
function App() {
|
| 13 |
+
const { pipeline, setPipeline, setModels, setModelInfo, modelInfo } = useModel()
|
| 14 |
|
| 15 |
useEffect(() => {
|
| 16 |
+
setModelInfo(null)
|
| 17 |
const fetchModels = async () => {
|
| 18 |
const fetchedModels = await getModelsByPipeline(pipeline)
|
| 19 |
setModels(fetchedModels)
|
src/components/ModelInfo.tsx
CHANGED
|
@@ -64,7 +64,7 @@ const ModelInfo = () => {
|
|
| 64 |
</div>
|
| 65 |
)
|
| 66 |
|
| 67 |
-
if (!modelInfo
|
| 68 |
return <ModelInfoSkeleton />
|
| 69 |
}
|
| 70 |
|
|
|
|
| 64 |
</div>
|
| 65 |
)
|
| 66 |
|
| 67 |
+
if (!modelInfo) {
|
| 68 |
return <ModelInfoSkeleton />
|
| 69 |
}
|
| 70 |
|
src/components/ModelLoader.tsx
CHANGED
|
@@ -19,6 +19,8 @@ const ModelLoader = () => {
|
|
| 19 |
} = useModel()
|
| 20 |
|
| 21 |
useEffect(() => {
|
|
|
|
|
|
|
| 22 |
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
| 23 |
const quantizations = modelInfo.supportedQuantizations
|
| 24 |
let defaultQuant: QuantizationType = 'fp32'
|
|
@@ -34,12 +36,13 @@ const ModelLoader = () => {
|
|
| 34 |
setSelectedQuantization(defaultQuant)
|
| 35 |
}
|
| 36 |
}, [
|
| 37 |
-
modelInfo
|
| 38 |
-
modelInfo.isCompatible,
|
| 39 |
setSelectedQuantization
|
| 40 |
])
|
| 41 |
|
| 42 |
useEffect(() => {
|
|
|
|
|
|
|
| 43 |
const newWorker = getWorker(pipeline)
|
| 44 |
if (!newWorker) {
|
| 45 |
return
|
|
@@ -70,10 +73,10 @@ const ModelLoader = () => {
|
|
| 70 |
newWorker.removeEventListener('message', onMessageReceived)
|
| 71 |
// terminateWorker(pipeline);
|
| 72 |
}
|
| 73 |
-
}, [pipeline, modelInfo
|
| 74 |
|
| 75 |
const loadModel = useCallback(() => {
|
| 76 |
-
if (!modelInfo
|
| 77 |
|
| 78 |
setStatus('loading')
|
| 79 |
const message = {
|
|
@@ -82,12 +85,12 @@ const ModelLoader = () => {
|
|
| 82 |
quantization: selectedQuantization
|
| 83 |
}
|
| 84 |
activeWorker?.postMessage(message)
|
| 85 |
-
}, [modelInfo
|
| 86 |
|
| 87 |
const ready: boolean = status === 'ready'
|
| 88 |
const busy: boolean = status === 'loading'
|
| 89 |
|
| 90 |
-
if (!modelInfo
|
| 91 |
return null
|
| 92 |
}
|
| 93 |
|
|
|
|
| 19 |
} = useModel()
|
| 20 |
|
| 21 |
useEffect(() => {
|
| 22 |
+
if (!modelInfo) return
|
| 23 |
+
|
| 24 |
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
| 25 |
const quantizations = modelInfo.supportedQuantizations
|
| 26 |
let defaultQuant: QuantizationType = 'fp32'
|
|
|
|
| 36 |
setSelectedQuantization(defaultQuant)
|
| 37 |
}
|
| 38 |
}, [
|
| 39 |
+
modelInfo,
|
|
|
|
| 40 |
setSelectedQuantization
|
| 41 |
])
|
| 42 |
|
| 43 |
useEffect(() => {
|
| 44 |
+
if (!modelInfo) return
|
| 45 |
+
|
| 46 |
const newWorker = getWorker(pipeline)
|
| 47 |
if (!newWorker) {
|
| 48 |
return
|
|
|
|
| 73 |
newWorker.removeEventListener('message', onMessageReceived)
|
| 74 |
// terminateWorker(pipeline);
|
| 75 |
}
|
| 76 |
+
}, [pipeline, modelInfo, selectedQuantization, setActiveWorker, setStatus, setProgress])
|
| 77 |
|
| 78 |
const loadModel = useCallback(() => {
|
| 79 |
+
if (!modelInfo || !selectedQuantization) return
|
| 80 |
|
| 81 |
setStatus('loading')
|
| 82 |
const message = {
|
|
|
|
| 85 |
quantization: selectedQuantization
|
| 86 |
}
|
| 87 |
activeWorker?.postMessage(message)
|
| 88 |
+
}, [modelInfo, selectedQuantization, setStatus, activeWorker])
|
| 89 |
|
| 90 |
const ready: boolean = status === 'ready'
|
| 91 |
const busy: boolean = status === 'loading'
|
| 92 |
|
| 93 |
+
if (!modelInfo?.isCompatible || modelInfo.supportedQuantizations.length === 0) {
|
| 94 |
return null
|
| 95 |
}
|
| 96 |
|
src/components/ModelSelector.tsx
CHANGED
|
@@ -119,7 +119,7 @@ const ModelSelector: React.FC = () => {
|
|
| 119 |
}
|
| 120 |
|
| 121 |
const selectedModel =
|
| 122 |
-
models.find((model) => model.id === modelInfo
|
| 123 |
|
| 124 |
return (
|
| 125 |
<div className="relative">
|
|
@@ -132,7 +132,7 @@ const ModelSelector: React.FC = () => {
|
|
| 132 |
<div className="flex items-center justify-between w-full">
|
| 133 |
<div className="flex flex-col flex-1 min-w-0">
|
| 134 |
<span className="truncate font-medium">
|
| 135 |
-
{modelInfo
|
| 136 |
</span>
|
| 137 |
</div>
|
| 138 |
|
|
|
|
| 119 |
}
|
| 120 |
|
| 121 |
const selectedModel =
|
| 122 |
+
models.find((model) => model.id === modelInfo?.id) || models[0]
|
| 123 |
|
| 124 |
return (
|
| 125 |
<div className="relative">
|
|
|
|
| 132 |
<div className="flex items-center justify-between w-full">
|
| 133 |
<div className="flex flex-col flex-1 min-w-0">
|
| 134 |
<span className="truncate font-medium">
|
| 135 |
+
{modelInfo?.id || 'Select a model'}
|
| 136 |
</span>
|
| 137 |
</div>
|
| 138 |
|
src/components/TextClassification.tsx
CHANGED
|
@@ -58,6 +58,7 @@ function TextClassification() {
|
|
| 58 |
}, [setStatus])
|
| 59 |
|
| 60 |
const classify = useCallback(() => {
|
|
|
|
| 61 |
setStatus('loading')
|
| 62 |
setResults([]) // Clear previous results
|
| 63 |
const message: TextClassificationWorkerInput = {
|
|
@@ -66,7 +67,7 @@ function TextClassification() {
|
|
| 66 |
model: modelInfo.id
|
| 67 |
}
|
| 68 |
workerRef.current?.postMessage(message)
|
| 69 |
-
}, [text, modelInfo
|
| 70 |
|
| 71 |
const busy: boolean = status !== 'ready'
|
| 72 |
|
|
|
|
| 58 |
}, [setStatus])
|
| 59 |
|
| 60 |
const classify = useCallback(() => {
|
| 61 |
+
if (!modelInfo) return
|
| 62 |
setStatus('loading')
|
| 63 |
setResults([]) // Clear previous results
|
| 64 |
const message: TextClassificationWorkerInput = {
|
|
|
|
| 67 |
model: modelInfo.id
|
| 68 |
}
|
| 69 |
workerRef.current?.postMessage(message)
|
| 70 |
+
}, [text, modelInfo, setStatus])
|
| 71 |
|
| 72 |
const busy: boolean = status !== 'ready'
|
| 73 |
|
src/components/ZeroShotClassification.tsx
CHANGED
|
@@ -104,6 +104,8 @@ function ZeroShotClassification() {
|
|
| 104 |
}, [sections])
|
| 105 |
|
| 106 |
const classify = useCallback(() => {
|
|
|
|
|
|
|
| 107 |
setStatus('loading')
|
| 108 |
const message: ZeroShotWorkerInput = {
|
| 109 |
text,
|
|
@@ -113,7 +115,7 @@ function ZeroShotClassification() {
|
|
| 113 |
model: modelInfo.name
|
| 114 |
}
|
| 115 |
worker.current?.postMessage(message)
|
| 116 |
-
}, [text, sections, modelInfo
|
| 117 |
|
| 118 |
const busy: boolean = status !== 'ready'
|
| 119 |
|
|
|
|
| 104 |
}, [sections])
|
| 105 |
|
| 106 |
const classify = useCallback(() => {
|
| 107 |
+
if (!modelInfo) return
|
| 108 |
+
|
| 109 |
setStatus('loading')
|
| 110 |
const message: ZeroShotWorkerInput = {
|
| 111 |
text,
|
|
|
|
| 115 |
model: modelInfo.name
|
| 116 |
}
|
| 117 |
worker.current?.postMessage(message)
|
| 118 |
+
}, [text, sections, modelInfo])
|
| 119 |
|
| 120 |
const busy: boolean = status !== 'ready'
|
| 121 |
|
src/contexts/ModelContext.tsx
CHANGED
|
@@ -16,8 +16,8 @@ interface ModelContextType {
|
|
| 16 |
setStatus: (status: WorkerStatus) => void
|
| 17 |
progress: number
|
| 18 |
setProgress: (progress: number) => void
|
| 19 |
-
modelInfo: ModelInfo
|
| 20 |
-
setModelInfo: (model: ModelInfo) => void
|
| 21 |
pipeline: string
|
| 22 |
setPipeline: (pipeline: string) => void
|
| 23 |
models: ModelInfoResponse[]
|
|
@@ -33,7 +33,7 @@ const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
| 33 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
| 34 |
const [progress, setProgress] = useState<number>(0)
|
| 35 |
const [status, setStatus] = useState<WorkerStatus>('initiate')
|
| 36 |
-
const [modelInfo, setModelInfo] = useState<ModelInfo>(
|
| 37 |
const [models, setModels] = useState<ModelInfoResponse[]>(
|
| 38 |
[] as ModelInfoResponse[]
|
| 39 |
)
|
|
@@ -45,7 +45,7 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
| 45 |
// set progress to 0 when model is changed
|
| 46 |
useEffect(() => {
|
| 47 |
setProgress(0)
|
| 48 |
-
}, [modelInfo
|
| 49 |
|
| 50 |
return (
|
| 51 |
<ModelContext.Provider
|
|
|
|
| 16 |
setStatus: (status: WorkerStatus) => void
|
| 17 |
progress: number
|
| 18 |
setProgress: (progress: number) => void
|
| 19 |
+
modelInfo: ModelInfo | null
|
| 20 |
+
setModelInfo: (model: ModelInfo | null) => void
|
| 21 |
pipeline: string
|
| 22 |
setPipeline: (pipeline: string) => void
|
| 23 |
models: ModelInfoResponse[]
|
|
|
|
| 33 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
| 34 |
const [progress, setProgress] = useState<number>(0)
|
| 35 |
const [status, setStatus] = useState<WorkerStatus>('initiate')
|
| 36 |
+
const [modelInfo, setModelInfo] = useState<ModelInfo | null>(null)
|
| 37 |
const [models, setModels] = useState<ModelInfoResponse[]>(
|
| 38 |
[] as ModelInfoResponse[]
|
| 39 |
)
|
|
|
|
| 45 |
// set progress to 0 when model is changed
|
| 46 |
useEffect(() => {
|
| 47 |
setProgress(0)
|
| 48 |
+
}, [modelInfo?.name])
|
| 49 |
|
| 50 |
return (
|
| 51 |
<ModelContext.Provider
|