neetnestor's picture
feat: fix log output of model configuration
60c6432
raw
history blame
5.93 kB
import * as webllm from "@mlc-ai/web-llm";
import rehypeStringify from "rehype-stringify";
import remarkFrontmatter from "remark-frontmatter";
import remarkGfm from "remark-gfm";
import RemarkBreaks from "remark-breaks";
import remarkParse from "remark-parse";
import remarkRehype from "remark-rehype";
import RehypeKatex from "rehype-katex";
import { unified } from "unified";
import remarkMath from "remark-math";
import rehypeHighlight from "rehype-highlight";
/*************** WebLLM logic ***************/
const messageFormatter = unified()
.use(remarkParse)
.use(remarkFrontmatter)
.use(remarkMath)
.use(remarkGfm)
.use(RemarkBreaks)
.use(remarkRehype)
.use(rehypeStringify)
.use(RehypeKatex)
.use(rehypeHighlight, {
detect: true,
ignoreMissing: true,
});
const messages = [
{
content: "You are a helpful AI agent helping users.",
role: "system",
},
];
// Callback function for initializing progress
function updateEngineInitProgressCallback(report) {
console.log("initialize", report.progress);
document.getElementById("download-status").textContent = report.text;
}
// Create engine instance
let modelLoaded = false;
const engine = new webllm.MLCEngine();
engine.setInitProgressCallback(updateEngineInitProgressCallback);
async function initializeWebLLMEngine() {
const quantization = document.getElementById("quantization").value;
const context_window_size = parseInt(document.getElementById("context").value);
const temperature = parseFloat(document.getElementById("temperature").value);
const top_p = parseFloat(document.getElementById("top_p").value);
const presence_penalty = parseFloat(document.getElementById("presence_penalty").value);
const frequency_penalty = parseFloat(document.getElementById("frequency_penalty").value);
document.getElementById("download-status").classList.remove("hidden");
const selectedModel = `Phi-3.5-mini-instruct-${quantization}_1-MLC`;
const config = {
temperature,
top_p,
frequency_penalty,
presence_penalty,
context_window_size,
};
console.log(`Loading Model: ${selectedModel}`);
console.log(`Config: ${JSON.stringify(config)}`);
await engine.reload(selectedModel, config);
modelLoaded = true;
}
async function streamingGenerating(messages, onUpdate, onFinish, onError) {
try {
let curMessage = "";
let usage;
const completion = await engine.chat.completions.create({
stream: true,
messages,
stream_options: { include_usage: true },
});
for await (const chunk of completion) {
const curDelta = chunk.choices[0]?.delta.content;
if (curDelta) {
curMessage += curDelta;
}
if (chunk.usage) {
usage = chunk.usage;
}
onUpdate(curMessage);
}
const finalMessage = await engine.getMessage();
onFinish(finalMessage, usage);
} catch (err) {
onError(err);
}
}
/*************** UI logic ***************/
function onMessageSend() {
if (!modelLoaded) {
return;
}
const input = document.getElementById("user-input").value.trim();
const message = {
content: input,
role: "user",
};
if (input.length === 0) {
return;
}
document.getElementById("send").disabled = true;
messages.push(message);
appendMessage(message);
document.getElementById("user-input").value = "";
document
.getElementById("user-input")
.setAttribute("placeholder", "Generating...");
const aiMessage = {
content: "typing...",
role: "assistant",
};
appendMessage(aiMessage);
const onFinishGenerating = async (finalMessage, usage) => {
updateLastMessage(finalMessage);
document.getElementById("send").disabled = false;
const usageText =
`prompt_tokens: ${usage.prompt_tokens}, ` +
`completion_tokens: ${usage.completion_tokens}, ` +
`prefill: ${usage.extra.prefill_tokens_per_s.toFixed(4)} tokens/sec, ` +
`decoding: ${usage.extra.decode_tokens_per_s.toFixed(4)} tokens/sec`;
document.getElementById("chat-stats").classList.remove("hidden");
document.getElementById("chat-stats").textContent = usageText;
document
.getElementById("user-input")
.setAttribute("placeholder", "Type a message...");
};
streamingGenerating(
messages,
updateLastMessage,
onFinishGenerating,
console.error
);
}
function appendMessage(message) {
const chatBox = document.getElementById("chat-box");
const container = document.createElement("div");
container.classList.add("message-container");
const newMessage = document.createElement("div");
newMessage.classList.add("message");
newMessage.textContent = message.content;
if (message.role === "user") {
container.classList.add("user");
} else {
container.classList.add("assistant");
}
container.appendChild(newMessage);
chatBox.appendChild(container);
chatBox.scrollTop = chatBox.scrollHeight; // Scroll to the latest message
}
async function updateLastMessage(content) {
const formattedMessage = await messageFormatter.process(content);
const messageDoms = document
.getElementById("chat-box")
.querySelectorAll(".message");
const lastMessageDom = messageDoms[messageDoms.length - 1];
lastMessageDom.innerHTML = formattedMessage;
}
/*************** UI binding ***************/
document.getElementById("download").addEventListener("click", function () {
document.getElementById("send").disabled = true;
initializeWebLLMEngine().then(() => {
document.getElementById("send").disabled = false;
});
});
document.getElementById("send").addEventListener("click", function () {
onMessageSend();
});
document.getElementById("user-input").addEventListener("keydown", (event) => {
if (event.key === "Enter") {
onMessageSend();
}
});
window.onload = function () {
document.getElementById("download").textContent = "Download";
document.getElementById("download").disabled = false;
}