import os
import logging
from flask import Flask, render_template_string, send_file, abort, request, jsonify
from huggingface_hub import hf_hub_download, login as hf_login
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_token = os.getenv("HF_TOKEN")
try:
hf_login(token=hf_token)
logger.info("Hugging Face Login Successful")
except Exception as e:
logger.error(f"Hugging Face Login Error: {e}")
app = Flask(__name__)
MODEL_FILENAME = 'gemma3-1b-it-int4.task'
HUGGINGFACE_REPO = 'litert-community/Gemma3-1B-IT'
MODEL_LOCAL_PATH = os.path.join(os.getcwd(), MODEL_FILENAME)
loaded_model = None
def download_model_file():
if not os.path.exists(MODEL_LOCAL_PATH):
logger.info("Model file not found locally. Downloading from Hugging Face...")
try:
hf_hub_download(repo_id=HUGGINGFACE_REPO, filename=MODEL_FILENAME, local_dir=".", local_dir_use_symlinks=False)
logger.info(f"Download Completed: {MODEL_LOCAL_PATH}")
except Exception as e:
logger.error(f"Error downloading model file: {e}")
raise
else:
logger.info("Model file already exists locally.")
return MODEL_LOCAL_PATH
download_model_file()
@app.route('/download')
def download_model():
if os.path.exists(MODEL_LOCAL_PATH):
return send_file(MODEL_LOCAL_PATH, as_attachment=True, download_name=MODEL_FILENAME)
else:
abort(404)
def perform_inference(input_text: str) -> str:
return "Backend inference is not used with MediaPipe. Inference happens in the browser."
HTML_CONTENT = """
LLM Chatbot Demo with MediaPipe
"""
JS_CONTENT = """
import { FilesetResolver, LlmInference } from 'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/dist/tasks-genai.js';
const chatInput = document.getElementById('input');
const sendButton = document.getElementById('submit');
const chatHistory = document.getElementById('chat-history');
const loadingIndicator = document.getElementById('loading-indicator');
const clearChatButton = document.getElementById('clear-chat-button');
const modelStatus = document.getElementById('model-status');
const outputElement = document.getElementById('output');
let isModelLoaded = false;
let messageHistory = [];
let llmInference;
function createMessageElement(text, isUserMessage) {
const messageDiv = document.createElement('div');
messageDiv.classList.add('message');
text = text.replace(/\\n/g, '
');
messageDiv.innerHTML = text;
if (isUserMessage) {
messageDiv.classList.add('user-message');
} else {
messageDiv.classList.add('bot-message');
}
const timestampDiv = document.createElement('div');
timestampDiv.classList.add('timestamp');
const now = new Date();
const hours = String(now.getHours()).padStart(2, '0');
const minutes = String(now.getMinutes()).padStart(2, '0');
timestampDiv.textContent = `${hours}:${minutes}`;
messageDiv.appendChild(timestampDiv);
return messageDiv;
}
function displayBotMessage(text) {
const botMessageElement = createMessageElement(text, false);
chatHistory.appendChild(botMessageElement);
chatHistory.scrollTop = chatHistory.scrollHeight;
messageHistory.push({ text: text, isUserMessage: false });
sendButton.disabled = false;
loadingIndicator.style.display = 'none';
}
function renderMessageHistory() {
chatHistory.innerHTML = '';
messageHistory.forEach(message => {
const messageElement = createMessageElement(message.text, message.isUserMessage);
chatHistory.appendChild(messageElement);
});
chatHistory.scrollTop = chatHistory.scrollHeight;
}
function clearChatHistory() {
messageHistory = [];
renderMessageHistory();
}
function displayPartialResults(partialResults, complete) {
outputElement.textContent += partialResults;
if (complete) {
if (!outputElement.textContent) {
outputElement.textContent = 'Result is empty';
}
sendButton.disabled = false;
loadingIndicator.style.display = 'none';
outputElement.style.display = 'none';
displayBotMessage(outputElement.textContent);
}
}
async function initializeChatbot() {
try {
const filesetResolver = await FilesetResolver.forGenAiTasks('https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/dist/wasm');
llmInference = await LlmInference.createFromFileset(filesetResolver, '/download');
isModelLoaded = true;
sendButton.disabled = false;
chatInput.disabled = false;
modelStatus.textContent = "Model Loaded";
modelStatus.classList.remove('loading');
modelStatus.classList.add('loaded');
console.log('MediaPipe GenAI Model Initialized.');
} catch (error) {
console.error("Error initializing MediaPipe GenAI:", error);
modelStatus.textContent = "Model Load Error";
modelStatus.classList.remove('loading');
modelStatus.classList.add('error');
} finally {
loadingIndicator.style.display = 'none';
renderMessageHistory();
}
}
sendButton.onclick = async () => {
if (!isModelLoaded) {
alert('Chatbot is not initialized yet.');
return;
}
const userMessageText = chatInput.value.trim();
if (!userMessageText) return;
const userMessageElement = createMessageElement(userMessageText, true);
chatHistory.appendChild(userMessageElement);
chatHistory.scrollTop = chatHistory.scrollHeight;
messageHistory.push({ text: userMessageText, isUserMessage: true });
chatInput.value = '';
sendButton.disabled = true;
loadingIndicator.style.display = 'inline-block';
outputElement.style.display = 'block';
outputElement.textContent = '';
try {
await llmInference.generateResponse(userMessageText, displayPartialResults);
} catch (error) {
console.error("Inference error:", error);
displayBotMessage('Error generating response. Please try again.');
sendButton.disabled = false;
loadingIndicator.style.display = 'none';
outputElement.style.display = 'none';
}
};
chatInput.addEventListener('keydown', (event) => {
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
sendButton.click();
}
});
clearChatButton.onclick = clearChatHistory;
document.addEventListener('DOMContentLoaded', initializeChatbot);
"""
@app.route('/')
def index():
return render_template_string(HTML_CONTENT)
@app.route('/index.js')
def serve_js():
return JS_CONTENT, 200, {'Content-Type': 'application/javascript'}
@app.route('/api/infer', methods=['POST'])
def api_infer():
return jsonify({'error': 'Backend inference is not used with MediaPipe. Inference happens in the browser.'}), 501
if __name__ == '__main__':
logger.info("Starting Flask application on port 7860")
app.run(debug=True, host="0.0.0.0", port=7860)