Spaces:
Running
Running
| import * as webllm from "@mlc-ai/web-llm"; | |
| import rehypeStringify from "rehype-stringify"; | |
| import remarkFrontmatter from "remark-frontmatter"; | |
| import remarkGfm from "remark-gfm"; | |
| import RemarkBreaks from "remark-breaks"; | |
| import remarkParse from "remark-parse"; | |
| import remarkRehype from "remark-rehype"; | |
| import RehypeKatex from "rehype-katex"; | |
| import { unified } from "unified"; | |
| import remarkMath from "remark-math"; | |
| import rehypeHighlight from "rehype-highlight"; | |
| /*************** WebLLM logic ***************/ | |
| const messageFormatter = unified() | |
| .use(remarkParse) | |
| .use(remarkFrontmatter) | |
| .use(remarkMath) | |
| .use(remarkGfm) | |
| .use(RemarkBreaks) | |
| .use(remarkRehype) | |
| .use(rehypeStringify) | |
| .use(RehypeKatex) | |
| .use(rehypeHighlight, { | |
| detect: true, | |
| ignoreMissing: true, | |
| }); | |
| const messages = [ | |
| { | |
| content: "You are a helpful AI agent helping users.", | |
| role: "system", | |
| }, | |
| ]; | |
| // Callback function for initializing progress | |
| function updateEngineInitProgressCallback(report) { | |
| console.log("initialize", report.progress); | |
| document.getElementById("download-status").textContent = report.text; | |
| } | |
| // Create engine instance | |
| let modelLoaded = false; | |
| const engine = new webllm.MLCEngine(); | |
| engine.setInitProgressCallback(updateEngineInitProgressCallback); | |
| async function initializeWebLLMEngine() { | |
| const quantization = document.getElementById("quantization").value; | |
| const context_window_size = parseInt(document.getElementById("context").value); | |
| const temperature = parseFloat(document.getElementById("temperature").value); | |
| const top_p = parseFloat(document.getElementById("top_p").value); | |
| const presence_penalty = parseFloat(document.getElementById("presence_penalty").value); | |
| const frequency_penalty = parseFloat(document.getElementById("frequency_penalty").value); | |
| document.getElementById("download-status").classList.remove("hidden"); | |
| const selectedModel = `Phi-3.5-mini-instruct-${quantization}_1-MLC`; | |
| const config = { | |
| temperature, | |
| top_p, | |
| frequency_penalty, | |
| presence_penalty, | |
| context_window_size, | |
| }; | |
| console.log(`Loading Model: ${selectedModel}`); | |
| console.log(`Config: ${JSON.stringify(config)}`); | |
| await engine.reload(selectedModel, config); | |
| modelLoaded = true; | |
| } | |
| async function streamingGenerating(messages, onUpdate, onFinish, onError) { | |
| try { | |
| let curMessage = ""; | |
| let usage; | |
| const completion = await engine.chat.completions.create({ | |
| stream: true, | |
| messages, | |
| stream_options: { include_usage: true }, | |
| }); | |
| for await (const chunk of completion) { | |
| const curDelta = chunk.choices[0]?.delta.content; | |
| if (curDelta) { | |
| curMessage += curDelta; | |
| } | |
| if (chunk.usage) { | |
| usage = chunk.usage; | |
| } | |
| onUpdate(curMessage); | |
| } | |
| const finalMessage = await engine.getMessage(); | |
| onFinish(finalMessage, usage); | |
| } catch (err) { | |
| onError(err); | |
| } | |
| } | |
| /*************** UI logic ***************/ | |
| function onMessageSend() { | |
| if (!modelLoaded) { | |
| return; | |
| } | |
| const input = document.getElementById("user-input").value.trim(); | |
| const message = { | |
| content: input, | |
| role: "user", | |
| }; | |
| if (input.length === 0) { | |
| return; | |
| } | |
| document.getElementById("send").disabled = true; | |
| messages.push(message); | |
| appendMessage(message); | |
| document.getElementById("user-input").value = ""; | |
| document | |
| .getElementById("user-input") | |
| .setAttribute("placeholder", "Generating..."); | |
| const aiMessage = { | |
| content: "typing...", | |
| role: "assistant", | |
| }; | |
| appendMessage(aiMessage); | |
| const onFinishGenerating = async (finalMessage, usage) => { | |
| updateLastMessage(finalMessage); | |
| document.getElementById("send").disabled = false; | |
| const usageText = | |
| `prompt_tokens: ${usage.prompt_tokens}, ` + | |
| `completion_tokens: ${usage.completion_tokens}, ` + | |
| `prefill: ${usage.extra.prefill_tokens_per_s.toFixed(4)} tokens/sec, ` + | |
| `decoding: ${usage.extra.decode_tokens_per_s.toFixed(4)} tokens/sec`; | |
| document.getElementById("chat-stats").classList.remove("hidden"); | |
| document.getElementById("chat-stats").textContent = usageText; | |
| document | |
| .getElementById("user-input") | |
| .setAttribute("placeholder", "Type a message..."); | |
| }; | |
| streamingGenerating( | |
| messages, | |
| updateLastMessage, | |
| onFinishGenerating, | |
| console.error | |
| ); | |
| } | |
| function appendMessage(message) { | |
| const chatBox = document.getElementById("chat-box"); | |
| const container = document.createElement("div"); | |
| container.classList.add("message-container"); | |
| const newMessage = document.createElement("div"); | |
| newMessage.classList.add("message"); | |
| newMessage.textContent = message.content; | |
| if (message.role === "user") { | |
| container.classList.add("user"); | |
| } else { | |
| container.classList.add("assistant"); | |
| } | |
| container.appendChild(newMessage); | |
| chatBox.appendChild(container); | |
| chatBox.scrollTop = chatBox.scrollHeight; // Scroll to the latest message | |
| } | |
| async function updateLastMessage(content) { | |
| const formattedMessage = await messageFormatter.process(content); | |
| const messageDoms = document | |
| .getElementById("chat-box") | |
| .querySelectorAll(".message"); | |
| const lastMessageDom = messageDoms[messageDoms.length - 1]; | |
| lastMessageDom.innerHTML = formattedMessage; | |
| } | |
| /*************** UI binding ***************/ | |
| document.getElementById("download").addEventListener("click", function () { | |
| document.getElementById("send").disabled = true; | |
| initializeWebLLMEngine().then(() => { | |
| document.getElementById("send").disabled = false; | |
| }); | |
| }); | |
| document.getElementById("send").addEventListener("click", function () { | |
| onMessageSend(); | |
| }); | |
| document.getElementById("user-input").addEventListener("keydown", (event) => { | |
| if (event.key === "Enter") { | |
| onMessageSend(); | |
| } | |
| }); | |
| window.onload = function () { | |
| document.getElementById("download").textContent = "Download"; | |
| document.getElementById("download").disabled = false; | |
| } |