Shio-Koube commited on
Commit
11b3d03
·
verified ·
1 Parent(s): 1648368

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # Load a pre-trained image classification model
7
+ model_name = "microsoft/resnet-50"
8
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
9
+ model = AutoModelForImageClassification.from_pretrained(model_name)
10
+
11
+ def classify_image(image):
12
+ # Ensure the image is in RGB mode
13
+ if image is None:
14
+ return "No image uploaded"
15
+
16
+ # Convert image to RGB if needed
17
+ if image.mode != 'RGB':
18
+ image = image.convert('RGB')
19
+
20
+ # Preprocess the image
21
+ inputs = image_processor(images=image, return_tensors="pt")
22
+
23
+ # Perform prediction
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ logits = outputs.logits
27
+
28
+ # Get predictions
29
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
30
+
31
+ # Get class labels and handle fewer than 5 classes
32
+ labels = model.config.id2label
33
+ num_classes = len(labels)
34
+
35
+ # Determine number of predictions to show
36
+ top_k = min(num_classes, 3)
37
+
38
+ # Get top predictions
39
+ top_prob, top_indices = probabilities.topk(top_k)
40
+
41
+ # Format results
42
+ results = []
43
+ for prob, idx in zip(top_prob[0], top_indices[0]):
44
+ label = labels[idx.item()]
45
+ percentage = prob.item() * 100
46
+ results.append(f"{label}: {percentage:.2f}%")
47
+
48
+ return "\n".join(results)
49
+
50
+ # Create Gradio interface
51
+ iface = gr.Interface(
52
+ fn=classify_image,
53
+ inputs=gr.Image(type="pil"),
54
+ outputs=gr.Textbox(label="Top Predictions"),
55
+ title="Image Classification with Hugging Face",
56
+ description="Upload an image to get classification predictions"
57
+ )
58
+
59
+ # Launch the app
60
+ if __name__ == "__main__":
61
+ iface.launch()