Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer | |
import streamlit as st | |
from burmese_gpt.config import ModelConfig | |
from burmese_gpt.models import BurmeseGPT | |
# Model configuration | |
VOCAB_SIZE = 119547 | |
CHECKPOINT_PATH = "checkpoints/best_model.pth" | |
# Load model function (cached to avoid reloading on every interaction) | |
def load_model(): | |
model_config = ModelConfig() | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model_config.vocab_size = VOCAB_SIZE | |
model = BurmeseGPT(model_config) | |
# Load checkpoint | |
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu") | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
model.eval() | |
# Move to device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
return model, tokenizer, device | |
def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_length=50): | |
"""Generate text from prompt""" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
for _ in range(max_length): | |
outputs = model(input_ids) | |
next_token = outputs[:, -1, :].argmax(dim=-1, keepdim=True) | |
input_ids = torch.cat((input_ids, next_token), dim=-1) | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
# Set up the page layout | |
st.set_page_config( | |
page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide" | |
) | |
# Create a sidebar with a title and a brief description | |
st.sidebar.title("Burmese GPT") | |
st.sidebar.write("A language models app for generating and chatting in Burmese.") | |
# Create a selectbox to choose the view | |
view_options = ["Sampling", "Chat Interface"] | |
selected_view = st.sidebar.selectbox("Select a view:", view_options) | |
# Load the model once (cached) | |
model, tokenizer, device = load_model() | |
# Create a main area | |
if selected_view == "Sampling": | |
st.title("Sampling") | |
st.write("Generate text using the pre-trained models:") | |
# Create a text input field for the prompt | |
prompt = st.text_input("Prompt:", value="မြန်မာ") | |
# Add additional generation parameters | |
col1, col2 = st.columns(2) | |
with col1: | |
max_length = st.slider("Max Length:", min_value=10, max_value=500, value=50) | |
with col2: | |
temperature = st.slider( | |
"Temperature:", min_value=0.1, max_value=2.0, value=0.7, step=0.1 | |
) | |
# Create a button to generate text | |
if st.button("Generate"): | |
if prompt.strip(): | |
with st.spinner("Generating text..."): | |
generated = generate_sample( | |
model=model, | |
tokenizer=tokenizer, | |
device=device, | |
prompt=prompt, | |
max_length=max_length, | |
) | |
st.text_area("Generated Text:", value=generated, height=200) | |
else: | |
st.warning("Please enter a prompt") | |
elif selected_view == "Chat Interface": | |
st.title("Chat Interface") | |
st.write("Chat with the fine-tuned models:") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Accept user input | |
if prompt := st.chat_input("What is up?"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
with st.spinner("Thinking..."): | |
# Generate response | |
generated = generate_sample( | |
model=model, | |
tokenizer=tokenizer, | |
device=device, | |
prompt=prompt, | |
max_length=100, | |
) | |
full_response = generated | |
message_placeholder.markdown(full_response) | |
# Add assistant response to chat history | |
st.session_state.messages.append( | |
{"role": "assistant", "content": full_response} | |
) | |