Arpit-Shukla-20233080
commited on
Commit
Β·
040d94c
1
Parent(s):
3609d20
deploy app
Browse files- app.py +130 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image, ImageDraw
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torchvision import transforms, models
|
6 |
+
from ultralytics import YOLO
|
7 |
+
from transformers import pipeline
|
8 |
+
import tempfile
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
# βββ 0)class NAMESβββ_____βββββββββββββ
|
12 |
+
class_names = [
|
13 |
+
"Barn Swallow", "Bay breasted Warbler", "Black and white Warbler", "Black billed Cuckoo",
|
14 |
+
"Black throated Blue Warbler", "Black throated Sparrow", "Blue Grosbeak", "Blue Jay",
|
15 |
+
"Bobolink", "Bohemian Waxwing", "Bronzed Cowbird", "Brown Creeper",
|
16 |
+
"Brown Pelican", "Brown Thrasher", "Canada Warbler", "Cape Glossy Starling",
|
17 |
+
"Cape May Warbler", "Cardinal", "Carolina Wren", "Caspian Tern",
|
18 |
+
"Cedar Waxwing", "Cerulean Warbler", "Chuck will Widow", "Clark Nutcracker",
|
19 |
+
"Common Yellowthroat", "Eared Grebe", "Eastern Towhee", "European Goldfinch",
|
20 |
+
"Evening Grosbeak", "Forsters Tern", "Fox Sparrow", "Geococcyx",
|
21 |
+
"Golden winged Warbler", "Gray Kingbird", "Gray crowned Rosy Finch", "Green Jay",
|
22 |
+
"Green Violetear", "Green tailed Towhee", "Harris Sparrow", "Heermann Gull",
|
23 |
+
"Hooded Merganser", "Hooded Oriole", "Hooded Warbler", "Horned Grebe",
|
24 |
+
"Horned Lark", "Horned Puffin", "Ivory Gull", "Lazuli Bunting",
|
25 |
+
"Le Conte Sparrow", "Least Auklet", "Loggerhead Shrike", "Magnolia Warbler",
|
26 |
+
"Mallard", "Myrtle Warbler", "Nashville Warbler", "Nelson Sharp tailed Sparrow",
|
27 |
+
"Nighthawk", "Ovenbird", "Pacific Loon", "Painted Bunting",
|
28 |
+
"Palm Warbler", "Parakeet Auklet", "Pied Kingfisher", "Pied billed Grebe",
|
29 |
+
"Pine Grosbeak", "Pine Warbler", "Prairie Warbler", "Prothonotary Warbler",
|
30 |
+
"Purple Finch", "Red bellied Woodpecker", "Red cockaded Woodpecker", "Red eyed Vireo",
|
31 |
+
"Red winged Blackbird", "Rhinoceros Auklet", "Rose breasted Grosbeak", "Ruby throated Hummingbird",
|
32 |
+
"Rufous Hummingbird", "Savannah Sparrow", "Sayornis", "Scarlet Tanager",
|
33 |
+
"Scissor tailed Flycatcher", "Spotted Catbird", "Summer Tanager", "Tree Swallow",
|
34 |
+
"Tropical Kingbird", "Vermilion Flycatcher", "Vesper Sparrow", "Warbling Vireo",
|
35 |
+
"Western Meadowlark", "Western Wood Pewee", "Whip poor Will", "White Pelican",
|
36 |
+
"White breasted Kingfisher", "White breasted Nuthatch", "White crowned Sparrow", "White throated Sparrow",
|
37 |
+
"Yellow Warbler", "Yellow breasted Chat", "Yellow headed Blackbird", "Yellow throated Vireo"
|
38 |
+
]
|
39 |
+
|
40 |
+
# βββ 1) CACHING MODEL LOADERS ββββββββββββββββββββββββββββββ
|
41 |
+
@st.cache_resource
|
42 |
+
def load_yolo_model() -> YOLO:
|
43 |
+
path = hf_hub_download(
|
44 |
+
repo_id="arpit8210/cubmodels",
|
45 |
+
filename="epoch45.pt",
|
46 |
+
)
|
47 |
+
return YOLO(path)
|
48 |
+
|
49 |
+
@st.cache_resource
|
50 |
+
def load_resnet_model(device: str = "cpu") -> torch.nn.Module:
|
51 |
+
sd_path = hf_hub_download(
|
52 |
+
repo_id="arpit8210/cubmodels",
|
53 |
+
filename="model_state_dict_best.pth",
|
54 |
+
)
|
55 |
+
model = models.resnet152(weights=None)
|
56 |
+
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
|
57 |
+
|
58 |
+
raw_sd = torch.load(sd_path, map_location="cpu")
|
59 |
+
new_sd = {
|
60 |
+
(k[len("module."): ] if k.startswith("module.") else k): v
|
61 |
+
for k, v in raw_sd.items()
|
62 |
+
}
|
63 |
+
model.load_state_dict(new_sd)
|
64 |
+
|
65 |
+
model = torch.nn.DataParallel(model).to(device).eval()
|
66 |
+
return model
|
67 |
+
|
68 |
+
@st.cache_resource
|
69 |
+
def load_swin_pipeline():
|
70 |
+
return pipeline("image-classification",
|
71 |
+
model="Emiel/cub-200-bird-classifier-swin")
|
72 |
+
|
73 |
+
# βββ 2) TRANSFORM FOR RESNET βββββββββββββββββββββββββββββββββ
|
74 |
+
resnet_transform = transforms.Compose([
|
75 |
+
transforms.Resize((256, 256)),
|
76 |
+
transforms.ToTensor(),
|
77 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
78 |
+
[0.229, 0.224, 0.225])
|
79 |
+
])
|
80 |
+
|
81 |
+
# βββ 3) STREAMLIT UI ββββββββββββββββββββββββββββββββββββββββ
|
82 |
+
st.title("π¦ Bird Species Classification")
|
83 |
+
|
84 |
+
uploaded = st.file_uploader("Upload a bird image", type=["jpg", "jpeg", "png"])
|
85 |
+
if not uploaded:
|
86 |
+
st.info("Please upload an image.")
|
87 |
+
st.stop()
|
88 |
+
|
89 |
+
img = Image.open(uploaded).convert("RGB")
|
90 |
+
st.image(img, caption="Uploaded Image", use_container_width=True)
|
91 |
+
|
92 |
+
# Load models (no args here)
|
93 |
+
yolo_model = load_yolo_model()
|
94 |
+
resnet_model = load_resnet_model()
|
95 |
+
hf_pipe = load_swin_pipeline()
|
96 |
+
|
97 |
+
# Run YOLO detection
|
98 |
+
results = yolo_model(np.array(img))
|
99 |
+
boxes = results[0].boxes.xyxy.cpu().numpy()
|
100 |
+
|
101 |
+
annotated = img.copy()
|
102 |
+
draw = ImageDraw.Draw(annotated)
|
103 |
+
|
104 |
+
# Precompute HF result on the full image
|
105 |
+
hf_res_full = hf_pipe(img)
|
106 |
+
hf_label, hf_conf = hf_res_full[0]["label"], hf_res_full[0]["score"]
|
107 |
+
|
108 |
+
for x1, y1, x2, y2 in boxes.astype(int):
|
109 |
+
crop = img.crop((x1, y1, x2, y2))
|
110 |
+
|
111 |
+
# ResNet prediction on the crop
|
112 |
+
inp = resnet_transform(crop).unsqueeze(0)
|
113 |
+
with torch.no_grad():
|
114 |
+
out = resnet_model(inp)
|
115 |
+
probs = torch.softmax(out, dim=1)[0]
|
116 |
+
ridx = torch.argmax(probs).item()
|
117 |
+
rlabel, rconf = class_names[ridx], probs[ridx].item()
|
118 |
+
|
119 |
+
# Choose best
|
120 |
+
if hf_conf > rconf:
|
121 |
+
label, score = hf_label, hf_conf
|
122 |
+
else:
|
123 |
+
label, score = rlabel, rconf
|
124 |
+
|
125 |
+
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
|
126 |
+
draw.text((x1, max(0, y1-15)),
|
127 |
+
f"{label} ({score:.2f})",
|
128 |
+
fill="red")
|
129 |
+
|
130 |
+
st.image(annotated, caption="Detections & Predictions", use_container_width=True)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit>=1.32
|
2 |
+
torch>=2.1
|
3 |
+
torchvision>=0.16
|
4 |
+
ultralytics>=8.0
|
5 |
+
transformers>=4.39
|
6 |
+
Pillow>=9.0
|
7 |
+
numpy>=1.24
|
8 |
+
scikit-learn>=1.3
|
9 |
+
huggingface-hub>=0.16.0
|
10 |
+
matplotlib>=3.7
|
11 |
+
tqdm>=4.66
|