wip: enhance model selection and display across components
Browse files- src/App.tsx +23 -12
- src/components/ModelSelector.tsx +177 -22
- src/components/PipelineSelector.tsx +4 -0
- src/components/TextClassification.tsx +10 -10
- src/components/ZeroShotClassification.tsx +4 -2
- src/contexts/ModelContext.tsx +12 -2
- src/lib/huggingface.ts +34 -25
- src/types.ts +29 -0
src/App.tsx
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import { useState } from 'react'
|
| 2 |
import PipelineSelector from './components/PipelineSelector'
|
| 3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
| 4 |
import TextClassification from './components/TextClassification'
|
|
@@ -6,11 +6,11 @@ import Header from './Header'
|
|
| 6 |
import Footer from './Footer'
|
| 7 |
import { useModel } from './contexts/ModelContext'
|
| 8 |
import { Bot, Heart, Download, Cpu, DatabaseIcon } from 'lucide-react'
|
| 9 |
-
import { getModelSize } from './lib/huggingface'
|
|
|
|
| 10 |
|
| 11 |
function App() {
|
| 12 |
-
const
|
| 13 |
-
const { progress, status, modelInfo } = useModel()
|
| 14 |
|
| 15 |
const formatNumber = (num: number) => {
|
| 16 |
if (num >= 1000000000) {
|
|
@@ -23,6 +23,14 @@ function App() {
|
|
| 23 |
return num.toString()
|
| 24 |
}
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
return (
|
| 27 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
| 28 |
<Header />
|
|
@@ -32,10 +40,6 @@ function App() {
|
|
| 32 |
<div className="mb-8">
|
| 33 |
<div className="bg-white rounded-lg shadow-sm border p-6">
|
| 34 |
<div className="flex items-center justify-between mb-4">
|
| 35 |
-
<h2 className="text-lg font-semibold text-gray-900">
|
| 36 |
-
Choose a Pipeline
|
| 37 |
-
</h2>
|
| 38 |
-
|
| 39 |
{/* Model Info Display */}
|
| 40 |
{modelInfo.name && (
|
| 41 |
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-2">
|
|
@@ -77,9 +81,10 @@ function App() {
|
|
| 77 |
<div className="flex items-center space-x-1">
|
| 78 |
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
| 79 |
<span>
|
| 80 |
-
{`~${getModelSize(
|
| 81 |
-
|
| 82 |
-
|
|
|
|
| 83 |
</span>
|
| 84 |
</div>
|
| 85 |
)}
|
|
@@ -88,7 +93,13 @@ function App() {
|
|
| 88 |
)}
|
| 89 |
</div>
|
| 90 |
|
| 91 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
{/* Model Loading Progress */}
|
| 94 |
{status === 'progress' && (
|
|
|
|
| 1 |
+
import { useEffect, useState } from 'react'
|
| 2 |
import PipelineSelector from './components/PipelineSelector'
|
| 3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
| 4 |
import TextClassification from './components/TextClassification'
|
|
|
|
| 6 |
import Footer from './Footer'
|
| 7 |
import { useModel } from './contexts/ModelContext'
|
| 8 |
import { Bot, Heart, Download, Cpu, DatabaseIcon } from 'lucide-react'
|
| 9 |
+
import { getModelsByPipeline, getModelSize } from './lib/huggingface'
|
| 10 |
+
import ModelSelector from './components/ModelSelector'
|
| 11 |
|
| 12 |
function App() {
|
| 13 |
+
const { pipeline, setPipeline, progress, status, modelInfo, setModels } = useModel()
|
|
|
|
| 14 |
|
| 15 |
const formatNumber = (num: number) => {
|
| 16 |
if (num >= 1000000000) {
|
|
|
|
| 23 |
return num.toString()
|
| 24 |
}
|
| 25 |
|
| 26 |
+
useEffect(() => {
|
| 27 |
+
const fetchModels = async () => {
|
| 28 |
+
const fetchedModels = await getModelsByPipeline(pipeline);
|
| 29 |
+
setModels(fetchedModels);
|
| 30 |
+
};
|
| 31 |
+
fetchModels();
|
| 32 |
+
}, [setModels, pipeline]);
|
| 33 |
+
|
| 34 |
return (
|
| 35 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
| 36 |
<Header />
|
|
|
|
| 40 |
<div className="mb-8">
|
| 41 |
<div className="bg-white rounded-lg shadow-sm border p-6">
|
| 42 |
<div className="flex items-center justify-between mb-4">
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
{/* Model Info Display */}
|
| 44 |
{modelInfo.name && (
|
| 45 |
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-2">
|
|
|
|
| 81 |
<div className="flex items-center space-x-1">
|
| 82 |
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
| 83 |
<span>
|
| 84 |
+
{`~${getModelSize(
|
| 85 |
+
modelInfo.parameters,
|
| 86 |
+
'INT8'
|
| 87 |
+
).toFixed(1)}MB`}
|
| 88 |
</span>
|
| 89 |
</div>
|
| 90 |
)}
|
|
|
|
| 93 |
)}
|
| 94 |
</div>
|
| 95 |
|
| 96 |
+
<div className="flex flex-row items-center space-x-4">
|
| 97 |
+
<span className="text-lg font-semibold text-gray-900">
|
| 98 |
+
Choose a Pipeline
|
| 99 |
+
</span>
|
| 100 |
+
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
| 101 |
+
</div>
|
| 102 |
+
<ModelSelector />
|
| 103 |
|
| 104 |
{/* Model Loading Progress */}
|
| 105 |
{status === 'progress' && (
|
src/components/ModelSelector.tsx
CHANGED
|
@@ -1,25 +1,180 @@
|
|
| 1 |
-
import React from 'react'
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
const ModelSelector: React.FC<ModelSelectorProps> = ({
|
| 10 |
-
model,
|
| 11 |
-
setModel,
|
| 12 |
-
models
|
| 13 |
-
}) => {
|
| 14 |
return (
|
| 15 |
-
<
|
| 16 |
-
{
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useEffect, useState } from 'react'
|
| 2 |
+
import { useModel } from '../contexts/ModelContext'
|
| 3 |
+
import { getModelInfo } from '../lib/huggingface'
|
| 4 |
+
import { Heart, Download, ChevronDown } from 'lucide-react'
|
| 5 |
|
| 6 |
+
const ModelSelector: React.FC = () => {
|
| 7 |
+
const { models, setModelInfo, modelInfo } = useModel()
|
| 8 |
+
const [isOpen, setIsOpen] = useState(false)
|
| 9 |
+
const [modelStats, setModelStats] = useState<
|
| 10 |
+
Record<string, { likes: number; downloads: number; createdAt: string }>
|
| 11 |
+
>({})
|
| 12 |
+
|
| 13 |
+
const formatNumber = (num: number) => {
|
| 14 |
+
if (num >= 1000000000) {
|
| 15 |
+
return (num / 1000000000).toFixed(1) + 'B'
|
| 16 |
+
} else if (num >= 1000000) {
|
| 17 |
+
return (num / 1000000).toFixed(1) + 'M'
|
| 18 |
+
} else if (num >= 1000) {
|
| 19 |
+
return (num / 1000).toFixed(1) + 'K'
|
| 20 |
+
}
|
| 21 |
+
return num.toString()
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// Separate function to fetch only stats without updating selected model
|
| 25 |
+
const fetchModelStats = async (modelId: string) => {
|
| 26 |
+
try {
|
| 27 |
+
const modelInfoResponse = await getModelInfo(modelId)
|
| 28 |
+
|
| 29 |
+
setModelStats((prev) => ({
|
| 30 |
+
...prev,
|
| 31 |
+
[modelId]: {
|
| 32 |
+
likes: modelInfoResponse.likes || 0,
|
| 33 |
+
downloads: modelInfoResponse.downloads || 0,
|
| 34 |
+
createdAt: modelInfoResponse.createdAt || ''
|
| 35 |
+
}
|
| 36 |
+
}))
|
| 37 |
+
} catch (error) {
|
| 38 |
+
console.error('Error fetching model stats:', error)
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// Function to fetch full model info and set as selected
|
| 43 |
+
const fetchModelAndSetInfo = async (modelId: string) => {
|
| 44 |
+
try {
|
| 45 |
+
const modelInfoResponse = await getModelInfo(modelId)
|
| 46 |
+
let parameters = 0
|
| 47 |
+
if (modelInfoResponse.safetensors) {
|
| 48 |
+
const safetensors = modelInfoResponse.safetensors
|
| 49 |
+
parameters =
|
| 50 |
+
safetensors.parameters.F16 ||
|
| 51 |
+
safetensors.parameters.F32 ||
|
| 52 |
+
safetensors.parameters.total ||
|
| 53 |
+
0
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// Transform ModelInfoResponse to ModelInfo
|
| 57 |
+
const modelInfo = {
|
| 58 |
+
id: modelId,
|
| 59 |
+
name: modelInfoResponse.id || modelId,
|
| 60 |
+
architecture: modelInfoResponse.config?.architectures?.[0] || 'Unknown',
|
| 61 |
+
parameters,
|
| 62 |
+
likes: modelInfoResponse.likes || 0,
|
| 63 |
+
downloads: modelInfoResponse.downloads || 0,
|
| 64 |
+
createdAt: modelInfoResponse.createdAt || ''
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Also update stats
|
| 68 |
+
setModelStats((prev) => ({
|
| 69 |
+
...prev,
|
| 70 |
+
[modelId]: {
|
| 71 |
+
likes: modelInfoResponse.likes || 0,
|
| 72 |
+
downloads: modelInfoResponse.downloads || 0,
|
| 73 |
+
createdAt: modelInfoResponse.createdAt || ''
|
| 74 |
+
}
|
| 75 |
+
}))
|
| 76 |
+
|
| 77 |
+
console.log(modelInfoResponse)
|
| 78 |
+
|
| 79 |
+
setModelInfo(modelInfo)
|
| 80 |
+
} catch (error) {
|
| 81 |
+
console.error('Error fetching model info:', error)
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// Fetch stats for all models when component mounts (without setting as selected)
|
| 86 |
+
useEffect(() => {
|
| 87 |
+
models.forEach((model) => {
|
| 88 |
+
if (!modelStats[model.id]) {
|
| 89 |
+
fetchModelStats(model.id)
|
| 90 |
+
}
|
| 91 |
+
})
|
| 92 |
+
}, [models])
|
| 93 |
+
|
| 94 |
+
// Only fetch full info when a model is actually selected
|
| 95 |
+
useEffect(() => {
|
| 96 |
+
if (!modelInfo.id) return
|
| 97 |
+
// Only fetch if we don't already have the full info
|
| 98 |
+
if (!modelStats[modelInfo.id]) {
|
| 99 |
+
fetchModelAndSetInfo(modelInfo.id)
|
| 100 |
+
}
|
| 101 |
+
}, [modelInfo.id])
|
| 102 |
+
|
| 103 |
+
const handleModelSelect = (modelId: string) => {
|
| 104 |
+
fetchModelAndSetInfo(modelId)
|
| 105 |
+
setIsOpen(false)
|
| 106 |
+
}
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
return (
|
| 109 |
+
<div className="relative">
|
| 110 |
+
{/* Custom Dropdown Button */}
|
| 111 |
+
<button
|
| 112 |
+
onClick={() => setIsOpen(!isOpen)}
|
| 113 |
+
className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between"
|
| 114 |
+
>
|
| 115 |
+
<span className="truncate">{modelInfo.id || 'Select a model'}</span>
|
| 116 |
+
<ChevronDown
|
| 117 |
+
className={`w-4 h-4 transition-transform ${
|
| 118 |
+
isOpen ? 'rotate-180' : ''
|
| 119 |
+
}`}
|
| 120 |
+
/>
|
| 121 |
+
</button>
|
| 122 |
+
|
| 123 |
+
{/* Custom Dropdown Options */}
|
| 124 |
+
{isOpen && (
|
| 125 |
+
<div className="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg max-h-60 overflow-auto">
|
| 126 |
+
{models.map((model) => (
|
| 127 |
+
<div
|
| 128 |
+
key={model.id}
|
| 129 |
+
onClick={() => handleModelSelect(model.id)}
|
| 130 |
+
className="px-3 py-2 hover:bg-gray-50 cursor-pointer border-b border-gray-100 last:border-b-0"
|
| 131 |
+
>
|
| 132 |
+
<div className="flex items-center justify-between">
|
| 133 |
+
<span className="text-sm font-medium truncate flex-1 mr-2">
|
| 134 |
+
{model.id}
|
| 135 |
+
</span>
|
| 136 |
+
|
| 137 |
+
{/* Stats Display */}
|
| 138 |
+
{modelStats[model.id] &&
|
| 139 |
+
(modelStats[model.id].likes > 0 ||
|
| 140 |
+
modelStats[model.id].downloads > 0) && (
|
| 141 |
+
<div className="flex items-center space-x-3 text-xs text-gray-500 flex-shrink-0">
|
| 142 |
+
{modelStats[model.id].likes > 0 && (
|
| 143 |
+
<div className="flex items-center space-x-1">
|
| 144 |
+
<Heart className="w-3 h-3 text-red-500" />
|
| 145 |
+
<span>
|
| 146 |
+
{formatNumber(modelStats[model.id].likes)}
|
| 147 |
+
</span>
|
| 148 |
+
</div>
|
| 149 |
+
)}
|
| 150 |
+
|
| 151 |
+
{modelStats[model.id].downloads > 0 && (
|
| 152 |
+
<div className="flex items-center space-x-1">
|
| 153 |
+
<Download className="w-3 h-3 text-green-500" />
|
| 154 |
+
<span>
|
| 155 |
+
{formatNumber(modelStats[model.id].downloads)}
|
| 156 |
+
</span>
|
| 157 |
+
</div>
|
| 158 |
+
)}
|
| 159 |
+
{modelStats[model.id].createdAt !== '' && (
|
| 160 |
+
<span className="text-xs text-gray-400">
|
| 161 |
+
{modelStats[model.id].createdAt.split('T')[0]}
|
| 162 |
+
</span>
|
| 163 |
+
)}
|
| 164 |
+
</div>
|
| 165 |
+
)}
|
| 166 |
+
</div>
|
| 167 |
+
</div>
|
| 168 |
+
))}
|
| 169 |
+
</div>
|
| 170 |
+
)}
|
| 171 |
+
|
| 172 |
+
{/* Click outside to close */}
|
| 173 |
+
{isOpen && (
|
| 174 |
+
<div className="fixed inset-0 z-0" onClick={() => setIsOpen(false)} />
|
| 175 |
+
)}
|
| 176 |
+
</div>
|
| 177 |
+
)
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
export default ModelSelector
|
src/components/PipelineSelector.tsx
CHANGED
|
@@ -3,6 +3,10 @@ import React from 'react';
|
|
| 3 |
const pipelines = [
|
| 4 |
'zero-shot-classification',
|
| 5 |
'text-classification',
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
'image-classification',
|
| 7 |
'question-answering',
|
| 8 |
'translation'
|
|
|
|
| 3 |
const pipelines = [
|
| 4 |
'zero-shot-classification',
|
| 5 |
'text-classification',
|
| 6 |
+
'text-generation',
|
| 7 |
+
'summarization',
|
| 8 |
+
'feature-extraction',
|
| 9 |
+
'sentiment-analysis',
|
| 10 |
'image-classification',
|
| 11 |
'question-answering',
|
| 12 |
'translation'
|
src/components/TextClassification.tsx
CHANGED
|
@@ -3,7 +3,6 @@ import {
|
|
| 3 |
ClassificationOutput,
|
| 4 |
TextClassificationWorkerInput,
|
| 5 |
WorkerMessage,
|
| 6 |
-
ModelInfo
|
| 7 |
} from '../types';
|
| 8 |
import { useModel } from '../contexts/ModelContext';
|
| 9 |
import { getModelInfo } from '../lib/huggingface';
|
|
@@ -25,13 +24,14 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
| 25 |
function TextClassification() {
|
| 26 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 27 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
| 28 |
-
const { setProgress, status, setStatus, modelInfo, setModelInfo} = useModel()
|
|
|
|
|
|
|
| 29 |
useEffect(() => {
|
| 30 |
-
|
| 31 |
const fetchModelInfo = async () => {
|
| 32 |
try {
|
| 33 |
-
const modelInfoResponse = await getModelInfo(
|
| 34 |
-
console.log(modelInfoResponse)
|
| 35 |
let parameters = 0
|
| 36 |
if (modelInfoResponse.safetensors) {
|
| 37 |
const safetensors = modelInfoResponse.safetensors
|
|
@@ -42,8 +42,8 @@ function TextClassification() {
|
|
| 42 |
0)
|
| 43 |
}
|
| 44 |
setModelInfo({
|
| 45 |
-
|
| 46 |
-
architecture: modelInfoResponse.config
|
| 47 |
parameters,
|
| 48 |
likes: modelInfoResponse.likes,
|
| 49 |
downloads: modelInfoResponse.downloads
|
|
@@ -54,7 +54,7 @@ function TextClassification() {
|
|
| 54 |
}
|
| 55 |
|
| 56 |
fetchModelInfo()
|
| 57 |
-
}, [setModelInfo])
|
| 58 |
|
| 59 |
// Create a reference to the worker object.
|
| 60 |
const worker = useRef<Worker | null>(null)
|
|
@@ -110,9 +110,9 @@ function TextClassification() {
|
|
| 110 |
const classify = useCallback(() => {
|
| 111 |
setStatus('processing')
|
| 112 |
setResults([]) // Clear previous results
|
| 113 |
-
const message: TextClassificationWorkerInput = { text, model: modelInfo.
|
| 114 |
worker.current?.postMessage(message)
|
| 115 |
-
}, [text, modelInfo.
|
| 116 |
|
| 117 |
const busy: boolean = status !== 'idle'
|
| 118 |
|
|
|
|
| 3 |
ClassificationOutput,
|
| 4 |
TextClassificationWorkerInput,
|
| 5 |
WorkerMessage,
|
|
|
|
| 6 |
} from '../types';
|
| 7 |
import { useModel } from '../contexts/ModelContext';
|
| 8 |
import { getModelInfo } from '../lib/huggingface';
|
|
|
|
| 24 |
function TextClassification() {
|
| 25 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 26 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
| 27 |
+
const { setProgress, status, setStatus, modelInfo, setModelInfo, models, setModels} = useModel()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
useEffect(() => {
|
| 31 |
+
if (!modelInfo.id) return;
|
| 32 |
const fetchModelInfo = async () => {
|
| 33 |
try {
|
| 34 |
+
const modelInfoResponse = await getModelInfo(modelInfo.id)
|
|
|
|
| 35 |
let parameters = 0
|
| 36 |
if (modelInfoResponse.safetensors) {
|
| 37 |
const safetensors = modelInfoResponse.safetensors
|
|
|
|
| 42 |
0)
|
| 43 |
}
|
| 44 |
setModelInfo({
|
| 45 |
+
...modelInfo,
|
| 46 |
+
architecture: modelInfoResponse.config?.architectures[0] ?? '',
|
| 47 |
parameters,
|
| 48 |
likes: modelInfoResponse.likes,
|
| 49 |
downloads: modelInfoResponse.downloads
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
fetchModelInfo()
|
| 57 |
+
}, [modelInfo.id, setModelInfo])
|
| 58 |
|
| 59 |
// Create a reference to the worker object.
|
| 60 |
const worker = useRef<Worker | null>(null)
|
|
|
|
| 110 |
const classify = useCallback(() => {
|
| 111 |
setStatus('processing')
|
| 112 |
setResults([]) // Clear previous results
|
| 113 |
+
const message: TextClassificationWorkerInput = { text, model: modelInfo.id }
|
| 114 |
worker.current?.postMessage(message)
|
| 115 |
+
}, [text, modelInfo.id])
|
| 116 |
|
| 117 |
const busy: boolean = status !== 'idle'
|
| 118 |
|
src/components/ZeroShotClassification.tsx
CHANGED
|
@@ -68,11 +68,13 @@ function ZeroShotClassification() {
|
|
| 68 |
0
|
| 69 |
}
|
| 70 |
setModelInfo({
|
|
|
|
| 71 |
name: modelName,
|
| 72 |
-
architecture: modelInfoResponse.config
|
| 73 |
parameters,
|
| 74 |
likes: modelInfoResponse.likes,
|
| 75 |
-
downloads: modelInfoResponse.downloads
|
|
|
|
| 76 |
})
|
| 77 |
} catch (error) {
|
| 78 |
console.error('Error fetching model info:', error)
|
|
|
|
| 68 |
0
|
| 69 |
}
|
| 70 |
setModelInfo({
|
| 71 |
+
id: modelInfoResponse.id,
|
| 72 |
name: modelName,
|
| 73 |
+
architecture: modelInfoResponse.config?.architectures[0] ?? '',
|
| 74 |
parameters,
|
| 75 |
likes: modelInfoResponse.likes,
|
| 76 |
+
downloads: modelInfoResponse.downloads,
|
| 77 |
+
createdAt: modelInfoResponse.createdAt,
|
| 78 |
})
|
| 79 |
} catch (error) {
|
| 80 |
console.error('Error fetching model info:', error)
|
src/contexts/ModelContext.tsx
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import React, { createContext, useContext, useEffect, useState } from 'react'
|
| 2 |
-
import { ModelInfo } from '../types'
|
| 3 |
|
| 4 |
interface ModelContextType {
|
| 5 |
progress: number
|
|
@@ -8,6 +8,10 @@ interface ModelContextType {
|
|
| 8 |
setStatus: (status: string) => void
|
| 9 |
modelInfo: ModelInfo
|
| 10 |
setModelInfo: (model: ModelInfo) => void
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
}
|
| 12 |
|
| 13 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
@@ -16,6 +20,8 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
| 16 |
const [progress, setProgress] = useState<number>(0)
|
| 17 |
const [status, setStatus] = useState<string>('idle')
|
| 18 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
|
|
|
|
|
|
| 19 |
|
| 20 |
// set progress to 0 when model is changed
|
| 21 |
useEffect(() => {
|
|
@@ -30,7 +36,11 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
| 30 |
status,
|
| 31 |
setStatus,
|
| 32 |
modelInfo,
|
| 33 |
-
setModelInfo
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
}}
|
| 35 |
>
|
| 36 |
{children}
|
|
|
|
| 1 |
import React, { createContext, useContext, useEffect, useState } from 'react'
|
| 2 |
+
import { ModelInfo, ModelInfoResponse } from '../types'
|
| 3 |
|
| 4 |
interface ModelContextType {
|
| 5 |
progress: number
|
|
|
|
| 8 |
setStatus: (status: string) => void
|
| 9 |
modelInfo: ModelInfo
|
| 10 |
setModelInfo: (model: ModelInfo) => void
|
| 11 |
+
pipeline: string
|
| 12 |
+
setPipeline: (pipeline: string) => void
|
| 13 |
+
models: ModelInfoResponse[]
|
| 14 |
+
setModels: (models: ModelInfoResponse[]) => void
|
| 15 |
}
|
| 16 |
|
| 17 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
|
|
| 20 |
const [progress, setProgress] = useState<number>(0)
|
| 21 |
const [status, setStatus] = useState<string>('idle')
|
| 22 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
| 23 |
+
const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
|
| 24 |
+
const [pipeline, setPipeline] = useState<string>('zero-shot-classification')
|
| 25 |
|
| 26 |
// set progress to 0 when model is changed
|
| 27 |
useEffect(() => {
|
|
|
|
| 36 |
status,
|
| 37 |
setStatus,
|
| 38 |
modelInfo,
|
| 39 |
+
setModelInfo,
|
| 40 |
+
models,
|
| 41 |
+
setModels,
|
| 42 |
+
pipeline,
|
| 43 |
+
setPipeline,
|
| 44 |
}}
|
| 45 |
>
|
| 46 |
{children}
|
src/lib/huggingface.ts
CHANGED
|
@@ -1,27 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
config: {
|
| 4 |
-
architectures: string[]
|
| 5 |
-
model_type: string
|
| 6 |
-
}
|
| 7 |
-
lastModified: string
|
| 8 |
-
pipeline_tag: string
|
| 9 |
-
tags: string[]
|
| 10 |
-
transformersInfo: {
|
| 11 |
-
pipeline_tag: string
|
| 12 |
-
auto_model: string
|
| 13 |
-
processor: string
|
| 14 |
-
}
|
| 15 |
-
safetensors?: {
|
| 16 |
-
parameters: {
|
| 17 |
-
F16?: number
|
| 18 |
-
F32?: number
|
| 19 |
-
total?: number
|
| 20 |
-
}
|
| 21 |
-
}
|
| 22 |
-
likes: number
|
| 23 |
-
downloads: number
|
| 24 |
-
}
|
| 25 |
|
| 26 |
const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
| 27 |
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
|
@@ -48,6 +26,37 @@ const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
|
| 48 |
return response.json()
|
| 49 |
}
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
// Define the possible quantization types for clarity and type safety
|
| 52 |
type QuantizationType = 'FP32' | 'FP16' | 'INT8' | 'Q4'
|
| 53 |
function getModelSize(
|
|
@@ -81,5 +90,5 @@ function getModelSize(
|
|
| 81 |
}
|
| 82 |
|
| 83 |
|
| 84 |
-
export { getModelInfo, getModelSize }
|
| 85 |
|
|
|
|
| 1 |
+
import { Mode } from "fs"
|
| 2 |
+
import { ModelInfoResponse } from "../types"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
| 5 |
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
|
|
|
| 26 |
return response.json()
|
| 27 |
}
|
| 28 |
|
| 29 |
+
const getModelsByPipeline = async (
|
| 30 |
+
pipeline_tag: string
|
| 31 |
+
): Promise<ModelInfoResponse[]> => {
|
| 32 |
+
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
| 33 |
+
|
| 34 |
+
if (!token) {
|
| 35 |
+
throw new Error(
|
| 36 |
+
'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
|
| 37 |
+
)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
const response = await fetch(
|
| 41 |
+
`https://huggingface.co/api/models?filter=${pipeline_tag}&filter=transformers.js&sort=downloads`,
|
| 42 |
+
{
|
| 43 |
+
method: 'GET',
|
| 44 |
+
headers: {
|
| 45 |
+
Authorization: `Bearer ${token}`
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if (!response.ok) {
|
| 51 |
+
throw new Error(`Failed to fetch models for pipeline: ${response.statusText}`)
|
| 52 |
+
}
|
| 53 |
+
const models = await response.json()
|
| 54 |
+
if (pipeline_tag === 'text-classification') {
|
| 55 |
+
return models.filter((model: ModelInfoResponse) => !model.tags.includes('reranker') && !model.id.includes('reranker')).slice(0, 10)
|
| 56 |
+
}
|
| 57 |
+
return models.slice(0, 10)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
// Define the possible quantization types for clarity and type safety
|
| 61 |
type QuantizationType = 'FP32' | 'FP16' | 'INT8' | 'Q4'
|
| 62 |
function getModelSize(
|
|
|
|
| 90 |
}
|
| 91 |
|
| 92 |
|
| 93 |
+
export { getModelInfo, getModelSize, getModelsByPipeline }
|
| 94 |
|
src/types.ts
CHANGED
|
@@ -28,9 +28,38 @@ export interface TextClassificationWorkerInput {
|
|
| 28 |
export type AppStatus = 'idle' | 'loading' | 'processing'
|
| 29 |
|
| 30 |
export interface ModelInfo {
|
|
|
|
| 31 |
name: string
|
| 32 |
architecture: string
|
| 33 |
parameters: number
|
| 34 |
likes: number
|
| 35 |
downloads: number
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
}
|
|
|
|
| 28 |
export type AppStatus = 'idle' | 'loading' | 'processing'
|
| 29 |
|
| 30 |
export interface ModelInfo {
|
| 31 |
+
id: string
|
| 32 |
name: string
|
| 33 |
architecture: string
|
| 34 |
parameters: number
|
| 35 |
likes: number
|
| 36 |
downloads: number
|
| 37 |
+
createdAt: string
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
export interface ModelInfoResponse {
|
| 42 |
+
id: string
|
| 43 |
+
createdAt: string
|
| 44 |
+
config?: {
|
| 45 |
+
architectures: string[]
|
| 46 |
+
model_type: string
|
| 47 |
+
}
|
| 48 |
+
lastModified: string
|
| 49 |
+
pipeline_tag: string
|
| 50 |
+
tags: string[]
|
| 51 |
+
transformersInfo: {
|
| 52 |
+
pipeline_tag: string
|
| 53 |
+
auto_model: string
|
| 54 |
+
processor: string
|
| 55 |
+
}
|
| 56 |
+
safetensors?: {
|
| 57 |
+
parameters: {
|
| 58 |
+
F16?: number
|
| 59 |
+
F32?: number
|
| 60 |
+
total?: number
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
likes: number
|
| 64 |
+
downloads: number
|
| 65 |
}
|