Iruos8805 commited on
Commit
fb8ed98
·
verified ·
1 Parent(s): 55bd520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -4,12 +4,13 @@ import torch
4
  import gradio as gr
5
  from transformers import pipeline
6
 
 
 
 
7
  # Loading in Model
8
  model_name = "imjeffhi/pokemon_classifier"
9
- model = ViTForImageClassification.from_pretrained(model_name).to(device)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
11
- model.eval()
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
 
 
4
  import gradio as gr
5
  from transformers import pipeline
6
 
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
  # Loading in Model
11
  model_name = "imjeffhi/pokemon_classifier"
 
12
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
13
+ model = ViTForImageClassification.from_pretrained(model_name).to(device)
 
14
  model.to(device)
15
 
16