Spaces:
Sleeping
Sleeping
File size: 4,682 Bytes
6936ef7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
from transformers import AutoTokenizer
import streamlit as st
from burmese_gpt.config import ModelConfig
from burmese_gpt.models import BurmeseGPT
# Model configuration
VOCAB_SIZE = 119547
CHECKPOINT_PATH = "checkpoints/best_model.pth"
# Load model function (cached to avoid reloading on every interaction)
@st.cache_resource
def load_model():
model_config = ModelConfig()
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_config.vocab_size = VOCAB_SIZE
model = BurmeseGPT(model_config)
# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, tokenizer, device
def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_length=50):
"""Generate text from prompt"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
next_token = outputs[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat((input_ids, next_token), dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Set up the page layout
st.set_page_config(
page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide"
)
# Create a sidebar with a title and a brief description
st.sidebar.title("Burmese GPT")
st.sidebar.write("A language models app for generating and chatting in Burmese.")
# Create a selectbox to choose the view
view_options = ["Sampling", "Chat Interface"]
selected_view = st.sidebar.selectbox("Select a view:", view_options)
# Load the model once (cached)
model, tokenizer, device = load_model()
# Create a main area
if selected_view == "Sampling":
st.title("Sampling")
st.write("Generate text using the pre-trained models:")
# Create a text input field for the prompt
prompt = st.text_input("Prompt:", value="မြန်မာ")
# Add additional generation parameters
col1, col2 = st.columns(2)
with col1:
max_length = st.slider("Max Length:", min_value=10, max_value=500, value=50)
with col2:
temperature = st.slider(
"Temperature:", min_value=0.1, max_value=2.0, value=0.7, step=0.1
)
# Create a button to generate text
if st.button("Generate"):
if prompt.strip():
with st.spinner("Generating text..."):
generated = generate_sample(
model=model,
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
)
st.text_area("Generated Text:", value=generated, height=200)
else:
st.warning("Please enter a prompt")
elif selected_view == "Chat Interface":
st.title("Chat Interface")
st.write("Chat with the fine-tuned models:")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("What is up?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner("Thinking..."):
# Generate response
generated = generate_sample(
model=model,
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=100,
)
full_response = generated
message_placeholder.markdown(full_response)
# Add assistant response to chat history
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
|