File size: 5,464 Bytes
040d94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
from PIL import Image, ImageDraw
import torch
import numpy as np
from torchvision import transforms, models
from ultralytics import YOLO
from transformers import pipeline
import tempfile
from huggingface_hub import hf_hub_download

# ─── 0)class NAMES───_____─────────────
class_names = [
    "Barn Swallow", "Bay breasted Warbler", "Black and white Warbler", "Black billed Cuckoo",
    "Black throated Blue Warbler", "Black throated Sparrow", "Blue Grosbeak", "Blue Jay",
    "Bobolink", "Bohemian Waxwing", "Bronzed Cowbird", "Brown Creeper",
    "Brown Pelican", "Brown Thrasher", "Canada Warbler", "Cape Glossy Starling",
    "Cape May Warbler", "Cardinal", "Carolina Wren", "Caspian Tern",
    "Cedar Waxwing", "Cerulean Warbler", "Chuck will Widow", "Clark Nutcracker",
    "Common Yellowthroat", "Eared Grebe", "Eastern Towhee", "European Goldfinch",
    "Evening Grosbeak", "Forsters Tern", "Fox Sparrow", "Geococcyx",
    "Golden winged Warbler", "Gray Kingbird", "Gray crowned Rosy Finch", "Green Jay",
    "Green Violetear", "Green tailed Towhee", "Harris Sparrow", "Heermann Gull",
    "Hooded Merganser", "Hooded Oriole", "Hooded Warbler", "Horned Grebe",
    "Horned Lark", "Horned Puffin", "Ivory Gull", "Lazuli Bunting",
    "Le Conte Sparrow", "Least Auklet", "Loggerhead Shrike", "Magnolia Warbler",
    "Mallard", "Myrtle Warbler", "Nashville Warbler", "Nelson Sharp tailed Sparrow",
    "Nighthawk", "Ovenbird", "Pacific Loon", "Painted Bunting",
    "Palm Warbler", "Parakeet Auklet", "Pied Kingfisher", "Pied billed Grebe",
    "Pine Grosbeak", "Pine Warbler", "Prairie Warbler", "Prothonotary Warbler",
    "Purple Finch", "Red bellied Woodpecker", "Red cockaded Woodpecker", "Red eyed Vireo",
    "Red winged Blackbird", "Rhinoceros Auklet", "Rose breasted Grosbeak", "Ruby throated Hummingbird",
    "Rufous Hummingbird", "Savannah Sparrow", "Sayornis", "Scarlet Tanager",
    "Scissor tailed Flycatcher", "Spotted Catbird", "Summer Tanager", "Tree Swallow",
    "Tropical Kingbird", "Vermilion Flycatcher", "Vesper Sparrow", "Warbling Vireo",
    "Western Meadowlark", "Western Wood Pewee", "Whip poor Will", "White Pelican",
    "White breasted Kingfisher", "White breasted Nuthatch", "White crowned Sparrow", "White throated Sparrow",
    "Yellow Warbler", "Yellow breasted Chat", "Yellow headed Blackbird", "Yellow throated Vireo"
]

# ─── 1) CACHING MODEL LOADERS ──────────────────────────────
@st.cache_resource
def load_yolo_model() -> YOLO:
    path = hf_hub_download(
        repo_id="arpit8210/cubmodels",
        filename="epoch45.pt",
    )
    return YOLO(path)

@st.cache_resource
def load_resnet_model(device: str = "cpu") -> torch.nn.Module:
    sd_path = hf_hub_download(
        repo_id="arpit8210/cubmodels",
        filename="model_state_dict_best.pth",
    )
    model = models.resnet152(weights=None)
    model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))

    raw_sd = torch.load(sd_path, map_location="cpu")
    new_sd = {
        (k[len("module."): ] if k.startswith("module.") else k): v
        for k, v in raw_sd.items()
    }
    model.load_state_dict(new_sd)

    model = torch.nn.DataParallel(model).to(device).eval()
    return model

@st.cache_resource
def load_swin_pipeline():
    return pipeline("image-classification",
                    model="Emiel/cub-200-bird-classifier-swin")

# ─── 2) TRANSFORM FOR RESNET ─────────────────────────────────
resnet_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ─── 3) STREAMLIT UI ────────────────────────────────────────
st.title("🐦 Bird Species Classification")

uploaded = st.file_uploader("Upload a bird image", type=["jpg", "jpeg", "png"])
if not uploaded:
    st.info("Please upload an image.")
    st.stop()

img = Image.open(uploaded).convert("RGB")
st.image(img, caption="Uploaded Image", use_container_width=True)

# Load models (no args here)
yolo_model   = load_yolo_model()
resnet_model = load_resnet_model()
hf_pipe      = load_swin_pipeline()

# Run YOLO detection
results = yolo_model(np.array(img))
boxes   = results[0].boxes.xyxy.cpu().numpy()

annotated = img.copy()
draw = ImageDraw.Draw(annotated)

# Precompute HF result on the full image
hf_res_full = hf_pipe(img)
hf_label, hf_conf = hf_res_full[0]["label"], hf_res_full[0]["score"]

for x1, y1, x2, y2 in boxes.astype(int):
    crop = img.crop((x1, y1, x2, y2))

    # ResNet prediction on the crop
    inp = resnet_transform(crop).unsqueeze(0)
    with torch.no_grad():
        out   = resnet_model(inp)
        probs = torch.softmax(out, dim=1)[0]
        ridx  = torch.argmax(probs).item()
        rlabel, rconf = class_names[ridx], probs[ridx].item()

    # Choose best
    if hf_conf > rconf:
        label, score = hf_label, hf_conf
    else:
        label, score = rlabel, rconf

    draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
    draw.text((x1, max(0, y1-15)),
              f"{label} ({score:.2f})",
              fill="red")

st.image(annotated, caption="Detections & Predictions", use_container_width=True)