Zai commited on
Commit
6936ef7
·
1 Parent(s): e66483a

Make scripts for upload and download

Browse files
burmese_gpt/config.py CHANGED
@@ -19,4 +19,4 @@ class TrainingConfig:
19
  log_dir: str = "logs"
20
  save_every: int = 1
21
  eval_every: int = 1
22
- dataset_url: str = "zaibutcooler/wiki-burmese"
 
19
  log_dir: str = "logs"
20
  save_every: int = 1
21
  eval_every: int = 1
22
+ dataset_url: str = "zaibutcooler/fine-burmese"
scripts/download.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import shutil
3
+ import os
4
+
5
+
6
+ def download_pretrained_model():
7
+ downloaded_path = hf_hub_download(
8
+ repo_id="zaibutcooler/burmese-gpt", filename="GPT.pth", cache_dir="checkpoint"
9
+ )
10
+
11
+ target_path = os.path.join("checkpoints", "best_model.pth")
12
+ shutil.copy(downloaded_path, target_path)
13
+
14
+ print(f"Saved to {target_path}")
15
+
16
+
17
+ if __name__ == "__main__":
18
+ download_pretrained_model()
scripts/sample.py CHANGED
@@ -6,10 +6,12 @@ from burmese_gpt.models import BurmeseGPT
6
  VOCAB_SIZE = 119547
7
  CHECKPOINT_PATH = "checkpoints/best_model.pth"
8
 
9
- def download_pretrained_model(path:str):
 
10
  pass
11
 
12
- def load_model(path:str):
 
13
  model_config = ModelConfig()
14
 
15
  tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
@@ -22,7 +24,7 @@ def load_model(path:str):
22
 
23
  # Load checkpoint
24
  checkpoint = torch.load(path, map_location="cpu")
25
- model.load_state_dict(checkpoint['model_state_dict'])
26
  model.eval()
27
 
28
  # Move to device
@@ -47,6 +49,7 @@ def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_l
47
 
48
  return tokenizer.decode(input_ids[0], skip_special_tokens=True)
49
 
 
50
  if __name__ == "__main__":
51
  # Download the pretrained model
52
  # download_pretrained_model(CHECKPOINT_PATH)
@@ -56,10 +59,10 @@ if __name__ == "__main__":
56
 
57
  while True:
58
  prompt = input("\nEnter prompt (or 'quit' to exit): ")
59
- if prompt.lower() == 'quit':
60
  break
61
 
62
  print("\nGenerating...")
63
  generated = generate_sample(model, tokenizer, device, prompt)
64
  print(f"\nPrompt: {prompt}")
65
- print(f"Generated: {generated}")
 
6
  VOCAB_SIZE = 119547
7
  CHECKPOINT_PATH = "checkpoints/best_model.pth"
8
 
9
+
10
+ def download_pretrained_model(path: str):
11
  pass
12
 
13
+
14
+ def load_model(path: str):
15
  model_config = ModelConfig()
16
 
17
  tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
 
24
 
25
  # Load checkpoint
26
  checkpoint = torch.load(path, map_location="cpu")
27
+ model.load_state_dict(checkpoint["model_state_dict"])
28
  model.eval()
29
 
30
  # Move to device
 
49
 
50
  return tokenizer.decode(input_ids[0], skip_special_tokens=True)
51
 
52
+
53
  if __name__ == "__main__":
54
  # Download the pretrained model
55
  # download_pretrained_model(CHECKPOINT_PATH)
 
59
 
60
  while True:
61
  prompt = input("\nEnter prompt (or 'quit' to exit): ")
62
+ if prompt.lower() == "quit":
63
  break
64
 
65
  print("\nGenerating...")
66
  generated = generate_sample(model, tokenizer, device, prompt)
67
  print(f"\nPrompt: {prompt}")
68
+ print(f"Generated: {generated}")
scripts/space.py DELETED
@@ -1,50 +0,0 @@
1
- import streamlit as st
2
-
3
- # Set up the page layout
4
- st.set_page_config(
5
- page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide"
6
- )
7
-
8
- # Create a sidebar with a title and a brief description
9
- st.sidebar.title("Burmese GPT")
10
- st.sidebar.write("A language models app for generating and chatting in Burmese.")
11
-
12
- # Create a selectbox to choose the view
13
- view_options = ["Sampling", "Chat Interface"]
14
- selected_view = st.sidebar.selectbox("Select a view:", view_options)
15
-
16
- # Create a main area
17
- if selected_view == "Sampling":
18
- st.title("Sampling")
19
- st.write("Generate text using the pre-trained models:")
20
-
21
- # Create a text input field for the prompt
22
- prompt = st.text_input("Prompt:", value="")
23
-
24
- # Create a slider to choose the temperature
25
- temperature = st.slider("Temperature:", min_value=0.0, max_value=1.0, value=0.5)
26
-
27
- # Create a button to generate text
28
- generate_button = st.button("Generate")
29
-
30
- # Create an output area to display the generated text
31
- output_area = st.text_area("Generated Text:", height=200, disabled=True)
32
-
33
- # Add some space between the input and output areas
34
- st.write("")
35
-
36
- elif selected_view == "Chat Interface":
37
- st.title("Chat Interface")
38
- st.write("Chat with the fine-tuned models:")
39
-
40
- # Create a text input field for the user input
41
- user_input = st.text_input("You:", value="")
42
-
43
- # Create a button to send the input to the models
44
- send_button = st.button("Send")
45
-
46
- # Create an output area to display the models's response
47
- response_area = st.text_area("Model:", height=200, disabled=True)
48
-
49
- # Add some space between the input and output areas
50
- st.write("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/upload.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import upload_file
2
+ import os
3
+
4
+
5
+ def upload_model():
6
+ if not os.path.exists("checkpoints/best_model.pth"):
7
+ print("File does not exist.")
8
+ return
9
+
10
+ upload_file(
11
+ path_or_fileobj="checkpoints/best_model.pth",
12
+ path_in_repo="GPT.pth",
13
+ repo_id="zaibutcooler/burmese-gpt",
14
+ repo_type="model",
15
+ )
16
+
17
+
18
+ if __name__ == "__main__":
19
+ upload_model()
space.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ import streamlit as st
4
+ from burmese_gpt.config import ModelConfig
5
+ from burmese_gpt.models import BurmeseGPT
6
+
7
+ # Model configuration
8
+ VOCAB_SIZE = 119547
9
+ CHECKPOINT_PATH = "checkpoints/best_model.pth"
10
+
11
+
12
+ # Load model function (cached to avoid reloading on every interaction)
13
+ @st.cache_resource
14
+ def load_model():
15
+ model_config = ModelConfig()
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
18
+ if tokenizer.pad_token is None:
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ model_config.vocab_size = VOCAB_SIZE
22
+ model = BurmeseGPT(model_config)
23
+
24
+ # Load checkpoint
25
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
26
+ model.load_state_dict(checkpoint["model_state_dict"])
27
+ model.eval()
28
+
29
+ # Move to device
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.to(device)
32
+
33
+ return model, tokenizer, device
34
+
35
+
36
+ def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_length=50):
37
+ """Generate text from prompt"""
38
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
39
+
40
+ with torch.no_grad():
41
+ for _ in range(max_length):
42
+ outputs = model(input_ids)
43
+ next_token = outputs[:, -1, :].argmax(dim=-1, keepdim=True)
44
+ input_ids = torch.cat((input_ids, next_token), dim=-1)
45
+
46
+ if next_token.item() == tokenizer.eos_token_id:
47
+ break
48
+
49
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
50
+
51
+
52
+ # Set up the page layout
53
+ st.set_page_config(
54
+ page_title="Burmese GPT", page_icon=":speech_balloon:", layout="wide"
55
+ )
56
+
57
+ # Create a sidebar with a title and a brief description
58
+ st.sidebar.title("Burmese GPT")
59
+ st.sidebar.write("A language models app for generating and chatting in Burmese.")
60
+
61
+ # Create a selectbox to choose the view
62
+ view_options = ["Sampling", "Chat Interface"]
63
+ selected_view = st.sidebar.selectbox("Select a view:", view_options)
64
+
65
+ # Load the model once (cached)
66
+ model, tokenizer, device = load_model()
67
+
68
+ # Create a main area
69
+ if selected_view == "Sampling":
70
+ st.title("Sampling")
71
+ st.write("Generate text using the pre-trained models:")
72
+
73
+ # Create a text input field for the prompt
74
+ prompt = st.text_input("Prompt:", value="မြန်မာ")
75
+
76
+ # Add additional generation parameters
77
+ col1, col2 = st.columns(2)
78
+ with col1:
79
+ max_length = st.slider("Max Length:", min_value=10, max_value=500, value=50)
80
+ with col2:
81
+ temperature = st.slider(
82
+ "Temperature:", min_value=0.1, max_value=2.0, value=0.7, step=0.1
83
+ )
84
+
85
+ # Create a button to generate text
86
+ if st.button("Generate"):
87
+ if prompt.strip():
88
+ with st.spinner("Generating text..."):
89
+ generated = generate_sample(
90
+ model=model,
91
+ tokenizer=tokenizer,
92
+ device=device,
93
+ prompt=prompt,
94
+ max_length=max_length,
95
+ )
96
+ st.text_area("Generated Text:", value=generated, height=200)
97
+ else:
98
+ st.warning("Please enter a prompt")
99
+
100
+ elif selected_view == "Chat Interface":
101
+ st.title("Chat Interface")
102
+ st.write("Chat with the fine-tuned models:")
103
+
104
+ # Initialize chat history
105
+ if "messages" not in st.session_state:
106
+ st.session_state.messages = []
107
+
108
+ # Display chat messages from history on app rerun
109
+ for message in st.session_state.messages:
110
+ with st.chat_message(message["role"]):
111
+ st.markdown(message["content"])
112
+
113
+ # Accept user input
114
+ if prompt := st.chat_input("What is up?"):
115
+ # Add user message to chat history
116
+ st.session_state.messages.append({"role": "user", "content": prompt})
117
+ # Display user message in chat message container
118
+ with st.chat_message("user"):
119
+ st.markdown(prompt)
120
+
121
+ # Display assistant response in chat message container
122
+ with st.chat_message("assistant"):
123
+ message_placeholder = st.empty()
124
+ full_response = ""
125
+
126
+ with st.spinner("Thinking..."):
127
+ # Generate response
128
+ generated = generate_sample(
129
+ model=model,
130
+ tokenizer=tokenizer,
131
+ device=device,
132
+ prompt=prompt,
133
+ max_length=100,
134
+ )
135
+ full_response = generated
136
+
137
+ message_placeholder.markdown(full_response)
138
+
139
+ # Add assistant response to chat history
140
+ st.session_state.messages.append(
141
+ {"role": "assistant", "content": full_response}
142
+ )