Arpit-Shukla-20233080 commited on
Commit
040d94c
Β·
1 Parent(s): 3609d20

deploy app

Browse files
Files changed (2) hide show
  1. app.py +130 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageDraw
3
+ import torch
4
+ import numpy as np
5
+ from torchvision import transforms, models
6
+ from ultralytics import YOLO
7
+ from transformers import pipeline
8
+ import tempfile
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # ─── 0)class NAMES───_____─────────────
12
+ class_names = [
13
+ "Barn Swallow", "Bay breasted Warbler", "Black and white Warbler", "Black billed Cuckoo",
14
+ "Black throated Blue Warbler", "Black throated Sparrow", "Blue Grosbeak", "Blue Jay",
15
+ "Bobolink", "Bohemian Waxwing", "Bronzed Cowbird", "Brown Creeper",
16
+ "Brown Pelican", "Brown Thrasher", "Canada Warbler", "Cape Glossy Starling",
17
+ "Cape May Warbler", "Cardinal", "Carolina Wren", "Caspian Tern",
18
+ "Cedar Waxwing", "Cerulean Warbler", "Chuck will Widow", "Clark Nutcracker",
19
+ "Common Yellowthroat", "Eared Grebe", "Eastern Towhee", "European Goldfinch",
20
+ "Evening Grosbeak", "Forsters Tern", "Fox Sparrow", "Geococcyx",
21
+ "Golden winged Warbler", "Gray Kingbird", "Gray crowned Rosy Finch", "Green Jay",
22
+ "Green Violetear", "Green tailed Towhee", "Harris Sparrow", "Heermann Gull",
23
+ "Hooded Merganser", "Hooded Oriole", "Hooded Warbler", "Horned Grebe",
24
+ "Horned Lark", "Horned Puffin", "Ivory Gull", "Lazuli Bunting",
25
+ "Le Conte Sparrow", "Least Auklet", "Loggerhead Shrike", "Magnolia Warbler",
26
+ "Mallard", "Myrtle Warbler", "Nashville Warbler", "Nelson Sharp tailed Sparrow",
27
+ "Nighthawk", "Ovenbird", "Pacific Loon", "Painted Bunting",
28
+ "Palm Warbler", "Parakeet Auklet", "Pied Kingfisher", "Pied billed Grebe",
29
+ "Pine Grosbeak", "Pine Warbler", "Prairie Warbler", "Prothonotary Warbler",
30
+ "Purple Finch", "Red bellied Woodpecker", "Red cockaded Woodpecker", "Red eyed Vireo",
31
+ "Red winged Blackbird", "Rhinoceros Auklet", "Rose breasted Grosbeak", "Ruby throated Hummingbird",
32
+ "Rufous Hummingbird", "Savannah Sparrow", "Sayornis", "Scarlet Tanager",
33
+ "Scissor tailed Flycatcher", "Spotted Catbird", "Summer Tanager", "Tree Swallow",
34
+ "Tropical Kingbird", "Vermilion Flycatcher", "Vesper Sparrow", "Warbling Vireo",
35
+ "Western Meadowlark", "Western Wood Pewee", "Whip poor Will", "White Pelican",
36
+ "White breasted Kingfisher", "White breasted Nuthatch", "White crowned Sparrow", "White throated Sparrow",
37
+ "Yellow Warbler", "Yellow breasted Chat", "Yellow headed Blackbird", "Yellow throated Vireo"
38
+ ]
39
+
40
+ # ─── 1) CACHING MODEL LOADERS ──────────────────────────────
41
+ @st.cache_resource
42
+ def load_yolo_model() -> YOLO:
43
+ path = hf_hub_download(
44
+ repo_id="arpit8210/cubmodels",
45
+ filename="epoch45.pt",
46
+ )
47
+ return YOLO(path)
48
+
49
+ @st.cache_resource
50
+ def load_resnet_model(device: str = "cpu") -> torch.nn.Module:
51
+ sd_path = hf_hub_download(
52
+ repo_id="arpit8210/cubmodels",
53
+ filename="model_state_dict_best.pth",
54
+ )
55
+ model = models.resnet152(weights=None)
56
+ model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
57
+
58
+ raw_sd = torch.load(sd_path, map_location="cpu")
59
+ new_sd = {
60
+ (k[len("module."): ] if k.startswith("module.") else k): v
61
+ for k, v in raw_sd.items()
62
+ }
63
+ model.load_state_dict(new_sd)
64
+
65
+ model = torch.nn.DataParallel(model).to(device).eval()
66
+ return model
67
+
68
+ @st.cache_resource
69
+ def load_swin_pipeline():
70
+ return pipeline("image-classification",
71
+ model="Emiel/cub-200-bird-classifier-swin")
72
+
73
+ # ─── 2) TRANSFORM FOR RESNET ─────────────────────────────────
74
+ resnet_transform = transforms.Compose([
75
+ transforms.Resize((256, 256)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize([0.485, 0.456, 0.406],
78
+ [0.229, 0.224, 0.225])
79
+ ])
80
+
81
+ # ─── 3) STREAMLIT UI ────────────────────────────────────────
82
+ st.title("🐦 Bird Species Classification")
83
+
84
+ uploaded = st.file_uploader("Upload a bird image", type=["jpg", "jpeg", "png"])
85
+ if not uploaded:
86
+ st.info("Please upload an image.")
87
+ st.stop()
88
+
89
+ img = Image.open(uploaded).convert("RGB")
90
+ st.image(img, caption="Uploaded Image", use_container_width=True)
91
+
92
+ # Load models (no args here)
93
+ yolo_model = load_yolo_model()
94
+ resnet_model = load_resnet_model()
95
+ hf_pipe = load_swin_pipeline()
96
+
97
+ # Run YOLO detection
98
+ results = yolo_model(np.array(img))
99
+ boxes = results[0].boxes.xyxy.cpu().numpy()
100
+
101
+ annotated = img.copy()
102
+ draw = ImageDraw.Draw(annotated)
103
+
104
+ # Precompute HF result on the full image
105
+ hf_res_full = hf_pipe(img)
106
+ hf_label, hf_conf = hf_res_full[0]["label"], hf_res_full[0]["score"]
107
+
108
+ for x1, y1, x2, y2 in boxes.astype(int):
109
+ crop = img.crop((x1, y1, x2, y2))
110
+
111
+ # ResNet prediction on the crop
112
+ inp = resnet_transform(crop).unsqueeze(0)
113
+ with torch.no_grad():
114
+ out = resnet_model(inp)
115
+ probs = torch.softmax(out, dim=1)[0]
116
+ ridx = torch.argmax(probs).item()
117
+ rlabel, rconf = class_names[ridx], probs[ridx].item()
118
+
119
+ # Choose best
120
+ if hf_conf > rconf:
121
+ label, score = hf_label, hf_conf
122
+ else:
123
+ label, score = rlabel, rconf
124
+
125
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
126
+ draw.text((x1, max(0, y1-15)),
127
+ f"{label} ({score:.2f})",
128
+ fill="red")
129
+
130
+ st.image(annotated, caption="Detections & Predictions", use_container_width=True)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.32
2
+ torch>=2.1
3
+ torchvision>=0.16
4
+ ultralytics>=8.0
5
+ transformers>=4.39
6
+ Pillow>=9.0
7
+ numpy>=1.24
8
+ scikit-learn>=1.3
9
+ huggingface-hub>=0.16.0
10
+ matplotlib>=3.7
11
+ tqdm>=4.66