Zai commited on
Commit
7afa131
Β·
1 Parent(s): 5dca0f4

Solve huggingface space error

Browse files
Files changed (1) hide show
  1. interface.py +166 -116
interface.py CHANGED
@@ -6,50 +6,37 @@ from burmese_gpt.models import BurmeseGPT
6
  from scripts.download import download_pretrained_model
7
  import os
8
 
9
- # Model configuration
10
  VOCAB_SIZE = 119547
11
- CHECKPOINT_PATH = "checkpoints/best_model.pth"
 
12
 
13
- if os.path.exists(CHECKPOINT_PATH):
14
- st.warning("Model already exists, skipping download.")
15
- else:
16
- st.info("Downloading model...")
17
- download_pretrained_model()
18
- st.success("Model downloaded successfully.")
19
-
20
-
21
- # Load model function (cached to avoid reloading on every interaction)
22
- @st.cache_resource
23
- def load_model():
24
- model_config = ModelConfig()
25
-
26
- tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
27
- if tokenizer.pad_token is None:
28
- tokenizer.pad_token = tokenizer.eos_token
29
 
30
- model_config.vocab_size = VOCAB_SIZE
31
- model = BurmeseGPT(model_config)
32
-
33
- # Load checkpoint
34
- checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
35
- model.load_state_dict(checkpoint["model_state_dict"])
36
- model.eval()
37
-
38
- # Move to device
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- model.to(device)
41
-
42
- return model, tokenizer, device
43
 
44
 
45
- def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_length=50):
 
46
  """Generate text from prompt"""
47
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
48
 
49
  with torch.no_grad():
50
  for _ in range(max_length):
51
  outputs = model(input_ids)
52
- next_token = outputs[:, -1, :].argmax(dim=-1, keepdim=True)
 
 
 
 
 
 
53
  input_ids = torch.cat((input_ids, next_token), dim=-1)
54
 
55
  if next_token.item() == tokenizer.eos_token_id:
@@ -58,94 +45,157 @@ def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_l
58
  return tokenizer.decode(input_ids[0], skip_special_tokens=True)
59
 
60
 
61
- # Set up the page layout
62
- st.set_page_config(
63
- page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide"
64
- )
 
65
 
66
- # Create a sidebar with a title and a brief description
67
- st.sidebar.title("Burmese GPT")
68
- st.sidebar.write("A language models app for generating and chatting in Burmese.")
69
 
70
- # Create a selectbox to choose the view
71
- view_options = ["Sampling", "Chat Interface"]
72
- selected_view = st.sidebar.selectbox("Select a view:", view_options)
73
 
74
- # Load the model once (cached)
75
- model, tokenizer, device = load_model()
 
 
 
 
 
76
 
77
- # Create a main area
78
- if selected_view == "Sampling":
79
- st.title("Sampling")
80
- st.write("Generate text using the pre-trained models:")
81
 
82
- # Create a text input field for the prompt
83
- prompt = st.text_input("Prompt:", value="မြန်မာ")
84
 
85
- # Add additional generation parameters
86
- col1, col2 = st.columns(2)
87
- with col1:
88
- max_length = st.slider("Max Length:", min_value=10, max_value=500, value=50)
89
- with col2:
90
- temperature = st.slider(
91
- "Temperature:", min_value=0.1, max_value=2.0, value=0.7, step=0.1
92
- )
93
 
94
- # Create a button to generate text
95
- if st.button("Generate"):
96
- if prompt.strip():
97
- with st.spinner("Generating text..."):
98
- generated = generate_sample(
99
- model=model,
100
- tokenizer=tokenizer,
101
- device=device,
102
- prompt=prompt,
103
- max_length=max_length,
104
- )
105
- st.text_area("Generated Text:", value=generated, height=200)
106
- else:
107
- st.warning("Please enter a prompt")
108
-
109
- elif selected_view == "Chat Interface":
110
- st.title("Chat Interface")
111
- st.write("Chat with the fine-tuned models:")
112
-
113
- # Initialize chat history
114
- if "messages" not in st.session_state:
115
- st.session_state.messages = []
116
-
117
- # Display chat messages from history on app rerun
118
- for message in st.session_state.messages:
119
- with st.chat_message(message["role"]):
120
- st.markdown(message["content"])
121
-
122
- # Accept user input
123
- if prompt := st.chat_input("What is up?"):
124
- # Add user message to chat history
125
- st.session_state.messages.append({"role": "user", "content": prompt})
126
- # Display user message in chat message container
127
- with st.chat_message("user"):
128
- st.markdown(prompt)
129
-
130
- # Display assistant response in chat message container
131
- with st.chat_message("assistant"):
132
- message_placeholder = st.empty()
133
- full_response = ""
134
-
135
- with st.spinner("Thinking..."):
136
- # Generate response
137
- generated = generate_sample(
138
- model=model,
139
- tokenizer=tokenizer,
140
- device=device,
141
- prompt=prompt,
142
- max_length=100,
143
- )
144
- full_response = generated
145
-
146
- message_placeholder.markdown(full_response)
147
-
148
- # Add assistant response to chat history
149
- st.session_state.messages.append(
150
- {"role": "assistant", "content": full_response}
151
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from scripts.download import download_pretrained_model
7
  import os
8
 
9
+ # Configuration
10
  VOCAB_SIZE = 119547
11
+ CHECKPOINT_DIR = "checkpoints"
12
+ CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "best_model.pth")
13
 
14
+ # Create checkpoints directory if it doesn't exist
15
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # --- App Layout ---
18
+ st.set_page_config(
19
+ page_title="Burmese GPT",
20
+ page_icon=":speech_balloon:",
21
+ layout="wide"
22
+ )
 
 
 
 
 
 
 
23
 
24
 
25
+ # --- Text Generation Function ---
26
+ def generate_text(model, tokenizer, device, prompt, max_length=50, temperature=0.7):
27
  """Generate text from prompt"""
28
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
29
 
30
  with torch.no_grad():
31
  for _ in range(max_length):
32
  outputs = model(input_ids)
33
+ logits = outputs[:, -1, :]
34
+
35
+ # Apply temperature
36
+ logits = logits / temperature
37
+ probs = torch.softmax(logits, dim=-1)
38
+ next_token = torch.multinomial(probs, num_samples=1)
39
+
40
  input_ids = torch.cat((input_ids, next_token), dim=-1)
41
 
42
  if next_token.item() == tokenizer.eos_token_id:
 
45
  return tokenizer.decode(input_ids[0], skip_special_tokens=True)
46
 
47
 
48
+ # --- Download Screen ---
49
+ def show_download_screen():
50
+ """Shows download screen until model is ready"""
51
+ st.title("Burmese GPT")
52
+ st.warning("Downloading required model files...")
53
 
54
+ progress_bar = st.progress(0)
55
+ status_text = st.empty()
 
56
 
57
+ try:
58
+ download_pretrained_model()
 
59
 
60
+ # Verify download completed
61
+ if os.path.exists(CHECKPOINT_PATH):
62
+ st.success("Download completed successfully!")
63
+ st.rerun() # Restart the app
64
+ else:
65
+ st.error("Download failed - file not found")
66
+ st.stop()
67
 
68
+ except Exception as e:
69
+ st.error(f"Download failed: {str(e)}")
70
+ st.stop()
 
71
 
 
 
72
 
73
+ # --- Main App ---
74
+ def main_app():
75
+ """Main app UI after model is loaded"""
 
 
 
 
 
76
 
77
+ @st.cache_resource
78
+ def load_model():
79
+ model_config = ModelConfig()
80
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
81
+
82
+ if tokenizer.pad_token is None:
83
+ tokenizer.pad_token = tokenizer.eos_token
84
+
85
+ model_config.vocab_size = VOCAB_SIZE
86
+ model = BurmeseGPT(model_config)
87
+
88
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
89
+ model.load_state_dict(checkpoint["model_state_dict"])
90
+ model.eval()
91
+
92
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ model.to(device)
94
+
95
+ return model, tokenizer, device
96
+
97
+ # Load model with spinner
98
+ with st.spinner("Loading model..."):
99
+ model, tokenizer, device = load_model()
100
+
101
+ # Sidebar
102
+ st.sidebar.title("Burmese GPT")
103
+ st.sidebar.write("A language model for generating and chatting in Burmese")
104
+
105
+ # View selection
106
+ view_options = ["Text Generation", "Chat Mode"]
107
+ selected_view = st.sidebar.selectbox("Select Mode", view_options)
108
+
109
+ # Generation parameters
110
+ st.sidebar.header("Generation Settings")
111
+ max_length = st.sidebar.slider("Max Length", 20, 500, 100)
112
+ temperature = st.sidebar.slider("Temperature", 0.1, 2.0, 0.7, 0.1)
113
+
114
+ # Main content area
115
+ if selected_view == "Text Generation":
116
+ st.header("Burmese Text Generation")
117
+
118
+ # Prompt input
119
+ prompt = st.text_area(
120
+ "Enter your prompt in Burmese:",
121
+ value="မြန်မာစာပေ",
122
+ height=100
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
+
125
+ # Generate button
126
+ if st.button("Generate Text"):
127
+ if prompt.strip():
128
+ with st.spinner("Generating..."):
129
+ generated = generate_text(
130
+ model=model,
131
+ tokenizer=tokenizer,
132
+ device=device,
133
+ prompt=prompt,
134
+ max_length=max_length,
135
+ temperature=temperature
136
+ )
137
+ st.subheader("Generated Text:")
138
+ st.write(generated)
139
+ else:
140
+ st.warning("Please enter a prompt")
141
+
142
+ elif selected_view == "Chat Mode":
143
+ st.header("Chat in Burmese")
144
+
145
+ # Initialize chat history
146
+ if "messages" not in st.session_state:
147
+ st.session_state.messages = [
148
+ {"role": "assistant", "content": "α€™α€„α€Ία€Ήα€‚α€œα€¬α€•α€«! ကျေးဇူးပြု၍ စကားပြောပါ။"}
149
+ ]
150
+
151
+ # Display chat messages
152
+ for message in st.session_state.messages:
153
+ with st.chat_message(message["role"]):
154
+ st.markdown(message["content"])
155
+
156
+ # Chat input
157
+ if prompt := st.chat_input("Type your message..."):
158
+ # Add user message to chat history
159
+ st.session_state.messages.append({"role": "user", "content": prompt})
160
+
161
+ # Display user message
162
+ with st.chat_message("user"):
163
+ st.markdown(prompt)
164
+
165
+ # Generate assistant response
166
+ with st.chat_message("assistant"):
167
+ message_placeholder = st.empty()
168
+ full_response = ""
169
+
170
+ with st.spinner("Thinking..."):
171
+ # Combine chat history for context
172
+ chat_history = "\n".join(
173
+ f"{msg['role']}: {msg['content']}"
174
+ for msg in st.session_state.messages[:-1]
175
+ )
176
+ full_prompt = f"{chat_history}\nuser: {prompt}\nassistant:"
177
+
178
+ # Generate response
179
+ full_response = generate_text(
180
+ model=model,
181
+ tokenizer=tokenizer,
182
+ device=device,
183
+ prompt=full_prompt,
184
+ max_length=max_length,
185
+ temperature=temperature
186
+ )
187
+
188
+ # Display response
189
+ message_placeholder.markdown(full_response)
190
+
191
+ # Add assistant response to chat history
192
+ st.session_state.messages.append(
193
+ {"role": "assistant", "content": full_response}
194
+ )
195
+
196
+
197
+ # --- App Flow Control ---
198
+ if not os.path.exists(CHECKPOINT_PATH):
199
+ show_download_screen()
200
+ else:
201
+ main_app()