import * as webllm from "@mlc-ai/web-llm"; import rehypeStringify from "rehype-stringify"; import remarkFrontmatter from "remark-frontmatter"; import remarkGfm from "remark-gfm"; import RemarkBreaks from "remark-breaks"; import remarkParse from "remark-parse"; import remarkRehype from "remark-rehype"; import RehypeKatex from "rehype-katex"; import { unified } from "unified"; import remarkMath from "remark-math"; import rehypeHighlight from "rehype-highlight"; /*************** WebLLM logic ***************/ const messageFormatter = unified() .use(remarkParse) .use(remarkFrontmatter) .use(remarkMath) .use(remarkGfm) .use(RemarkBreaks) .use(remarkRehype) .use(rehypeStringify) .use(RehypeKatex) .use(rehypeHighlight, { detect: true, ignoreMissing: true, }); const messages = [ { content: "You are a helpful AI agent helping users.", role: "system", }, ]; // Callback function for initializing progress function updateEngineInitProgressCallback(report) { console.log("initialize", report.progress); document.getElementById("download-status").textContent = report.text; } // Create engine instance let modelLoaded = false; const engine = new webllm.MLCEngine(); engine.setInitProgressCallback(updateEngineInitProgressCallback); async function initializeWebLLMEngine() { const quantization = document.getElementById("quantization").value; const context_window_size = parseInt(document.getElementById("context").value); const temperature = parseFloat(document.getElementById("temperature").value); const top_p = parseFloat(document.getElementById("top_p").value); const presence_penalty = parseFloat(document.getElementById("presence_penalty").value); const frequency_penalty = parseFloat(document.getElementById("frequency_penalty").value); document.getElementById("download-status").classList.remove("hidden"); const selectedModel = `Phi-3.5-mini-instruct-${quantization}_1-MLC`; const config = { temperature, top_p, frequency_penalty, presence_penalty, context_window_size, }; console.log(`Loading Model: ${selectedModel}`); console.log(`Config: ${JSON.stringify(config)}`); await engine.reload(selectedModel, config); modelLoaded = true; } async function streamingGenerating(messages, onUpdate, onFinish, onError) { try { let curMessage = ""; let usage; const completion = await engine.chat.completions.create({ stream: true, messages, stream_options: { include_usage: true }, }); for await (const chunk of completion) { const curDelta = chunk.choices[0]?.delta.content; if (curDelta) { curMessage += curDelta; } if (chunk.usage) { usage = chunk.usage; } onUpdate(curMessage); } const finalMessage = await engine.getMessage(); onFinish(finalMessage, usage); } catch (err) { onError(err); } } /*************** UI logic ***************/ function onMessageSend() { if (!modelLoaded) { return; } const input = document.getElementById("user-input").value.trim(); const message = { content: input, role: "user", }; if (input.length === 0) { return; } document.getElementById("send").disabled = true; messages.push(message); appendMessage(message); document.getElementById("user-input").value = ""; document .getElementById("user-input") .setAttribute("placeholder", "Generating..."); const aiMessage = { content: "typing...", role: "assistant", }; appendMessage(aiMessage); const onFinishGenerating = async (finalMessage, usage) => { updateLastMessage(finalMessage); document.getElementById("send").disabled = false; const usageText = `prompt_tokens: ${usage.prompt_tokens}, ` + `completion_tokens: ${usage.completion_tokens}, ` + `prefill: ${usage.extra.prefill_tokens_per_s.toFixed(4)} tokens/sec, ` + `decoding: ${usage.extra.decode_tokens_per_s.toFixed(4)} tokens/sec`; document.getElementById("chat-stats").classList.remove("hidden"); document.getElementById("chat-stats").textContent = usageText; document .getElementById("user-input") .setAttribute("placeholder", "Type a message..."); }; streamingGenerating( messages, updateLastMessage, onFinishGenerating, console.error ); } function appendMessage(message) { const chatBox = document.getElementById("chat-box"); const container = document.createElement("div"); container.classList.add("message-container"); const newMessage = document.createElement("div"); newMessage.classList.add("message"); newMessage.textContent = message.content; if (message.role === "user") { container.classList.add("user"); } else { container.classList.add("assistant"); } container.appendChild(newMessage); chatBox.appendChild(container); chatBox.scrollTop = chatBox.scrollHeight; // Scroll to the latest message } async function updateLastMessage(content) { const formattedMessage = await messageFormatter.process(content); const messageDoms = document .getElementById("chat-box") .querySelectorAll(".message"); const lastMessageDom = messageDoms[messageDoms.length - 1]; lastMessageDom.innerHTML = formattedMessage; } /*************** UI binding ***************/ document.getElementById("download").addEventListener("click", function () { document.getElementById("send").disabled = true; initializeWebLLMEngine().then(() => { document.getElementById("send").disabled = false; }); }); document.getElementById("send").addEventListener("click", function () { onMessageSend(); }); document.getElementById("user-input").addEventListener("keydown", (event) => { if (event.key === "Enter") { onMessageSend(); } }); window.onload = function () { document.getElementById("download").textContent = "Download"; document.getElementById("download").disabled = false; }