File size: 5,464 Bytes
040d94c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import streamlit as st
from PIL import Image, ImageDraw
import torch
import numpy as np
from torchvision import transforms, models
from ultralytics import YOLO
from transformers import pipeline
import tempfile
from huggingface_hub import hf_hub_download
# βββ 0)class NAMESβββ_____βββββββββββββ
class_names = [
"Barn Swallow", "Bay breasted Warbler", "Black and white Warbler", "Black billed Cuckoo",
"Black throated Blue Warbler", "Black throated Sparrow", "Blue Grosbeak", "Blue Jay",
"Bobolink", "Bohemian Waxwing", "Bronzed Cowbird", "Brown Creeper",
"Brown Pelican", "Brown Thrasher", "Canada Warbler", "Cape Glossy Starling",
"Cape May Warbler", "Cardinal", "Carolina Wren", "Caspian Tern",
"Cedar Waxwing", "Cerulean Warbler", "Chuck will Widow", "Clark Nutcracker",
"Common Yellowthroat", "Eared Grebe", "Eastern Towhee", "European Goldfinch",
"Evening Grosbeak", "Forsters Tern", "Fox Sparrow", "Geococcyx",
"Golden winged Warbler", "Gray Kingbird", "Gray crowned Rosy Finch", "Green Jay",
"Green Violetear", "Green tailed Towhee", "Harris Sparrow", "Heermann Gull",
"Hooded Merganser", "Hooded Oriole", "Hooded Warbler", "Horned Grebe",
"Horned Lark", "Horned Puffin", "Ivory Gull", "Lazuli Bunting",
"Le Conte Sparrow", "Least Auklet", "Loggerhead Shrike", "Magnolia Warbler",
"Mallard", "Myrtle Warbler", "Nashville Warbler", "Nelson Sharp tailed Sparrow",
"Nighthawk", "Ovenbird", "Pacific Loon", "Painted Bunting",
"Palm Warbler", "Parakeet Auklet", "Pied Kingfisher", "Pied billed Grebe",
"Pine Grosbeak", "Pine Warbler", "Prairie Warbler", "Prothonotary Warbler",
"Purple Finch", "Red bellied Woodpecker", "Red cockaded Woodpecker", "Red eyed Vireo",
"Red winged Blackbird", "Rhinoceros Auklet", "Rose breasted Grosbeak", "Ruby throated Hummingbird",
"Rufous Hummingbird", "Savannah Sparrow", "Sayornis", "Scarlet Tanager",
"Scissor tailed Flycatcher", "Spotted Catbird", "Summer Tanager", "Tree Swallow",
"Tropical Kingbird", "Vermilion Flycatcher", "Vesper Sparrow", "Warbling Vireo",
"Western Meadowlark", "Western Wood Pewee", "Whip poor Will", "White Pelican",
"White breasted Kingfisher", "White breasted Nuthatch", "White crowned Sparrow", "White throated Sparrow",
"Yellow Warbler", "Yellow breasted Chat", "Yellow headed Blackbird", "Yellow throated Vireo"
]
# βββ 1) CACHING MODEL LOADERS ββββββββββββββββββββββββββββββ
@st.cache_resource
def load_yolo_model() -> YOLO:
path = hf_hub_download(
repo_id="arpit8210/cubmodels",
filename="epoch45.pt",
)
return YOLO(path)
@st.cache_resource
def load_resnet_model(device: str = "cpu") -> torch.nn.Module:
sd_path = hf_hub_download(
repo_id="arpit8210/cubmodels",
filename="model_state_dict_best.pth",
)
model = models.resnet152(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
raw_sd = torch.load(sd_path, map_location="cpu")
new_sd = {
(k[len("module."): ] if k.startswith("module.") else k): v
for k, v in raw_sd.items()
}
model.load_state_dict(new_sd)
model = torch.nn.DataParallel(model).to(device).eval()
return model
@st.cache_resource
def load_swin_pipeline():
return pipeline("image-classification",
model="Emiel/cub-200-bird-classifier-swin")
# βββ 2) TRANSFORM FOR RESNET βββββββββββββββββββββββββββββββββ
resnet_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# βββ 3) STREAMLIT UI ββββββββββββββββββββββββββββββββββββββββ
st.title("π¦ Bird Species Classification")
uploaded = st.file_uploader("Upload a bird image", type=["jpg", "jpeg", "png"])
if not uploaded:
st.info("Please upload an image.")
st.stop()
img = Image.open(uploaded).convert("RGB")
st.image(img, caption="Uploaded Image", use_container_width=True)
# Load models (no args here)
yolo_model = load_yolo_model()
resnet_model = load_resnet_model()
hf_pipe = load_swin_pipeline()
# Run YOLO detection
results = yolo_model(np.array(img))
boxes = results[0].boxes.xyxy.cpu().numpy()
annotated = img.copy()
draw = ImageDraw.Draw(annotated)
# Precompute HF result on the full image
hf_res_full = hf_pipe(img)
hf_label, hf_conf = hf_res_full[0]["label"], hf_res_full[0]["score"]
for x1, y1, x2, y2 in boxes.astype(int):
crop = img.crop((x1, y1, x2, y2))
# ResNet prediction on the crop
inp = resnet_transform(crop).unsqueeze(0)
with torch.no_grad():
out = resnet_model(inp)
probs = torch.softmax(out, dim=1)[0]
ridx = torch.argmax(probs).item()
rlabel, rconf = class_names[ridx], probs[ridx].item()
# Choose best
if hf_conf > rconf:
label, score = hf_label, hf_conf
else:
label, score = rlabel, rconf
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
draw.text((x1, max(0, y1-15)),
f"{label} ({score:.2f})",
fill="red")
st.image(annotated, caption="Detections & Predictions", use_container_width=True) |