divitmittal commited on
Commit
0fac5e0
·
1 Parent(s): 2d902d5

feat: load model from HuggingFace Hub

Browse files
Files changed (2) hide show
  1. app.py +11 -11
  2. requirements.txt +1 -0
app.py CHANGED
@@ -5,7 +5,7 @@ from torchvision import transforms
5
  from PIL import Image
6
  import numpy as np
7
  import os
8
- from urllib.request import urlretrieve
9
 
10
  # --- Model Definition ---
11
 
@@ -241,7 +241,7 @@ device = None
241
 
242
 
243
  def load_model():
244
- """Loads the model and caches it in a global variable."""
245
  global model, device
246
  if model is not None:
247
  return model, device
@@ -249,19 +249,19 @@ def load_model():
249
  try:
250
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
251
 
252
- MODEL_PATH = "best_model.pth"
 
 
 
 
 
253
 
254
  model_instance = FocalCrossViTHybrid(img_size=224).to(device)
255
 
256
- if not os.path.exists(MODEL_PATH):
257
- raise FileNotFoundError(
258
- "Model checkpoint 'best_model.pth' not found. Please ensure it is available in the Space."
259
- )
260
-
261
- checkpoint = torch.load(MODEL_PATH, map_location=device)
262
 
263
  state_dict = checkpoint.get("model_state_dict", checkpoint)
264
- if any(key.startswith("module.")):
265
  state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
266
 
267
  model_instance.load_state_dict(state_dict)
@@ -271,7 +271,7 @@ def load_model():
271
  return model, device
272
  except Exception as e:
273
  # Catch any exception during loading and show it in the UI
274
- raise gr.Error(f"Failed to load the model: {e}")
275
 
276
 
277
  # Image processing functions
 
5
  from PIL import Image
6
  import numpy as np
7
  import os
8
+ from huggingface_hub import hf_hub_download
9
 
10
  # --- Model Definition ---
11
 
 
241
 
242
 
243
  def load_model():
244
+ """Loads the model from HuggingFace Hub and caches it in a global variable."""
245
  global model, device
246
  if model is not None:
247
  return model, device
 
249
  try:
250
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
251
 
252
+ # Download model from HuggingFace Hub
253
+ model_path = hf_hub_download(
254
+ repo_id="divitmittal/HybridTransformer-MFIF",
255
+ filename="best_model.pth",
256
+ cache_dir="./model_cache"
257
+ )
258
 
259
  model_instance = FocalCrossViTHybrid(img_size=224).to(device)
260
 
261
+ checkpoint = torch.load(model_path, map_location=device)
 
 
 
 
 
262
 
263
  state_dict = checkpoint.get("model_state_dict", checkpoint)
264
+ if any(key.startswith("module.") for key in state_dict.keys()):
265
  state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
266
 
267
  model_instance.load_state_dict(state_dict)
 
271
  return model, device
272
  except Exception as e:
273
  # Catch any exception during loading and show it in the UI
274
+ raise gr.Error(f"Failed to load the model from HuggingFace Hub: {e}")
275
 
276
 
277
  # Image processing functions
requirements.txt CHANGED
@@ -3,3 +3,4 @@ torchvision
3
  gradio
4
  numpy
5
  Pillow
 
 
3
  gradio
4
  numpy
5
  Pillow
6
+ huggingface_hub