kakasher commited on
Commit
5a72e50
·
1 Parent(s): b1e5ee7

debuging model loading

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -24,22 +24,23 @@ def load_class_labels(file_path='idx_to_class.json'):
24
  print(f"Error: {file_path} not found. Class labels will not be available.")
25
  return {}
26
 
27
- # Load model
28
- def load_model(model_path):
29
- try:
30
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47)
31
- model.load_state_dict(torch.load(model_path, weights_only=True))
32
- model.eval()
33
- return model
34
- except Exception as e:
35
- print(f"Error loading model: {str(e)}")
36
- return None
37
 
38
  # Global variables
39
  toxicity_data = load_toxicity_data()
40
  idx_to_class = load_class_labels()
41
  model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar')
42
- model = load_model(model_path)
 
 
 
 
 
 
 
 
 
43
 
44
  # Define the transformation
45
  transform = transforms.Compose([
 
24
  print(f"Error: {file_path} not found. Class labels will not be available.")
25
  return {}
26
 
27
+
28
+
 
 
 
 
 
 
 
 
29
 
30
  # Global variables
31
  toxicity_data = load_toxicity_data()
32
  idx_to_class = load_class_labels()
33
  model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar')
34
+ model = None
35
+
36
+ # Load model
37
+ try:
38
+ model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47)
39
+ model.load_state_dict(torch.load(model_path, weights_only=True))
40
+ model.eval()
41
+ except Exception as e:
42
+ print(f"Error loading model: {str(e)}")
43
+
44
 
45
  # Define the transformation
46
  transform = transforms.Compose([