SavlonBhai commited on
Commit
6b57a1b
·
verified ·
1 Parent(s): e5c357e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -66
app.py CHANGED
@@ -1,4 +1,11 @@
1
  # app.py
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
@@ -7,15 +14,19 @@ import torchvision.models as tv_models
7
  import torchvision.transforms as T
8
  import numpy as np
9
  from PIL import Image
10
- import pickle, joblib, json, os, warnings
11
  warnings.filterwarnings("ignore")
12
 
 
13
  try:
14
  import timm
15
  HAS_TIMM = True
16
  except Exception:
17
  HAS_TIMM = False
18
 
 
 
 
19
  DEFAULT_CLASSES = [
20
  "Ayrshire cattle","Brown Swiss cattle","Holstein Friesian cattle",
21
  "Jaffrabadi","Jersey cattle","Murrah","Red Dane cattle",
@@ -83,20 +94,19 @@ BREED_INFO = {
83
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
84
  IMAGENET_STD = [0.229, 0.224, 0.225]
85
 
86
- def strip_module_prefix(state_dict):
87
- new_sd = {}
88
- for k, v in state_dict.items():
89
- new_sd[k[7:]] = v if k.startswith("module.") else v
90
- if not k.startswith("module."):
91
- new_sd[k] = v
92
  clean = {}
93
- for k, v in new_sd.items():
94
  if k.startswith("module."):
95
- continue
96
- clean[k] = v
 
97
  return clean
98
 
99
- def file_to_path(file_obj):
100
  if isinstance(file_obj, str):
101
  return file_obj
102
  if hasattr(file_obj, "name"):
@@ -105,9 +115,12 @@ def file_to_path(file_obj):
105
  return file_obj.get("name") or file_obj.get("path") or file_obj.get("file")
106
  raise ValueError("Unsupported file input type")
107
 
108
- def make_head(in_dim, num_classes):
109
  return nn.Sequential(nn.Dropout(0.2), nn.Linear(in_dim, num_classes))
110
 
 
 
 
111
  class IndianBovineClassifier:
112
  def __init__(self):
113
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -116,50 +129,51 @@ class IndianBovineClassifier:
116
  self.class_names = list(DEFAULT_CLASSES)
117
  self.num_classes = len(self.class_names)
118
  self.preprocess = T.Compose([
119
- T.Resize((224,224)),
120
  T.ToTensor(),
121
- T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
122
  ])
123
  self._try_autoload()
124
 
125
- def _build_arch(self, arch: str, num_classes: int):
126
  a = (arch or "").strip()
127
  if a and HAS_TIMM:
128
  try:
129
  m = timm.create_model(a, pretrained=False, num_classes=num_classes)
130
  cfg = getattr(m, "default_cfg", None)
131
  if cfg:
132
- size = cfg.get("input_size", (3,224,224))[-1]
133
  mean = list(cfg.get("mean", IMAGENET_MEAN))
134
- std = list(cfg.get("std", IMAGENET_STD))
135
  self.preprocess = T.Compose([
136
- T.Resize((size,size)),
137
  T.ToTensor(),
138
- T.Normalize(mean, std)
139
  ])
140
  return m
141
  except Exception:
142
  pass
143
- if a.lower() in {"resnet18","tv_resnet18"}:
144
  m = tv_models.resnet18(weights=None)
145
  m.fc = nn.Linear(m.fc.in_features, num_classes)
146
  return m
147
- if a.lower() in {"efficientnet_b0","tv_efficientnet_b0"}:
148
  m = tv_models.efficientnet_b0(weights=None)
149
  in_dim = m.classifier[1].in_features
150
  m.classifier = make_head(in_dim, num_classes)
151
  return m
 
152
  return self._simple_cnn(num_classes)
153
 
154
- def _simple_cnn(self, nc: int):
155
  class Simple(nn.Module):
156
  def __init__(self, out_dim):
157
  super().__init__()
158
  self.features = nn.Sequential(
159
- nn.Conv2d(3,64,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2),
160
- nn.Conv2d(64,128,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2),
161
- nn.Conv2d(128,256,3,padding=1), nn.ReLU(True),
162
- nn.AdaptiveAvgPool2d((1,1))
163
  )
164
  self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(256, out_dim))
165
  def forward(self, x):
@@ -169,6 +183,7 @@ class IndianBovineClassifier:
169
  return Simple(nc)
170
 
171
  def _try_autoload(self):
 
172
  candidates = [
173
  ("indian_bovine_breeds.pth","pytorch"),
174
  ("indian_bovine_model.pth","pytorch"),
@@ -176,12 +191,12 @@ class IndianBovineClassifier:
176
  ("indian_bovine_breeds.pkl","pickle"),
177
  ("model.pkl","pickle"),
178
  ("indian_bovine_breeds.joblib","joblib"),
179
- ("model.joblib","joblib")
180
  ]
181
  for path, kind in candidates:
182
  if os.path.exists(path):
183
  try:
184
- self._load_from_path(path, kind)
185
  print(f"Loaded model: {path}")
186
  return
187
  except Exception as e:
@@ -189,15 +204,15 @@ class IndianBovineClassifier:
189
  self.model = self._simple_cnn(self.num_classes).to(self.device).eval()
190
  self.model_type = "demo"
191
 
192
- def _maybe_set_classes_from_meta(self, meta: dict):
193
- keys = ["classes","class_names","labels","breeds"]
194
  for k in keys:
195
  if k in meta and isinstance(meta[k], (list, tuple)) and len(meta[k]) > 1:
196
  self.class_names = list(meta[k])
197
  self.num_classes = len(self.class_names)
198
  return True
199
  if "class_to_idx" in meta and isinstance(meta["class_to_idx"], dict):
200
- inv = {v:k for k,v in meta["class_to_idx"].items()}
201
  self.class_names = [inv[i] for i in range(len(inv))]
202
  self.num_classes = len(self.class_names)
203
  return True
@@ -210,11 +225,12 @@ class IndianBovineClassifier:
210
  nc = ckpt.get("num_classes", self.num_classes)
211
  state = ckpt.get("model_state_dict", ckpt.get("state_dict"))
212
  if state is None and all(isinstance(k, str) for k in ckpt.keys()):
213
- state = ckpt
214
  if state is None:
215
  raise ValueError("No state_dict in checkpoint.")
216
  state = strip_module_prefix(state)
217
  model = self._build_arch(arch or "efficientnet_b0", nc)
 
218
  if hasattr(model, "classifier") and isinstance(model.classifier, nn.Sequential):
219
  last = model.classifier[-1]
220
  if isinstance(last, nn.Linear) and last.out_features != nc:
@@ -226,6 +242,7 @@ class IndianBovineClassifier:
226
  self.model = model.to(self.device).eval()
227
  self.model_type = f"pytorch:{arch or 'tv_efficientnet_b0'}"
228
  else:
 
229
  self.model = ckpt.to(self.device).eval()
230
  self.model_type = "pytorch:serialized"
231
 
@@ -237,44 +254,67 @@ class IndianBovineClassifier:
237
  self.model = obj
238
  self.model_type = "sklearn"
239
  else:
240
- raise ValueError("Unsupported object in file.")
241
 
242
- def _load_from_path(self, path, kind="auto"):
243
  ext = os.path.splitext(path)[1].lower()
244
  if kind == "auto":
245
- kind = "pytorch" if ext == ".pth" else ("pickle" if ext == ".pkl" else ("joblib" if ext == ".joblib" else "pytorch"))
246
- if kind == "pytorch":
247
- ckpt = torch.load(path, map_location=self.device)
248
- self._load_pytorch_checkpoint(ckpt); return
249
- if kind == "pickle":
 
 
 
 
 
 
250
  try:
251
  ckpt = torch.load(path, map_location=self.device)
252
- self._load_pytorch_checkpoint(ckpt); return
253
- except Exception:
254
- pass
255
- with open(path, "rb") as f:
256
- obj = pickle.load(f)
257
- self._load_generic_object(obj); return
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  if kind == "joblib":
259
  obj = joblib.load(path)
260
- self._load_generic_object(obj); return
 
 
261
  raise ValueError(f"Unknown model kind: {kind}")
262
 
263
- def load_user_model(self, file_obj):
 
264
  path = file_to_path(file_obj)
265
  self._load_from_path(path, kind="auto")
266
  return f"✅ Loaded model: {os.path.basename(path)} | Type: {self.model_type} | Classes: {self.num_classes}"
267
 
268
- def load_classes_json(self, file_obj):
269
  path = file_to_path(file_obj)
270
  with open(path, "r", encoding="utf-8") as f:
271
  names = json.load(f)
272
  if not isinstance(names, list) or len(names) < 2:
273
- raise ValueError("classes.json must be a list of class names.")
274
  self.class_names = list(names)
275
  self.num_classes = len(names)
276
  return f"✅ Loaded {len(names)} class names from {os.path.basename(path)}"
277
 
 
278
  def preprocess_img(self, image: Image.Image):
279
  if image.mode != "RGB":
280
  image = image.convert("RGB")
@@ -282,18 +322,18 @@ class IndianBovineClassifier:
282
  x = self.preprocess(image).unsqueeze(0).to(self.device)
283
  return x
284
  else:
285
- arr = np.array(image.resize((224,224))).astype(np.float32)/255.0
286
- return arr.flatten().reshape(1,-1)
287
 
288
- def predict(self, image: Image.Image):
289
  if self.model is None:
290
- return {"Error":"Model not loaded"}, "Unknown"
291
  x = self.preprocess_img(image)
292
  if self.model_type.startswith("pytorch") or self.model_type == "demo":
293
  with torch.no_grad():
294
  if self.model_type == "demo":
295
  np.random.seed(hash(str(image.size)) % (2**32))
296
- probs = np.random.dirichlet(np.ones(self.num_classes)*3.0)
297
  else:
298
  logits = self.model(x)
299
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
@@ -301,11 +341,14 @@ class IndianBovineClassifier:
301
  probs = self.model.predict_proba(x)[0]
302
  else:
303
  np.random.seed(42)
304
- probs = np.random.dirichlet(np.ones(self.num_classes)*2.0)
305
  top_idx = np.argsort(probs)[::-1][:3]
306
  results = {f"Top {i+1}: {self.class_names[idx]}": float(probs[idx]) for i, idx in enumerate(top_idx)}
307
  return results, self.class_names[top_idx[0]]
308
 
 
 
 
309
  classifier = IndianBovineClassifier()
310
 
311
  def classify_image(image: Image.Image):
@@ -323,12 +366,13 @@ def classify_image(image: Image.Image):
323
  "Error occurred during classification",
324
  f"| Attribute | Value |\n|-----------|-------|\n| Status | Error: {msg} |",
325
  )
326
- indicator = "🎲 DEMO - " if classifier.model_type == "demo" else f"🔥 {classifier.model_type} - "
327
  md = f"{indicator}Classification Results:\n\n"
328
  for k, v in preds.items():
329
  md += f"- {k}: {v:.2%}\n"
330
  if classifier.model_type == "demo":
331
  md += "\nDemo mode: Upload a .pth/.pkl/.joblib model for real predictions."
 
332
  if top_breed in BREED_INFO:
333
  info = BREED_INFO[top_breed]
334
  desc = f"""
@@ -372,27 +416,32 @@ def upload_classes(file_obj):
372
  except Exception as e:
373
  return f"❌ Failed to load classes.json: {e}"
374
 
375
- # Minimal, responsive CSS (optional)
 
 
376
  CUSTOM_CSS = """
377
  .gradio-container { min-height: 100vh; }
378
  .header { text-align:center; padding: 1rem; }
379
  .header .title { font-size: 2em; font-weight: 700; }
380
  .footer { text-align:center; opacity:.75; padding:.75rem; }
381
  @media (max-width: 768px) {
382
- .title { font-size: 1.5em !important; }
383
  }
384
  """
385
 
 
 
 
386
  def create_interface():
387
- with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), fill_width=True, title="🐄 Indian Bovine Classifier") as app:
388
  gr.HTML(f"""
389
  <div class="header">
390
- <div class="title">🐄 Indian Bovine Breeds Classifier 🐃</div>
391
  <div>PyTorch runtime • {len(DEFAULT_CLASSES)} default classes • Device: {classifier.device}</div>
392
  </div>
393
  """)
394
 
395
- # Collapsible sidebar for loaders
396
  with gr.Sidebar():
397
  gr.Markdown("### Model loader")
398
  model_file = gr.File(label="Upload .pth / .pkl / .joblib", file_types=[".pth",".pkl",".joblib"], file_count="single")
@@ -406,22 +455,23 @@ def create_interface():
406
  classes_status = gr.Markdown("No external classes.json loaded.")
407
  load_classes_btn.click(upload_classes, inputs=[classes_file], outputs=[classes_status])
408
 
409
- # Main responsive canvas
410
  with gr.Row(equal_height=True):
411
  with gr.Column(scale=1, min_width=320, variant="panel"):
412
  gr.Markdown("### Upload image")
413
  image_input = gr.Image(type="pil", label="Cattle/Buffalo image")
414
- classify_btn = gr.Button("🔍 Classify", variant="secondary")
415
  with gr.Column(scale=2, min_width=360, variant="panel"):
416
  with gr.Tab("Results"):
417
- prediction_output = gr.Markdown(value="🔄 Upload an image to see classification.")
418
  with gr.Tab("Breed info"):
419
- breed_info_output = gr.Markdown(value="ℹ️ Breed info will appear here.")
420
  with gr.Tab("Stats"):
421
  breed_stats_table = gr.Markdown(value="| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting classification... |")
422
 
423
  gr.Markdown(f"""<div class="footer">Model type: {classifier.model_type} • PyTorch {torch.__version__}</div>""")
424
 
 
425
  classify_btn.click(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table])
426
  image_input.change(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table])
427
 
@@ -429,4 +479,7 @@ def create_interface():
429
 
430
  if __name__ == "__main__":
431
  app = create_interface()
432
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
1
  # app.py
2
+ import os
3
+ import json
4
+ import pickle
5
+ import joblib
6
+ import warnings
7
+ from typing import Tuple, Dict
8
+
9
  import gradio as gr
10
  import torch
11
  import torch.nn as nn
 
14
  import torchvision.transforms as T
15
  import numpy as np
16
  from PIL import Image
17
+
18
  warnings.filterwarnings("ignore")
19
 
20
+ # Optional timm
21
  try:
22
  import timm
23
  HAS_TIMM = True
24
  except Exception:
25
  HAS_TIMM = False
26
 
27
+ # ---------------------------
28
+ # Defaults & metadata
29
+ # ---------------------------
30
  DEFAULT_CLASSES = [
31
  "Ayrshire cattle","Brown Swiss cattle","Holstein Friesian cattle",
32
  "Jaffrabadi","Jersey cattle","Murrah","Red Dane cattle",
 
94
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
95
  IMAGENET_STD = [0.229, 0.224, 0.225]
96
 
97
+ # ---------------------------
98
+ # Helpers
99
+ # ---------------------------
100
+ def strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 
 
101
  clean = {}
102
+ for k, v in state_dict.items():
103
  if k.startswith("module."):
104
+ clean[k[7:]] = v
105
+ else:
106
+ clean[k] = v
107
  return clean
108
 
109
+ def file_to_path(file_obj) -> str:
110
  if isinstance(file_obj, str):
111
  return file_obj
112
  if hasattr(file_obj, "name"):
 
115
  return file_obj.get("name") or file_obj.get("path") or file_obj.get("file")
116
  raise ValueError("Unsupported file input type")
117
 
118
+ def make_head(in_dim: int, num_classes: int) -> nn.Module:
119
  return nn.Sequential(nn.Dropout(0.2), nn.Linear(in_dim, num_classes))
120
 
121
+ # ---------------------------
122
+ # Classifier
123
+ # ---------------------------
124
  class IndianBovineClassifier:
125
  def __init__(self):
126
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
129
  self.class_names = list(DEFAULT_CLASSES)
130
  self.num_classes = len(self.class_names)
131
  self.preprocess = T.Compose([
132
+ T.Resize((224, 224)),
133
  T.ToTensor(),
134
+ T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
135
  ])
136
  self._try_autoload()
137
 
138
+ def _build_arch(self, arch: str, num_classes: int) -> nn.Module:
139
  a = (arch or "").strip()
140
  if a and HAS_TIMM:
141
  try:
142
  m = timm.create_model(a, pretrained=False, num_classes=num_classes)
143
  cfg = getattr(m, "default_cfg", None)
144
  if cfg:
145
+ size = cfg.get("input_size", (3, 224, 224))[-1]
146
  mean = list(cfg.get("mean", IMAGENET_MEAN))
147
+ std = list(cfg.get("std", IMAGENET_STD))
148
  self.preprocess = T.Compose([
149
+ T.Resize((size, size)),
150
  T.ToTensor(),
151
+ T.Normalize(mean, std),
152
  ])
153
  return m
154
  except Exception:
155
  pass
156
+ if a.lower() in {"resnet18", "tv_resnet18"}:
157
  m = tv_models.resnet18(weights=None)
158
  m.fc = nn.Linear(m.fc.in_features, num_classes)
159
  return m
160
+ if a.lower() in {"efficientnet_b0", "tv_efficientnet_b0"}:
161
  m = tv_models.efficientnet_b0(weights=None)
162
  in_dim = m.classifier[1].in_features
163
  m.classifier = make_head(in_dim, num_classes)
164
  return m
165
+ # fallback
166
  return self._simple_cnn(num_classes)
167
 
168
+ def _simple_cnn(self, nc: int) -> nn.Module:
169
  class Simple(nn.Module):
170
  def __init__(self, out_dim):
171
  super().__init__()
172
  self.features = nn.Sequential(
173
+ nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2),
174
+ nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2),
175
+ nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(True),
176
+ nn.AdaptiveAvgPool2d((1, 1)),
177
  )
178
  self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(256, out_dim))
179
  def forward(self, x):
 
183
  return Simple(nc)
184
 
185
  def _try_autoload(self):
186
+ # Attempt common filenames; quietly fall back to demo if none
187
  candidates = [
188
  ("indian_bovine_breeds.pth","pytorch"),
189
  ("indian_bovine_model.pth","pytorch"),
 
191
  ("indian_bovine_breeds.pkl","pickle"),
192
  ("model.pkl","pickle"),
193
  ("indian_bovine_breeds.joblib","joblib"),
194
+ ("model.joblib","joblib"),
195
  ]
196
  for path, kind in candidates:
197
  if os.path.exists(path):
198
  try:
199
+ self._load_from_path(path, kind=kind)
200
  print(f"Loaded model: {path}")
201
  return
202
  except Exception as e:
 
204
  self.model = self._simple_cnn(self.num_classes).to(self.device).eval()
205
  self.model_type = "demo"
206
 
207
+ def _maybe_set_classes_from_meta(self, meta: dict) -> bool:
208
+ keys = ["classes", "class_names", "labels", "breeds"]
209
  for k in keys:
210
  if k in meta and isinstance(meta[k], (list, tuple)) and len(meta[k]) > 1:
211
  self.class_names = list(meta[k])
212
  self.num_classes = len(self.class_names)
213
  return True
214
  if "class_to_idx" in meta and isinstance(meta["class_to_idx"], dict):
215
+ inv = {v: k for k, v in meta["class_to_idx"].items()}
216
  self.class_names = [inv[i] for i in range(len(inv))]
217
  self.num_classes = len(self.class_names)
218
  return True
 
225
  nc = ckpt.get("num_classes", self.num_classes)
226
  state = ckpt.get("model_state_dict", ckpt.get("state_dict"))
227
  if state is None and all(isinstance(k, str) for k in ckpt.keys()):
228
+ state = ckpt # raw state dict
229
  if state is None:
230
  raise ValueError("No state_dict in checkpoint.")
231
  state = strip_module_prefix(state)
232
  model = self._build_arch(arch or "efficientnet_b0", nc)
233
+ # ensure classifier head matches
234
  if hasattr(model, "classifier") and isinstance(model.classifier, nn.Sequential):
235
  last = model.classifier[-1]
236
  if isinstance(last, nn.Linear) and last.out_features != nc:
 
242
  self.model = model.to(self.device).eval()
243
  self.model_type = f"pytorch:{arch or 'tv_efficientnet_b0'}"
244
  else:
245
+ # direct serialized torch.nn.Module
246
  self.model = ckpt.to(self.device).eval()
247
  self.model_type = "pytorch:serialized"
248
 
 
254
  self.model = obj
255
  self.model_type = "sklearn"
256
  else:
257
+ raise ValueError("Unsupported object in file (expect torch module/state_dict or sklearn estimator).")
258
 
259
+ def _load_from_path(self, path: str, kind: str = "auto"):
260
  ext = os.path.splitext(path)[1].lower()
261
  if kind == "auto":
262
+ if ext in {".pth"}:
263
+ kind = "pytorch"
264
+ elif ext in {".pkl"}:
265
+ kind = "pickle"
266
+ elif ext in {".joblib"}:
267
+ kind = "joblib"
268
+ else:
269
+ kind = "pytorch"
270
+
271
+ if kind in ("pytorch", "pickle"):
272
+ # Prefer torch.load first for torch checkpoints, even if extension is .pkl
273
  try:
274
  ckpt = torch.load(path, map_location=self.device)
275
+ self._load_pytorch_checkpoint(ckpt)
276
+ return
277
+ except Exception as torch_err:
278
+ if kind == "pytorch":
279
+ raise RuntimeError(f"PyTorch load failed: {torch_err}") from torch_err
280
+ # try sklearn-style pickle below
281
+
282
+ # sklearn pickle fallback
283
+ try:
284
+ with open(path, "rb") as f:
285
+ obj = pickle.load(f)
286
+ self._load_generic_object(obj)
287
+ return
288
+ except pickle.UnpicklingError as pe:
289
+ # Likely a torch checkpoint mislabeled as .pkl
290
+ raise RuntimeError(
291
+ "This .pkl appears to be a PyTorch checkpoint; load via torch.load or rename to .pth."
292
+ ) from pe
293
+
294
  if kind == "joblib":
295
  obj = joblib.load(path)
296
+ self._load_generic_object(obj)
297
+ return
298
+
299
  raise ValueError(f"Unknown model kind: {kind}")
300
 
301
+ # public API for UI
302
+ def load_user_model(self, file_obj) -> str:
303
  path = file_to_path(file_obj)
304
  self._load_from_path(path, kind="auto")
305
  return f"✅ Loaded model: {os.path.basename(path)} | Type: {self.model_type} | Classes: {self.num_classes}"
306
 
307
+ def load_classes_json(self, file_obj) -> str:
308
  path = file_to_path(file_obj)
309
  with open(path, "r", encoding="utf-8") as f:
310
  names = json.load(f)
311
  if not isinstance(names, list) or len(names) < 2:
312
+ raise ValueError("classes.json must be a list with 2 or more class names.")
313
  self.class_names = list(names)
314
  self.num_classes = len(names)
315
  return f"✅ Loaded {len(names)} class names from {os.path.basename(path)}"
316
 
317
+ # inference
318
  def preprocess_img(self, image: Image.Image):
319
  if image.mode != "RGB":
320
  image = image.convert("RGB")
 
322
  x = self.preprocess(image).unsqueeze(0).to(self.device)
323
  return x
324
  else:
325
+ arr = np.array(image.resize((224, 224))).astype(np.float32) / 255.0
326
+ return arr.flatten().reshape(1, -1)
327
 
328
+ def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]:
329
  if self.model is None:
330
+ return {"Error": "Model not loaded"}, "Unknown"
331
  x = self.preprocess_img(image)
332
  if self.model_type.startswith("pytorch") or self.model_type == "demo":
333
  with torch.no_grad():
334
  if self.model_type == "demo":
335
  np.random.seed(hash(str(image.size)) % (2**32))
336
+ probs = np.random.dirichlet(np.ones(self.num_classes) * 3.0)
337
  else:
338
  logits = self.model(x)
339
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
 
341
  probs = self.model.predict_proba(x)[0]
342
  else:
343
  np.random.seed(42)
344
+ probs = np.random.dirichlet(np.ones(self.num_classes) * 2.0)
345
  top_idx = np.argsort(probs)[::-1][:3]
346
  results = {f"Top {i+1}: {self.class_names[idx]}": float(probs[idx]) for i, idx in enumerate(top_idx)}
347
  return results, self.class_names[top_idx[0]]
348
 
349
+ # ---------------------------
350
+ # UI callbacks
351
+ # ---------------------------
352
  classifier = IndianBovineClassifier()
353
 
354
  def classify_image(image: Image.Image):
 
366
  "Error occurred during classification",
367
  f"| Attribute | Value |\n|-----------|-------|\n| Status | Error: {msg} |",
368
  )
369
+ indicator = "DEMO - " if classifier.model_type == "demo" else f"{classifier.model_type} - "
370
  md = f"{indicator}Classification Results:\n\n"
371
  for k, v in preds.items():
372
  md += f"- {k}: {v:.2%}\n"
373
  if classifier.model_type == "demo":
374
  md += "\nDemo mode: Upload a .pth/.pkl/.joblib model for real predictions."
375
+
376
  if top_breed in BREED_INFO:
377
  info = BREED_INFO[top_breed]
378
  desc = f"""
 
416
  except Exception as e:
417
  return f"❌ Failed to load classes.json: {e}"
418
 
419
+ # ---------------------------
420
+ # Minimal, responsive CSS
421
+ # ---------------------------
422
  CUSTOM_CSS = """
423
  .gradio-container { min-height: 100vh; }
424
  .header { text-align:center; padding: 1rem; }
425
  .header .title { font-size: 2em; font-weight: 700; }
426
  .footer { text-align:center; opacity:.75; padding:.75rem; }
427
  @media (max-width: 768px) {
428
+ .title { font-size: 1.6em !important; }
429
  }
430
  """
431
 
432
+ # ---------------------------
433
+ # Interface
434
+ # ---------------------------
435
  def create_interface():
436
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), fill_width=True, title="Indian Bovine Classifier") as app:
437
  gr.HTML(f"""
438
  <div class="header">
439
+ <div class="title">Indian Bovine Breeds Classifier</div>
440
  <div>PyTorch runtime • {len(DEFAULT_CLASSES)} default classes • Device: {classifier.device}</div>
441
  </div>
442
  """)
443
 
444
+ # Collapsible sidebar
445
  with gr.Sidebar():
446
  gr.Markdown("### Model loader")
447
  model_file = gr.File(label="Upload .pth / .pkl / .joblib", file_types=[".pth",".pkl",".joblib"], file_count="single")
 
455
  classes_status = gr.Markdown("No external classes.json loaded.")
456
  load_classes_btn.click(upload_classes, inputs=[classes_file], outputs=[classes_status])
457
 
458
+ # Main canvas
459
  with gr.Row(equal_height=True):
460
  with gr.Column(scale=1, min_width=320, variant="panel"):
461
  gr.Markdown("### Upload image")
462
  image_input = gr.Image(type="pil", label="Cattle/Buffalo image")
463
+ classify_btn = gr.Button("Classify", variant="secondary")
464
  with gr.Column(scale=2, min_width=360, variant="panel"):
465
  with gr.Tab("Results"):
466
+ prediction_output = gr.Markdown(value="Upload an image to see classification.")
467
  with gr.Tab("Breed info"):
468
+ breed_info_output = gr.Markdown(value="Breed info will appear here.")
469
  with gr.Tab("Stats"):
470
  breed_stats_table = gr.Markdown(value="| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting classification... |")
471
 
472
  gr.Markdown(f"""<div class="footer">Model type: {classifier.model_type} • PyTorch {torch.__version__}</div>""")
473
 
474
+ # Wiring
475
  classify_btn.click(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table])
476
  image_input.change(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table])
477
 
 
479
 
480
  if __name__ == "__main__":
481
  app = create_interface()
482
+ # Launch controls via env vars (optional)
483
+ share_flag = os.getenv("GRADIO_SHARE", "0").lower() in {"1", "true", "yes"}
484
+ ssr_flag = os.getenv("GRADIO_SSR_MODE", "true").lower() in {"1", "true", "yes"}
485
+ app.launch(server_name="0.0.0.0", server_port=7860, share=share_flag, ssr_mode=ssr_flag)