# app.py import os import json import pickle import joblib import warnings from typing import Tuple, Dict import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as tv_models import torchvision.transforms as T import numpy as np from PIL import Image warnings.filterwarnings("ignore") # Optional timm try: import timm HAS_TIMM = True except Exception: HAS_TIMM = False # --------------------------- # Defaults & metadata # --------------------------- DEFAULT_CLASSES = [ "Ayrshire cattle","Brown Swiss cattle","Holstein Friesian cattle", "Jaffrabadi","Jersey cattle","Murrah","Red Dane cattle", "kankarej","sahiwal","sahiwal cross","sibbi" ] BREED_INFO = { "Ayrshire cattle":{"type":"Dairy Cow","origin":"Scotland", "characteristics":"Strong, adaptable, excellent udder conformation and superior grazing ability", "milk_yield":"6000-7000 liters per lactation", "special_features":"Red and white patches, hardy in cold weather, high butterfat content", "weight":"450-550 kg","height":"125-135 cm","temperament":"Docile and friendly"}, "Brown Swiss cattle":{"type":"Dual-purpose (Dairy & Beef)","origin":"Switzerland", "characteristics":"Docile, strong, excellent for cheese production, disease resistant", "milk_yield":"10000-14000 liters per lactation", "special_features":"Light to dark brown color with creamy white muzzle, exceptional longevity", "weight":"600-700 kg","height":"135-150 cm","temperament":"Calm and intelligent"}, "Holstein Friesian cattle":{"type":"Dairy Cow","origin":"Netherlands/Germany", "characteristics":"Highest milk production, excellent feed conversion, docile temperament", "milk_yield":"8000-12000 liters per lactation", "special_features":"Distinctive black and white patches, large frame, heat sensitive", "weight":"580-700 kg","height":"140-150 cm","temperament":"Gentle and manageable"}, "Jaffrabadi":{"type":"Indigenous Dairy Buffalo","origin":"Gujarat, India (Saurashtra region)", "characteristics":"Heaviest Indian buffalo breed, adapted to harsh semi-arid conditions", "milk_yield":"2000-2500 liters per lactation", "special_features":"Black color, dome-shaped forehead, ring-like horns, highest butterfat content", "weight":"400-600 kg","height":"130-140 cm","temperament":"Hardy and resilient"}, "Jersey cattle":{"type":"Dairy Cow","origin":"Jersey, Channel Islands", "characteristics":"Efficient feed conversion, calving ease, heat tolerant, docile", "milk_yield":"4500-6500 liters per lactation", "special_features":"Light tan to fawn color, smallest dairy breed, highest butterfat percentage", "weight":"350-450 kg","height":"120-125 cm","temperament":"Alert and intelligent"}, "Murrah":{"type":"Indigenous Dairy Buffalo","origin":"Haryana and Punjab, India", "characteristics":"Highest milk yielding buffalo breed, docile nature, good mothers", "milk_yield":"2200-3000 liters per lactation", "special_features":"Jet black color, tightly curved horns, compact body structure", "weight":"450-650 kg","height":"130-135 cm","temperament":"Docile and calm"}, "Red Dane cattle":{"type":"Dual-purpose (Dairy & Beef)","origin":"Denmark", "characteristics":"Hardy, disease resistant, excellent meat quality, easy calving", "milk_yield":"8000-10000 liters per lactation", "special_features":"Red to dark mahogany color with white markings, good heat tolerance", "weight":"550-650 kg","height":"135-145 cm","temperament":"Gentle and cooperative"}, "kankarej":{"type":"Indigenous Dual-purpose (Dairy & Draught)","origin":"Gujarat, India (Kankrej territory)", "characteristics":"Active, strong draught animal, drought resistant, disease resistant", "milk_yield":"1500-2000 liters per lactation", "special_features":"Silver to gray to steel black color, lyre-shaped horns, large pendulous ears", "weight":"400-500 kg","height":"125-135 cm","temperament":"Active and energetic"}, "sahiwal":{"type":"Indigenous Dairy Cow","origin":"Punjab, Pakistan/India", "characteristics":"Heat resistant, tick resistant, high disease resistance, docile", "milk_yield":"2500-3200 liters per lactation", "special_features":"Brownish red to grayish red color, loose dewlap, compact build", "weight":"300-400 kg","height":"115-125 cm","temperament":"Docile and hardy"}, "sahiwal cross":{"type":"Crossbred Dairy Cow","origin":"Cross breeding programs (Sahiwal x exotic breeds)", "characteristics":"Hybrid vigor, improved milk yield, better adaptability than pure exotic", "milk_yield":"3000-4200 liters per lactation", "special_features":"Variable color depending on cross, moderate heat tolerance, enhanced productivity", "weight":"350-450 kg","height":"120-130 cm","temperament":"Balanced and adaptable"}, "sibbi":{"type":"Indigenous Dual-purpose (Draught & Beef)","origin":"Sibi, Baluchistan, Pakistan", "characteristics":"Largest Zebu breed, exceptional size, extremely hardy, massive build", "milk_yield":"1500-2200 liters per lactation", "special_features":"Pure white to grey with black neck, tallest cattle breed, exhibited at Sibi Mela", "weight":"500-800 kg","height":"140-160 cm","temperament":"Majestic and calm"} } IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # --------------------------- # Helpers # --------------------------- def strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: clean = {} for k, v in state_dict.items(): if k.startswith("module."): clean[k[7:]] = v else: clean[k] = v return clean def file_to_path(file_obj) -> str: if isinstance(file_obj, str): return file_obj if hasattr(file_obj, "name"): return file_obj.name if isinstance(file_obj, dict): return file_obj.get("name") or file_obj.get("path") or file_obj.get("file") raise ValueError("Unsupported file input type") def make_head(in_dim: int, num_classes: int) -> nn.Module: return nn.Sequential(nn.Dropout(0.2), nn.Linear(in_dim, num_classes)) # --------------------------- # Classifier # --------------------------- class IndianBovineClassifier: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.model_type = "demo" self.class_names = list(DEFAULT_CLASSES) self.num_classes = len(self.class_names) self.preprocess = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD), ]) self._try_autoload() def _build_arch(self, arch: str, num_classes: int) -> nn.Module: a = (arch or "").strip() if a and HAS_TIMM: try: m = timm.create_model(a, pretrained=False, num_classes=num_classes) cfg = getattr(m, "default_cfg", None) if cfg: size = cfg.get("input_size", (3, 224, 224))[-1] mean = list(cfg.get("mean", IMAGENET_MEAN)) std = list(cfg.get("std", IMAGENET_STD)) self.preprocess = T.Compose([ T.Resize((size, size)), T.ToTensor(), T.Normalize(mean, std), ]) return m except Exception: pass if a.lower() in {"resnet18", "tv_resnet18"}: m = tv_models.resnet18(weights=None) m.fc = nn.Linear(m.fc.in_features, num_classes) return m if a.lower() in {"efficientnet_b0", "tv_efficientnet_b0"}: m = tv_models.efficientnet_b0(weights=None) in_dim = m.classifier[1].in_features m.classifier = make_head(in_dim, num_classes) return m # fallback return self._simple_cnn(num_classes) def _simple_cnn(self, nc: int) -> nn.Module: class Simple(nn.Module): def __init__(self, out_dim): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(True), nn.AdaptiveAvgPool2d((1, 1)), ) self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(256, out_dim)) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) return self.classifier(x) return Simple(nc) def _try_autoload(self): # Attempt common filenames; quietly fall back to demo if none candidates = [ ("indian_bovine_breeds.pth","pytorch"), ("indian_bovine_model.pth","pytorch"), ("model.pth","pytorch"), ("indian_bovine_breeds.pkl","pickle"), ("model.pkl","pickle"), ("indian_bovine_breeds.joblib","joblib"), ("model.joblib","joblib"), ] for path, kind in candidates: if os.path.exists(path): try: self._load_from_path(path, kind=kind) print(f"Loaded model: {path}") return except Exception as e: print(f"Autoload failed for {path}: {e}") self.model = self._simple_cnn(self.num_classes).to(self.device).eval() self.model_type = "demo" def _maybe_set_classes_from_meta(self, meta: dict) -> bool: keys = ["classes", "class_names", "labels", "breeds"] for k in keys: if k in meta and isinstance(meta[k], (list, tuple)) and len(meta[k]) > 1: self.class_names = list(meta[k]) self.num_classes = len(self.class_names) return True if "class_to_idx" in meta and isinstance(meta["class_to_idx"], dict): inv = {v: k for k, v in meta["class_to_idx"].items()} self.class_names = [inv[i] for i in range(len(inv))] self.num_classes = len(self.class_names) return True return False def _load_pytorch_checkpoint(self, ckpt): if isinstance(ckpt, dict): arch = ckpt.get("arch") self._maybe_set_classes_from_meta(ckpt) nc = ckpt.get("num_classes", self.num_classes) state = ckpt.get("model_state_dict", ckpt.get("state_dict")) if state is None and all(isinstance(k, str) for k in ckpt.keys()): state = ckpt # raw state dict if state is None: raise ValueError("No state_dict in checkpoint.") state = strip_module_prefix(state) model = self._build_arch(arch or "efficientnet_b0", nc) # ensure classifier head matches if hasattr(model, "classifier") and isinstance(model.classifier, nn.Sequential): last = model.classifier[-1] if isinstance(last, nn.Linear) and last.out_features != nc: model.classifier[-1] = nn.Linear(last.in_features, nc) elif hasattr(model, "fc") and isinstance(model.fc, nn.Linear) and model.fc.out_features != nc: model.fc = nn.Linear(model.fc.in_features, nc) model.load_state_dict(state, strict=False) self.num_classes = nc self.model = model.to(self.device).eval() self.model_type = f"pytorch:{arch or 'tv_efficientnet_b0'}" else: # direct serialized torch.nn.Module self.model = ckpt.to(self.device).eval() self.model_type = "pytorch:serialized" def _load_generic_object(self, obj): if hasattr(obj, "eval") and hasattr(obj, "state_dict"): self.model = obj.to(self.device).eval() self.model_type = "pytorch:pickle" elif hasattr(obj, "predict_proba"): self.model = obj self.model_type = "sklearn" else: raise ValueError("Unsupported object in file (expect torch module/state_dict or sklearn estimator).") def _load_from_path(self, path: str, kind: str = "auto"): ext = os.path.splitext(path)[1].lower() if kind == "auto": if ext in {".pth"}: kind = "pytorch" elif ext in {".pkl"}: kind = "pickle" elif ext in {".joblib"}: kind = "joblib" else: kind = "pytorch" if kind in ("pytorch", "pickle"): # Prefer torch.load first for torch checkpoints, even if extension is .pkl try: ckpt = torch.load(path, map_location=self.device) self._load_pytorch_checkpoint(ckpt) return except Exception as torch_err: if kind == "pytorch": raise RuntimeError(f"PyTorch load failed: {torch_err}") from torch_err # try sklearn-style pickle below # sklearn pickle fallback try: with open(path, "rb") as f: obj = pickle.load(f) self._load_generic_object(obj) return except pickle.UnpicklingError as pe: # Likely a torch checkpoint mislabeled as .pkl raise RuntimeError( "This .pkl appears to be a PyTorch checkpoint; load via torch.load or rename to .pth." ) from pe if kind == "joblib": obj = joblib.load(path) self._load_generic_object(obj) return raise ValueError(f"Unknown model kind: {kind}") # public API for UI def load_user_model(self, file_obj) -> str: path = file_to_path(file_obj) self._load_from_path(path, kind="auto") return f"✅ Loaded model: {os.path.basename(path)} | Type: {self.model_type} | Classes: {self.num_classes}" def load_classes_json(self, file_obj) -> str: path = file_to_path(file_obj) with open(path, "r", encoding="utf-8") as f: names = json.load(f) if not isinstance(names, list) or len(names) < 2: raise ValueError("classes.json must be a list with 2 or more class names.") self.class_names = list(names) self.num_classes = len(names) return f"✅ Loaded {len(names)} class names from {os.path.basename(path)}" # inference def preprocess_img(self, image: Image.Image): if image.mode != "RGB": image = image.convert("RGB") if self.model_type.startswith("pytorch") or self.model_type == "demo": x = self.preprocess(image).unsqueeze(0).to(self.device) return x else: arr = np.array(image.resize((224, 224))).astype(np.float32) / 255.0 return arr.flatten().reshape(1, -1) def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]: if self.model is None: return {"Error": "Model not loaded"}, "Unknown" x = self.preprocess_img(image) if self.model_type.startswith("pytorch") or self.model_type == "demo": with torch.no_grad(): if self.model_type == "demo": np.random.seed(hash(str(image.size)) % (2**32)) probs = np.random.dirichlet(np.ones(self.num_classes) * 3.0) else: logits = self.model(x) probs = F.softmax(logits, dim=1).cpu().numpy()[0] elif self.model_type == "sklearn": probs = self.model.predict_proba(x)[0] else: np.random.seed(42) probs = np.random.dirichlet(np.ones(self.num_classes) * 2.0) top_idx = np.argsort(probs)[::-1][:3] results = {f"Top {i+1}: {self.class_names[idx]}": float(probs[idx]) for i, idx in enumerate(top_idx)} return results, self.class_names[top_idx[0]] # --------------------------- # UI callbacks # --------------------------- classifier = IndianBovineClassifier() def classify_image(image: Image.Image): if image is None: return ( "Please upload an image of cattle or buffalo", "Upload an image to see detailed breed information", "| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting image upload |", ) preds, top_breed = classifier.predict(image) if "Error" in preds: msg = preds["Error"] return ( f"❌ {msg}", "Error occurred during classification", f"| Attribute | Value |\n|-----------|-------|\n| Status | Error: {msg} |", ) indicator = "DEMO - " if classifier.model_type == "demo" else f"{classifier.model_type} - " md = f"{indicator}Classification Results:\n\n" for k, v in preds.items(): md += f"- {k}: {v:.2%}\n" if classifier.model_type == "demo": md += "\nDemo mode: Upload a .pth/.pkl/.joblib model for real predictions." if top_breed in BREED_INFO: info = BREED_INFO[top_breed] desc = f""" ## 🐄 {top_breed} Type: {info['type']} Origin: {info['origin']} Characteristics: {info['characteristics']} Milk Yield: {info['milk_yield']} Special Features: {info['special_features']} Weight: {info['weight']} Height: {info['height']} Temperament: {info['temperament']} """ table = f"""| Attribute | Value | |-----------|-------| | Type | {info['type']} | | Origin | {info['origin']} | | Weight | {info['weight']} | | Height | {info['height']} | | Milk Yield | {info['milk_yield']} | | Temperament | {info['temperament']} |""" else: desc = "Detailed information not available for this breed." table = "| Attribute | Value |\n|-----------|-------|\n| Status | Information not available |" return md, desc, table def upload_and_load_model(file_obj): if not file_obj: return "Please select a .pth, .pkl or .joblib file to load." try: return classifier.load_user_model(file_obj) except Exception as e: return f"❌ Failed to load model: {e}" def upload_classes(file_obj): if not file_obj: return "Please select a classes.json file." try: return classifier.load_classes_json(file_obj) except Exception as e: return f"❌ Failed to load classes.json: {e}" # --------------------------- # Minimal, responsive CSS # --------------------------- CUSTOM_CSS = """ .gradio-container { min-height: 100vh; } .header { text-align:center; padding: 1rem; } .header .title { font-size: 2em; font-weight: 700; } .footer { text-align:center; opacity:.75; padding:.75rem; } @media (max-width: 768px) { .title { font-size: 1.6em !important; } } """ # --------------------------- # Interface # --------------------------- def create_interface(): with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), fill_width=True, title="Indian Bovine Classifier") as app: gr.HTML(f"""