Spaces:
Sleeping
Sleeping
Commit
·
0fac5e0
1
Parent(s):
2d902d5
feat: load model from HuggingFace Hub
Browse files- app.py +11 -11
- requirements.txt +1 -0
app.py
CHANGED
@@ -5,7 +5,7 @@ from torchvision import transforms
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
import os
|
8 |
-
from
|
9 |
|
10 |
# --- Model Definition ---
|
11 |
|
@@ -241,7 +241,7 @@ device = None
|
|
241 |
|
242 |
|
243 |
def load_model():
|
244 |
-
"""Loads the model and caches it in a global variable."""
|
245 |
global model, device
|
246 |
if model is not None:
|
247 |
return model, device
|
@@ -249,19 +249,19 @@ def load_model():
|
|
249 |
try:
|
250 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
251 |
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
model_instance = FocalCrossViTHybrid(img_size=224).to(device)
|
255 |
|
256 |
-
|
257 |
-
raise FileNotFoundError(
|
258 |
-
"Model checkpoint 'best_model.pth' not found. Please ensure it is available in the Space."
|
259 |
-
)
|
260 |
-
|
261 |
-
checkpoint = torch.load(MODEL_PATH, map_location=device)
|
262 |
|
263 |
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
264 |
-
if any(key.startswith("module.")):
|
265 |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
266 |
|
267 |
model_instance.load_state_dict(state_dict)
|
@@ -271,7 +271,7 @@ def load_model():
|
|
271 |
return model, device
|
272 |
except Exception as e:
|
273 |
# Catch any exception during loading and show it in the UI
|
274 |
-
raise gr.Error(f"Failed to load the model: {e}")
|
275 |
|
276 |
|
277 |
# Image processing functions
|
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
import os
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
# --- Model Definition ---
|
11 |
|
|
|
241 |
|
242 |
|
243 |
def load_model():
|
244 |
+
"""Loads the model from HuggingFace Hub and caches it in a global variable."""
|
245 |
global model, device
|
246 |
if model is not None:
|
247 |
return model, device
|
|
|
249 |
try:
|
250 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
251 |
|
252 |
+
# Download model from HuggingFace Hub
|
253 |
+
model_path = hf_hub_download(
|
254 |
+
repo_id="divitmittal/HybridTransformer-MFIF",
|
255 |
+
filename="best_model.pth",
|
256 |
+
cache_dir="./model_cache"
|
257 |
+
)
|
258 |
|
259 |
model_instance = FocalCrossViTHybrid(img_size=224).to(device)
|
260 |
|
261 |
+
checkpoint = torch.load(model_path, map_location=device)
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
264 |
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
265 |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
266 |
|
267 |
model_instance.load_state_dict(state_dict)
|
|
|
271 |
return model, device
|
272 |
except Exception as e:
|
273 |
# Catch any exception during loading and show it in the UI
|
274 |
+
raise gr.Error(f"Failed to load the model from HuggingFace Hub: {e}")
|
275 |
|
276 |
|
277 |
# Image processing functions
|
requirements.txt
CHANGED
@@ -3,3 +3,4 @@ torchvision
|
|
3 |
gradio
|
4 |
numpy
|
5 |
Pillow
|
|
|
|
3 |
gradio
|
4 |
numpy
|
5 |
Pillow
|
6 |
+
huggingface_hub
|