Spaces:
Running
Running
import * as webllm from "@mlc-ai/web-llm"; | |
import rehypeStringify from "rehype-stringify"; | |
import remarkFrontmatter from "remark-frontmatter"; | |
import remarkGfm from "remark-gfm"; | |
import RemarkBreaks from "remark-breaks"; | |
import remarkParse from "remark-parse"; | |
import remarkRehype from "remark-rehype"; | |
import RehypeKatex from "rehype-katex"; | |
import { unified } from "unified"; | |
import remarkMath from "remark-math"; | |
import rehypeHighlight from "rehype-highlight"; | |
/*************** WebLLM logic ***************/ | |
const messageFormatter = unified() | |
.use(remarkParse) | |
.use(remarkFrontmatter) | |
.use(remarkMath) | |
.use(remarkGfm) | |
.use(RemarkBreaks) | |
.use(remarkRehype) | |
.use(rehypeStringify) | |
.use(RehypeKatex) | |
.use(rehypeHighlight, { | |
detect: true, | |
ignoreMissing: true, | |
}); | |
const messages = [ | |
{ | |
content: "You are a helpful AI agent helping users.", | |
role: "system", | |
}, | |
]; | |
// Callback function for initializing progress | |
function updateEngineInitProgressCallback(report) { | |
console.log("initialize", report.progress); | |
document.getElementById("download-status").textContent = report.text; | |
} | |
// Create engine instance | |
let modelLoaded = false; | |
const engine = new webllm.MLCEngine(); | |
engine.setInitProgressCallback(updateEngineInitProgressCallback); | |
async function initializeWebLLMEngine() { | |
const quantization = document.getElementById("quantization").value; | |
const context_window_size = parseInt(document.getElementById("context").value); | |
const temperature = parseFloat(document.getElementById("temperature").value); | |
const top_p = parseFloat(document.getElementById("top_p").value); | |
const presence_penalty = parseFloat(document.getElementById("presence_penalty").value); | |
const frequency_penalty = parseFloat(document.getElementById("frequency_penalty").value); | |
document.getElementById("download-status").classList.remove("hidden"); | |
const selectedModel = `Phi-3.5-mini-instruct-${quantization}_1-MLC`; | |
const config = { | |
temperature, | |
top_p, | |
frequency_penalty, | |
presence_penalty, | |
context_window_size, | |
}; | |
console.log(`Loading Model: ${selectedModel}`); | |
console.log(`Config: ${JSON.stringify(config)}`); | |
await engine.reload(selectedModel, config); | |
modelLoaded = true; | |
} | |
async function streamingGenerating(messages, onUpdate, onFinish, onError) { | |
try { | |
let curMessage = ""; | |
let usage; | |
const completion = await engine.chat.completions.create({ | |
stream: true, | |
messages, | |
stream_options: { include_usage: true }, | |
}); | |
for await (const chunk of completion) { | |
const curDelta = chunk.choices[0]?.delta.content; | |
if (curDelta) { | |
curMessage += curDelta; | |
} | |
if (chunk.usage) { | |
usage = chunk.usage; | |
} | |
onUpdate(curMessage); | |
} | |
const finalMessage = await engine.getMessage(); | |
onFinish(finalMessage, usage); | |
} catch (err) { | |
onError(err); | |
} | |
} | |
/*************** UI logic ***************/ | |
function onMessageSend() { | |
if (!modelLoaded) { | |
return; | |
} | |
const input = document.getElementById("user-input").value.trim(); | |
const message = { | |
content: input, | |
role: "user", | |
}; | |
if (input.length === 0) { | |
return; | |
} | |
document.getElementById("send").disabled = true; | |
messages.push(message); | |
appendMessage(message); | |
document.getElementById("user-input").value = ""; | |
document | |
.getElementById("user-input") | |
.setAttribute("placeholder", "Generating..."); | |
const aiMessage = { | |
content: "typing...", | |
role: "assistant", | |
}; | |
appendMessage(aiMessage); | |
const onFinishGenerating = async (finalMessage, usage) => { | |
updateLastMessage(finalMessage); | |
document.getElementById("send").disabled = false; | |
const usageText = | |
`prompt_tokens: ${usage.prompt_tokens}, ` + | |
`completion_tokens: ${usage.completion_tokens}, ` + | |
`prefill: ${usage.extra.prefill_tokens_per_s.toFixed(4)} tokens/sec, ` + | |
`decoding: ${usage.extra.decode_tokens_per_s.toFixed(4)} tokens/sec`; | |
document.getElementById("chat-stats").classList.remove("hidden"); | |
document.getElementById("chat-stats").textContent = usageText; | |
document | |
.getElementById("user-input") | |
.setAttribute("placeholder", "Type a message..."); | |
}; | |
streamingGenerating( | |
messages, | |
updateLastMessage, | |
onFinishGenerating, | |
console.error | |
); | |
} | |
function appendMessage(message) { | |
const chatBox = document.getElementById("chat-box"); | |
const container = document.createElement("div"); | |
container.classList.add("message-container"); | |
const newMessage = document.createElement("div"); | |
newMessage.classList.add("message"); | |
newMessage.textContent = message.content; | |
if (message.role === "user") { | |
container.classList.add("user"); | |
} else { | |
container.classList.add("assistant"); | |
} | |
container.appendChild(newMessage); | |
chatBox.appendChild(container); | |
chatBox.scrollTop = chatBox.scrollHeight; // Scroll to the latest message | |
} | |
async function updateLastMessage(content) { | |
const formattedMessage = await messageFormatter.process(content); | |
const messageDoms = document | |
.getElementById("chat-box") | |
.querySelectorAll(".message"); | |
const lastMessageDom = messageDoms[messageDoms.length - 1]; | |
lastMessageDom.innerHTML = formattedMessage; | |
} | |
/*************** UI binding ***************/ | |
document.getElementById("download").addEventListener("click", function () { | |
document.getElementById("send").disabled = true; | |
initializeWebLLMEngine().then(() => { | |
document.getElementById("send").disabled = false; | |
}); | |
}); | |
document.getElementById("send").addEventListener("click", function () { | |
onMessageSend(); | |
}); | |
document.getElementById("user-input").addEventListener("keydown", (event) => { | |
if (event.key === "Enter") { | |
onMessageSend(); | |
} | |
}); | |
window.onload = function () { | |
document.getElementById("download").textContent = "Download"; | |
document.getElementById("download").disabled = false; | |
} |