burmese-gpt / interface.py
Zai
Solve torch model loading issue
f5e1dde
import torch
from transformers import AutoTokenizer
import streamlit as st
from burmese_gpt.config import ModelConfig
from burmese_gpt.models import BurmeseGPT
from scripts.download import download_pretrained_model
import os
# Configuration
VOCAB_SIZE = 119547
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "best_model.pth")
# Create checkpoints directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# --- App Layout ---
st.set_page_config(
page_title="Burmese GPT",
page_icon=":speech_balloon:",
layout="wide"
)
# --- Text Generation Function ---
def generate_text(model, tokenizer, device, prompt, max_length=50, temperature=0.7):
"""Generate text from prompt"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
logits = outputs[:, -1, :]
# Apply temperature
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
# --- Download Screen ---
def show_download_screen():
"""Shows download screen until model is ready"""
st.title("Burmese GPT")
st.warning("Downloading required model files...")
progress_bar = st.progress(0)
status_text = st.empty()
try:
download_pretrained_model()
# Verify download completed
if os.path.exists(CHECKPOINT_PATH):
st.success("Download completed successfully!")
st.rerun() # Restart the app
else:
st.error("Download failed - file not found")
st.stop()
except Exception as e:
st.error(f"Download failed: {str(e)}")
st.stop()
# --- Main App ---
def main_app():
"""Main app UI after model is loaded"""
def load_model_safely():
"""Load model with proper safety settings"""
model_config = ModelConfig()
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_config.vocab_size = VOCAB_SIZE
model = BurmeseGPT(model_config)
# Attempt safe loading first
try:
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True)
except Exception as e:
st.warning("Using less secure loading method - only do this with trusted checkpoints")
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, tokenizer, device
@st.cache_resource
def load_model():
return load_model_safely()
# Load model with spinner
with st.spinner("Loading model..."):
model, tokenizer, device = load_model()
# Sidebar
st.sidebar.title("Burmese GPT")
st.sidebar.write("A language model for generating and chatting in Burmese")
# View selection
view_options = ["Text Generation", "Chat Mode"]
selected_view = st.sidebar.selectbox("Select Mode", view_options)
# Generation parameters
st.sidebar.header("Generation Settings")
max_length = st.sidebar.slider("Max Length", 20, 500, 100)
temperature = st.sidebar.slider("Temperature", 0.1, 2.0, 0.7, 0.1)
# Main content area
if selected_view == "Text Generation":
st.header("Burmese Text Generation")
# Prompt input
prompt = st.text_area(
"Enter your prompt in Burmese:",
value="မြန်မာစာပေ",
height=100
)
# Generate button
if st.button("Generate Text"):
if prompt.strip():
with st.spinner("Generating..."):
generated = generate_text(
model=model,
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
temperature=temperature
)
st.subheader("Generated Text:")
st.write(generated)
else:
st.warning("Please enter a prompt")
elif selected_view == "Chat Mode":
st.header("Chat in Burmese")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": "α€™α€„α€Ία€Ήα€‚α€œα€¬α€•α€«! ကျေးဇူးပြု၍ စကားပြောပါ။"}
]
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Type your message..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner("Thinking..."):
# Combine chat history for context
chat_history = "\n".join(
f"{msg['role']}: {msg['content']}"
for msg in st.session_state.messages[:-1]
)
full_prompt = f"{chat_history}\nuser: {prompt}\nassistant:"
# Generate response
full_response = generate_text(
model=model,
tokenizer=tokenizer,
device=device,
prompt=full_prompt,
max_length=max_length,
temperature=temperature
)
# Display response
message_placeholder.markdown(full_response)
# Add assistant response to chat history
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
# --- App Flow Control ---
if not os.path.exists(CHECKPOINT_PATH):
show_download_screen()
else:
main_app()