Spaces:
Sleeping
Sleeping
File size: 6,883 Bytes
6936ef7 575ca09 fac5833 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 7afa131 6936ef7 f5e1dde 7afa131 f5e1dde 7afa131 f5e1dde 7afa131 6936ef7 7afa131 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import torch
from transformers import AutoTokenizer
import streamlit as st
from burmese_gpt.config import ModelConfig
from burmese_gpt.models import BurmeseGPT
from scripts.download import download_pretrained_model
import os
# Configuration
VOCAB_SIZE = 119547
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "best_model.pth")
# Create checkpoints directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# --- App Layout ---
st.set_page_config(
page_title="Burmese GPT",
page_icon=":speech_balloon:",
layout="wide"
)
# --- Text Generation Function ---
def generate_text(model, tokenizer, device, prompt, max_length=50, temperature=0.7):
"""Generate text from prompt"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
logits = outputs[:, -1, :]
# Apply temperature
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
# --- Download Screen ---
def show_download_screen():
"""Shows download screen until model is ready"""
st.title("Burmese GPT")
st.warning("Downloading required model files...")
progress_bar = st.progress(0)
status_text = st.empty()
try:
download_pretrained_model()
# Verify download completed
if os.path.exists(CHECKPOINT_PATH):
st.success("Download completed successfully!")
st.rerun() # Restart the app
else:
st.error("Download failed - file not found")
st.stop()
except Exception as e:
st.error(f"Download failed: {str(e)}")
st.stop()
# --- Main App ---
def main_app():
"""Main app UI after model is loaded"""
def load_model_safely():
"""Load model with proper safety settings"""
model_config = ModelConfig()
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_config.vocab_size = VOCAB_SIZE
model = BurmeseGPT(model_config)
# Attempt safe loading first
try:
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True)
except Exception as e:
st.warning("Using less secure loading method - only do this with trusted checkpoints")
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, tokenizer, device
@st.cache_resource
def load_model():
return load_model_safely()
# Load model with spinner
with st.spinner("Loading model..."):
model, tokenizer, device = load_model()
# Sidebar
st.sidebar.title("Burmese GPT")
st.sidebar.write("A language model for generating and chatting in Burmese")
# View selection
view_options = ["Text Generation", "Chat Mode"]
selected_view = st.sidebar.selectbox("Select Mode", view_options)
# Generation parameters
st.sidebar.header("Generation Settings")
max_length = st.sidebar.slider("Max Length", 20, 500, 100)
temperature = st.sidebar.slider("Temperature", 0.1, 2.0, 0.7, 0.1)
# Main content area
if selected_view == "Text Generation":
st.header("Burmese Text Generation")
# Prompt input
prompt = st.text_area(
"Enter your prompt in Burmese:",
value="ααΌααΊαα¬α
α¬αα±",
height=100
)
# Generate button
if st.button("Generate Text"):
if prompt.strip():
with st.spinner("Generating..."):
generated = generate_text(
model=model,
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
temperature=temperature
)
st.subheader("Generated Text:")
st.write(generated)
else:
st.warning("Please enter a prompt")
elif selected_view == "Chat Mode":
st.header("Chat in Burmese")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": "αααΊαΉααα¬αα«! αα»α±αΈαα°αΈααΌα―α α
αα¬αΈααΌα±α¬αα«α"}
]
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Type your message..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner("Thinking..."):
# Combine chat history for context
chat_history = "\n".join(
f"{msg['role']}: {msg['content']}"
for msg in st.session_state.messages[:-1]
)
full_prompt = f"{chat_history}\nuser: {prompt}\nassistant:"
# Generate response
full_response = generate_text(
model=model,
tokenizer=tokenizer,
device=device,
prompt=full_prompt,
max_length=max_length,
temperature=temperature
)
# Display response
message_placeholder.markdown(full_response)
# Add assistant response to chat history
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
# --- App Flow Control ---
if not os.path.exists(CHECKPOINT_PATH):
show_download_screen()
else:
main_app() |