File size: 6,883 Bytes
6936ef7
 
 
 
 
575ca09
fac5833
6936ef7
7afa131
6936ef7
7afa131
 
6936ef7
7afa131
 
6936ef7
7afa131
 
 
 
 
 
6936ef7
 
7afa131
 
6936ef7
 
 
 
 
 
7afa131
 
 
 
 
 
 
6936ef7
 
 
 
 
 
 
 
7afa131
 
 
 
 
6936ef7
7afa131
 
6936ef7
7afa131
 
6936ef7
7afa131
 
 
 
 
 
 
6936ef7
7afa131
 
 
6936ef7
 
7afa131
 
 
6936ef7
f5e1dde
 
7afa131
 
 
 
 
 
 
 
 
f5e1dde
 
 
 
 
 
 
7afa131
 
 
 
 
 
 
 
f5e1dde
 
 
 
7afa131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6936ef7
7afa131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
from transformers import AutoTokenizer
import streamlit as st
from burmese_gpt.config import ModelConfig
from burmese_gpt.models import BurmeseGPT
from scripts.download import download_pretrained_model
import os

# Configuration
VOCAB_SIZE = 119547
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "best_model.pth")

# Create checkpoints directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- App Layout ---
st.set_page_config(
    page_title="Burmese GPT",
    page_icon=":speech_balloon:",
    layout="wide"
)


# --- Text Generation Function ---
def generate_text(model, tokenizer, device, prompt, max_length=50, temperature=0.7):
    """Generate text from prompt"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            logits = outputs[:, -1, :]

            # Apply temperature
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            input_ids = torch.cat((input_ids, next_token), dim=-1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# --- Download Screen ---
def show_download_screen():
    """Shows download screen until model is ready"""
    st.title("Burmese GPT")
    st.warning("Downloading required model files...")

    progress_bar = st.progress(0)
    status_text = st.empty()

    try:
        download_pretrained_model()

        # Verify download completed
        if os.path.exists(CHECKPOINT_PATH):
            st.success("Download completed successfully!")
            st.rerun()  # Restart the app
        else:
            st.error("Download failed - file not found")
            st.stop()

    except Exception as e:
        st.error(f"Download failed: {str(e)}")
        st.stop()


# --- Main App ---
def main_app():
    """Main app UI after model is loaded"""

    def load_model_safely():
        """Load model with proper safety settings"""
        model_config = ModelConfig()
        tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model_config.vocab_size = VOCAB_SIZE
        model = BurmeseGPT(model_config)

        # Attempt safe loading first
        try:
            checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True)
        except Exception as e:
            st.warning("Using less secure loading method - only do this with trusted checkpoints")
            checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False)

        model.load_state_dict(checkpoint["model_state_dict"])
        model.eval()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        return model, tokenizer, device

    @st.cache_resource
    def load_model():
        return load_model_safely()

    # Load model with spinner
    with st.spinner("Loading model..."):
        model, tokenizer, device = load_model()

    # Sidebar
    st.sidebar.title("Burmese GPT")
    st.sidebar.write("A language model for generating and chatting in Burmese")

    # View selection
    view_options = ["Text Generation", "Chat Mode"]
    selected_view = st.sidebar.selectbox("Select Mode", view_options)

    # Generation parameters
    st.sidebar.header("Generation Settings")
    max_length = st.sidebar.slider("Max Length", 20, 500, 100)
    temperature = st.sidebar.slider("Temperature", 0.1, 2.0, 0.7, 0.1)

    # Main content area
    if selected_view == "Text Generation":
        st.header("Burmese Text Generation")

        # Prompt input
        prompt = st.text_area(
            "Enter your prompt in Burmese:",
            value="မြန်မာစာပေ",
            height=100
        )

        # Generate button
        if st.button("Generate Text"):
            if prompt.strip():
                with st.spinner("Generating..."):
                    generated = generate_text(
                        model=model,
                        tokenizer=tokenizer,
                        device=device,
                        prompt=prompt,
                        max_length=max_length,
                        temperature=temperature
                    )
                st.subheader("Generated Text:")
                st.write(generated)
            else:
                st.warning("Please enter a prompt")

    elif selected_view == "Chat Mode":
        st.header("Chat in Burmese")

        # Initialize chat history
        if "messages" not in st.session_state:
            st.session_state.messages = [
                {"role": "assistant", "content": "α€™α€„α€Ία€Ήα€‚α€œα€¬α€•α€«! ကျေးဇူးပြု၍ စကားပြောပါ။"}
            ]

        # Display chat messages
        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

        # Chat input
        if prompt := st.chat_input("Type your message..."):
            # Add user message to chat history
            st.session_state.messages.append({"role": "user", "content": prompt})

            # Display user message
            with st.chat_message("user"):
                st.markdown(prompt)

            # Generate assistant response
            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                full_response = ""

                with st.spinner("Thinking..."):
                    # Combine chat history for context
                    chat_history = "\n".join(
                        f"{msg['role']}: {msg['content']}"
                        for msg in st.session_state.messages[:-1]
                    )
                    full_prompt = f"{chat_history}\nuser: {prompt}\nassistant:"

                    # Generate response
                    full_response = generate_text(
                        model=model,
                        tokenizer=tokenizer,
                        device=device,
                        prompt=full_prompt,
                        max_length=max_length,
                        temperature=temperature
                    )

                # Display response
                message_placeholder.markdown(full_response)

            # Add assistant response to chat history
            st.session_state.messages.append(
                {"role": "assistant", "content": full_response}
            )


# --- App Flow Control ---
if not os.path.exists(CHECKPOINT_PATH):
    show_download_screen()
else:
    main_app()