feat: add dtype support for quantization in model inputs and remove debug logs
Browse files
src/components/ModelLoader.tsx
CHANGED
|
@@ -64,8 +64,6 @@ const ModelLoader = () => {
|
|
| 64 |
|
| 65 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
| 66 |
const { status, output } = e.data
|
| 67 |
-
console.log('Received output from worker', e.data)
|
| 68 |
-
|
| 69 |
if (status === 'ready') {
|
| 70 |
setStatus('ready')
|
| 71 |
if (e.data.output) console.log(e.data.output)
|
|
|
|
| 64 |
|
| 65 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
| 66 |
const { status, output } = e.data
|
|
|
|
|
|
|
| 67 |
if (status === 'ready') {
|
| 68 |
setStatus('ready')
|
| 69 |
if (e.data.output) console.log(e.data.output)
|
src/components/TextClassification.tsx
CHANGED
|
@@ -3,7 +3,7 @@ import {
|
|
| 3 |
TextClassificationWorkerInput,
|
| 4 |
} from '../types'
|
| 5 |
import { useModel } from '../contexts/ModelContext'
|
| 6 |
-
|
| 7 |
const PLACEHOLDER_TEXTS: string[] = [
|
| 8 |
'I absolutely love this product! It exceeded all my expectations.',
|
| 9 |
"This is the worst purchase I've ever made. Complete waste of money.",
|
|
@@ -19,7 +19,7 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
| 19 |
|
| 20 |
function TextClassification() {
|
| 21 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 22 |
-
const { activeWorker, status, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
| 23 |
|
| 24 |
const classify = useCallback(() => {
|
| 25 |
if (!modelInfo || !activeWorker) {
|
|
@@ -30,10 +30,11 @@ function TextClassification() {
|
|
| 30 |
const message: TextClassificationWorkerInput = {
|
| 31 |
type: 'classify',
|
| 32 |
text,
|
| 33 |
-
model: modelInfo.id
|
|
|
|
| 34 |
}
|
| 35 |
activeWorker.postMessage(message)
|
| 36 |
-
}, [text, modelInfo, activeWorker,
|
| 37 |
|
| 38 |
const busy: boolean = status !== 'ready'
|
| 39 |
|
|
|
|
| 3 |
TextClassificationWorkerInput,
|
| 4 |
} from '../types'
|
| 5 |
import { useModel } from '../contexts/ModelContext'
|
| 6 |
+
|
| 7 |
const PLACEHOLDER_TEXTS: string[] = [
|
| 8 |
'I absolutely love this product! It exceeded all my expectations.',
|
| 9 |
"This is the worst purchase I've ever made. Complete waste of money.",
|
|
|
|
| 19 |
|
| 20 |
function TextClassification() {
|
| 21 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 22 |
+
const { activeWorker, status, modelInfo, results, setResults, hasBeenLoaded, selectedQuantization} = useModel()
|
| 23 |
|
| 24 |
const classify = useCallback(() => {
|
| 25 |
if (!modelInfo || !activeWorker) {
|
|
|
|
| 30 |
const message: TextClassificationWorkerInput = {
|
| 31 |
type: 'classify',
|
| 32 |
text,
|
| 33 |
+
model: modelInfo.id,
|
| 34 |
+
dtype: selectedQuantization ?? 'fp32'
|
| 35 |
}
|
| 36 |
activeWorker.postMessage(message)
|
| 37 |
+
}, [text, modelInfo, activeWorker, selectedQuantization, setResults])
|
| 38 |
|
| 39 |
const busy: boolean = status !== 'ready'
|
| 40 |
|
src/components/TextGeneration.tsx
CHANGED
|
@@ -30,7 +30,7 @@ function TextGeneration() {
|
|
| 30 |
// Generation state
|
| 31 |
const [isGenerating, setIsGenerating] = useState<boolean>(false)
|
| 32 |
|
| 33 |
-
const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
|
| 34 |
const messagesEndRef = useRef<HTMLDivElement>(null)
|
| 35 |
|
| 36 |
const scrollToBottom = () => {
|
|
@@ -73,10 +73,11 @@ function TextGeneration() {
|
|
| 73 |
top_p: topP,
|
| 74 |
top_k: topK,
|
| 75 |
do_sample: doSample,
|
|
|
|
| 76 |
}
|
| 77 |
|
| 78 |
activeWorker.postMessage(message)
|
| 79 |
-
}, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
|
| 80 |
|
| 81 |
const handleGenerateText = useCallback(() => {
|
| 82 |
if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
|
|
@@ -94,11 +95,12 @@ function TextGeneration() {
|
|
| 94 |
max_new_tokens: maxTokens,
|
| 95 |
top_p: topP,
|
| 96 |
top_k: topK,
|
| 97 |
-
do_sample: doSample
|
|
|
|
| 98 |
}
|
| 99 |
|
| 100 |
activeWorker.postMessage(message)
|
| 101 |
-
}, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
|
| 102 |
|
| 103 |
useEffect(() => {
|
| 104 |
if (!activeWorker) return
|
|
|
|
| 30 |
// Generation state
|
| 31 |
const [isGenerating, setIsGenerating] = useState<boolean>(false)
|
| 32 |
|
| 33 |
+
const { activeWorker, status, modelInfo, hasBeenLoaded, selectedQuantization } = useModel()
|
| 34 |
const messagesEndRef = useRef<HTMLDivElement>(null)
|
| 35 |
|
| 36 |
const scrollToBottom = () => {
|
|
|
|
| 73 |
top_p: topP,
|
| 74 |
top_k: topK,
|
| 75 |
do_sample: doSample,
|
| 76 |
+
dtype: selectedQuantization ?? 'fp32'
|
| 77 |
}
|
| 78 |
|
| 79 |
activeWorker.postMessage(message)
|
| 80 |
+
}, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating, selectedQuantization])
|
| 81 |
|
| 82 |
const handleGenerateText = useCallback(() => {
|
| 83 |
if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
|
|
|
|
| 95 |
max_new_tokens: maxTokens,
|
| 96 |
top_p: topP,
|
| 97 |
top_k: topK,
|
| 98 |
+
do_sample: doSample,
|
| 99 |
+
dtype: selectedQuantization ?? 'fp32'
|
| 100 |
}
|
| 101 |
|
| 102 |
activeWorker.postMessage(message)
|
| 103 |
+
}, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating, selectedQuantization])
|
| 104 |
|
| 105 |
useEffect(() => {
|
| 106 |
if (!activeWorker) return
|
src/components/ZeroShotClassification.tsx
CHANGED
|
@@ -48,7 +48,7 @@ function ZeroShotClassification() {
|
|
| 48 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
| 49 |
)
|
| 50 |
|
| 51 |
-
const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
|
| 52 |
|
| 53 |
const classify = useCallback(() => {
|
| 54 |
if (!modelInfo || !activeWorker) {
|
|
@@ -70,10 +70,11 @@ function ZeroShotClassification() {
|
|
| 70 |
labels: sections
|
| 71 |
.slice(0, sections.length - 1)
|
| 72 |
.map((section) => section.title),
|
| 73 |
-
model: modelInfo.id
|
|
|
|
| 74 |
}
|
| 75 |
activeWorker.postMessage(message)
|
| 76 |
-
}, [text, sections, modelInfo, activeWorker])
|
| 77 |
|
| 78 |
// Handle worker messages
|
| 79 |
useEffect(() => {
|
|
|
|
| 48 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
| 49 |
)
|
| 50 |
|
| 51 |
+
const { activeWorker, status, modelInfo, hasBeenLoaded, selectedQuantization } = useModel()
|
| 52 |
|
| 53 |
const classify = useCallback(() => {
|
| 54 |
if (!modelInfo || !activeWorker) {
|
|
|
|
| 70 |
labels: sections
|
| 71 |
.slice(0, sections.length - 1)
|
| 72 |
.map((section) => section.title),
|
| 73 |
+
model: modelInfo.id,
|
| 74 |
+
dtype: selectedQuantization ?? 'fp32'
|
| 75 |
}
|
| 76 |
activeWorker.postMessage(message)
|
| 77 |
+
}, [text, sections, modelInfo, activeWorker, selectedQuantization])
|
| 78 |
|
| 79 |
// Handle worker messages
|
| 80 |
useEffect(() => {
|
src/types.ts
CHANGED
|
@@ -39,12 +39,14 @@ export interface ZeroShotWorkerInput {
|
|
| 39 |
text: string
|
| 40 |
labels: string[]
|
| 41 |
model: string
|
|
|
|
| 42 |
}
|
| 43 |
|
| 44 |
export interface TextClassificationWorkerInput {
|
| 45 |
type: 'classify'
|
| 46 |
text: string
|
| 47 |
model: string
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
export interface TextGenerationWorkerInput {
|
|
@@ -58,6 +60,7 @@ export interface TextGenerationWorkerInput {
|
|
| 58 |
top_p?: number
|
| 59 |
top_k?: number
|
| 60 |
do_sample?: boolean
|
|
|
|
| 61 |
}
|
| 62 |
|
| 63 |
const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
|
|
|
|
| 39 |
text: string
|
| 40 |
labels: string[]
|
| 41 |
model: string
|
| 42 |
+
dtype: QuantizationType
|
| 43 |
}
|
| 44 |
|
| 45 |
export interface TextClassificationWorkerInput {
|
| 46 |
type: 'classify'
|
| 47 |
text: string
|
| 48 |
model: string
|
| 49 |
+
dtype: QuantizationType
|
| 50 |
}
|
| 51 |
|
| 52 |
export interface TextGenerationWorkerInput {
|
|
|
|
| 60 |
top_p?: number
|
| 61 |
top_k?: number
|
| 62 |
do_sample?: boolean
|
| 63 |
+
dtype: QuantizationType
|
| 64 |
}
|
| 65 |
|
| 66 |
const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
|