Zai commited on
Commit
f5e1dde
·
1 Parent(s): 7afa131

Solve torch model loading issue

Browse files
Files changed (1) hide show
  1. interface.py +13 -3
interface.py CHANGED
@@ -74,8 +74,8 @@ def show_download_screen():
74
  def main_app():
75
  """Main app UI after model is loaded"""
76
 
77
- @st.cache_resource
78
- def load_model():
79
  model_config = ModelConfig()
80
  tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
81
 
@@ -85,7 +85,13 @@ def main_app():
85
  model_config.vocab_size = VOCAB_SIZE
86
  model = BurmeseGPT(model_config)
87
 
88
- checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
 
 
 
 
 
 
89
  model.load_state_dict(checkpoint["model_state_dict"])
90
  model.eval()
91
 
@@ -94,6 +100,10 @@ def main_app():
94
 
95
  return model, tokenizer, device
96
 
 
 
 
 
97
  # Load model with spinner
98
  with st.spinner("Loading model..."):
99
  model, tokenizer, device = load_model()
 
74
  def main_app():
75
  """Main app UI after model is loaded"""
76
 
77
+ def load_model_safely():
78
+ """Load model with proper safety settings"""
79
  model_config = ModelConfig()
80
  tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
81
 
 
85
  model_config.vocab_size = VOCAB_SIZE
86
  model = BurmeseGPT(model_config)
87
 
88
+ # Attempt safe loading first
89
+ try:
90
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True)
91
+ except Exception as e:
92
+ st.warning("Using less secure loading method - only do this with trusted checkpoints")
93
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False)
94
+
95
  model.load_state_dict(checkpoint["model_state_dict"])
96
  model.eval()
97
 
 
100
 
101
  return model, tokenizer, device
102
 
103
+ @st.cache_resource
104
+ def load_model():
105
+ return load_model_safely()
106
+
107
  # Load model with spinner
108
  with st.spinner("Loading model..."):
109
  model, tokenizer, device = load_model()