#!/usr/bin/env python3 """ Complete Standalone End-to-End Speech-to-Speaker Association Inference Script Includes all necessary functions without external dependencies on custom files """ import json import torch import cv2 from torch_geometric.data import HeteroData, Batch from typing import Dict, List, Any, Optional, Tuple, Union import os from ultralytics import YOLO from utils.train_speaker import hungarian_matching,AssocGCN,infer_associations,DatasetLoader # ============================================================================ # NEW CLASSES FOR UNSEEN IMAGE PROCESSING # ============================================================================ class DetectionPredictions: """Container for object detection predictions from an unseen image""" def __init__(self, image_path: str): self.image_path = image_path self.panels: List[Dict] = [] self.speech_bubbles: List[Dict] = [] self.faces: List[Dict] = [] self.bodies: List[Dict] = [] self.image_size: Tuple[int, int] = (0, 0) # (width, height) def add_panel(self, bbox: List[float], confidence: float, panel_id: int): """Add panel detection (frame in YOLO terms)""" self.panels.append({ 'bbox': bbox, 'confidence': confidence, 'id': panel_id }) def add_speech_bubble(self, bbox: List[float], confidence: float, panel_id: int, bubble_id: int): """Add speech bubble detection (text in YOLO terms)""" self.speech_bubbles.append({ 'bbox': bbox, 'confidence': confidence, 'panel_id': panel_id, 'id': bubble_id }) def add_face(self, bbox: List[float], confidence: float, panel_id: int, face_id: int): """Add face detection""" self.faces.append({ 'bbox': bbox, 'confidence': confidence, 'panel_id': panel_id, 'id': face_id }) def add_body(self, bbox: List[float], confidence: float, panel_id: int, body_id: int): """Add body detection""" self.bodies.append({ 'bbox': bbox, 'confidence': confidence, 'panel_id': panel_id, 'id': body_id }) # ============================================================================ # ENHANCED DATA PROCESSING FUNCTIONS # ============================================================================ def create_panel_dict_from_predictions(predictions: DetectionPredictions, panel_id: int) -> Dict: """ ADAPTED: Creates a panel dictionary from detection predictions Uses the same format as original create_hetero_data_from_panel expects """ # Filter predictions for this panel panel_bubbles = [b for b in predictions.speech_bubbles if b['panel_id'] == panel_id] panel_faces = [f for f in predictions.faces if f['panel_id'] == panel_id] # Create bubbles list in expected format bubbles = [] for i, bubble in enumerate(panel_bubbles): bubbles.append({ 'bubble_id': bubble['id'], 'bbox': bubble['bbox'] }) # Create faces list in expected format faces = [] for i, face in enumerate(panel_faces): faces.append({ 'face_id': face['id'], 'bbox': face['bbox'] }) # Create panel dict in expected format panel_dict = { 'panel_id': f"unseen_panel_{panel_id}", 'width': predictions.image_size[0], 'height': predictions.image_size[1], 'bubbles': bubbles, 'faces': faces, 'links': [] # No ground truth links for unseen images } return panel_dict def create_hetero_data_from_predictions(predictions: DetectionPredictions, panel_id: int) -> Optional[HeteroData]: """ ENHANCED: Creates HeteroData from detection predictions Leverages existing create_hetero_data_from_panel function """ # Convert predictions to panel dict format panel_dict = create_panel_dict_from_predictions(predictions, panel_id) # Use existing function to create HeteroData hetero_data = DatasetLoader.create_hetero_data_from_panel(panel_dict) if hetero_data is not None: # Remove ground truth labels since we don't have them for unseen images if hasattr(hetero_data['bubble', 'to', 'face'], 'edge_label'): delattr(hetero_data['bubble', 'to', 'face'], 'edge_label') return hetero_data # ============================================================================ # YOLO DETECTION INTEGRATION # ============================================================================ def get_predictions_from_yolo(img_path: str, yolo_model) -> DetectionPredictions: """ INTEGRATED: YOLO detection pipeline that maps to our detection categories Maps YOLO classes: {0:"body", 1:"face", 2:"frame", 3:"text"} """ CLASSES = {0: "body", 1: "face", 2: "frame", 3: "text"} # Create predictions container predictions = DetectionPredictions(img_path) # Load image to get dimensions img = cv2.imread(img_path) if img is None: raise ValueError(f"Could not load image: {img_path}") height, width = img.shape[:2] predictions.image_size = (width, height) # Get YOLO predictions results = yolo_model.predict(source=img_path, device='cuda', verbose=False) # Process detections and organize by type detections_by_type = {"body": [], "face": [], "frame": [], "text": []} for box in results[0].boxes: c = int(box.cls[0]) x1, y1, x2, y2 = map(int, box.xyxy[0]) confidence = float(box.conf[0]) detection_type = CLASSES[c] detections_by_type[detection_type].append({ 'bbox': [x1, y1, x2, y2], 'confidence': confidence }) # Process frame detections as panels for i, frame in enumerate(detections_by_type["frame"]): predictions.add_panel(frame['bbox'], frame['confidence'], i) # If no frames detected, create a default full-image panel if len(predictions.panels) == 0: predictions.add_panel([0, 0, width, height], 1.0, 0) # Process text detections as speech bubbles for i, text in enumerate(detections_by_type["text"]): panel_id = find_containing_panel(text['bbox'], predictions.panels) predictions.add_speech_bubble(text['bbox'], text['confidence'], panel_id, i) from collections import defaultdict # Group bubbles by panel_id grouped_texts = defaultdict(list) for bubble in predictions.speech_bubbles: grouped_texts[bubble['panel_id']].append(bubble) # Assign seq based on right-to-left order (sort by bbox[0] descending) for panel_id, bubbles in grouped_texts.items(): sorted_bubbles = sorted(bubbles, key=lambda b: b['bbox'][0], reverse=True) for seq, bubble in enumerate(sorted_bubbles): bubble['seq'] = seq # # Optional: print results # for bubble in speech_bubbles: # print(bubble) # Process face detections for i, face in enumerate(detections_by_type["face"]): panel_id = find_containing_panel(face['bbox'], predictions.panels) predictions.add_face(face['bbox'], face['confidence'], panel_id, i) # Process body detections (optional - can be used for additional context) for i, body in enumerate(detections_by_type["body"]): panel_id = find_containing_panel(body['bbox'], predictions.panels) predictions.add_body(body['bbox'], body['confidence'], panel_id, i) return predictions def find_containing_panel(bbox: List[float], panels: List[Dict]) -> int: """ Helper function to determine which panel contains a detection """ bbox_center_x = (bbox[0] + bbox[2]) / 2 bbox_center_y = (bbox[1] + bbox[3]) / 2 for panel in panels: p_bbox = panel['bbox'] if (p_bbox[0] <= bbox_center_x <= p_bbox[2] and p_bbox[1] <= bbox_center_y <= p_bbox[3]): return panel['id'] # Return first panel if not contained in any return panels[0]['id'] if panels else 0 # ============================================================================ # MODEL LOADING WITH ERROR HANDLING # ============================================================================ def load_trained_speaker_model(model_path: str, device: str = "cuda") -> AssocGCN: """ FIXED: Load trained AssocGCN model with proper error handling Handles different checkpoint formats including 'model_state' key """ # Create model instance model = AssocGCN().to(device) # Load the checkpoint checkpoint = torch.load(model_path, map_location=device) # Handle different checkpoint formats if isinstance(checkpoint, dict): # Check for different possible keys where model state is stored if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'model_state' in checkpoint: # This handles your specific case state_dict = checkpoint['model_state'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: # If no specific key, assume the entire checkpoint is the state dict # but filter out non-model keys state_dict = {k: v for k, v in checkpoint.items() if k not in {'epoch', 'loss', 'optimizer_state_dict'}} else: # Assume checkpoint is directly the state dict state_dict = checkpoint # Handle DataParallel module prefix if present if any(key.startswith('module.') for key in state_dict.keys()): new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace('module.', '') if key.startswith('module.') else key new_state_dict[new_key] = value state_dict = new_state_dict try: # Try to load with strict=True first model.load_state_dict(state_dict, strict=True) print("✅ Model loaded successfully with strict=True") except RuntimeError as e: print(f"⚠️ Warning: {str(e)}") print("Attempting to load with strict=False...") try: # Try with strict=False to ignore missing/unexpected keys missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if missing_keys: print(f"⚠️ Missing keys: {missing_keys}") if unexpected_keys: print(f"⚠️ Unexpected keys: {unexpected_keys}") print("✅ Model loaded successfully with strict=False") except Exception as e2: print(f"❌ Failed to load model: {str(e2)}") raise e2 model.eval() return model # ============================================================================ # MAIN INFERENCE PIPELINE # ============================================================================ @torch.no_grad() def identify_speaker(model: AssocGCN, image_path: str, yolo_model, device: str = "cuda") -> Dict[str, Dict[int, int]]: """ COMPLETE: End-to-end inference pipeline for unseen images """ try: # Step 1: Run YOLO detection pipeline print(f"🔍 Running YOLO detection on {os.path.basename(image_path)}...") predictions = get_predictions_from_yolo(image_path, yolo_model) # print(f"📊 Detection Results:") # print(f" - Panels/Frames: {len(predictions.panels)}") # print(f" - Speech Bubbles/Text: {len(predictions.speech_bubbles)}") # print(f" - Faces: {len(predictions.faces)}") # print(f" - Bodies: {len(predictions.bodies)}") # Check if we have the minimum required detections if len(predictions.speech_bubbles) == 0: print("⚠️ No speech bubbles/text detected. Cannot perform association.") return {} if len(predictions.faces) == 0: print("⚠️ No faces detected. Cannot perform association.") return {} # Step 2: Create HeteroData for each panel panels_data = [] panel_ids = list(set([p['id'] for p in predictions.panels])) for panel_id in panel_ids: panel_data = create_hetero_data_from_predictions(predictions, panel_id) # print("panel data", panel_data , dir(panel_data)) # print(( 'bubble' in panel_data.node_types , 'face' in panel_data.node_types, # panel_data['bubble'].x.size(0) > 0 , panel_data['face'].x.size(0) > 0)) if panel_data is not None: panel_data = panel_data.to(device) panels_data.append(panel_data) # Step 3: Run model inference results = {} for panel_data in panels_data: # print("ppppp",panel_data == None) # if (hasattr(panel_data, 'bubble') and hasattr(panel_data, 'face') and # panel_data['bubble'].x.size(0) > 0 and panel_data['face'].x.size(0) > 0): if ( panel_data != None and 'bubble' in panel_data.node_types and 'face' in panel_data.node_types and panel_data['bubble'].x.size(0) > 0 and panel_data['face'].x.size(0) > 0): # Run inference using existing function mapping = infer_associations(model, panel_data) # print(f"\\n🖼️ {panel_data.panel_id}:") # if mapping: # for bubble_idx, face_idx in mapping.items(): # print(f" Text/Bubble {bubble_idx} → Face {face_idx}") # else: # print(" No associations found") results[panel_data.panel_id] = mapping else: print(f"⚠️ {panel_data.panel_id}: No valid bubbles or faces, skipping...") total_associations = 0 if results: for panel_id, mapping in results.items(): if mapping: # print(f"\\n🖼️ {panel_id}:") for bubble_id, face_id in mapping.items(): # print(f" Text/Bubble {bubble_id} ← → Face {face_id}") total_associations += 1 # else: # print(f"\\n🖼️ {panel_id}: No associations found") else: print("No associations found in the image.") print(f"\\n✅ Total associations found: {total_associations}") return results,predictions except Exception as e: print(f"❌ Error during inference: {str(e)}") raise e # #!/usr/bin/env python3 # """ # inference.py ── Run saved AssocGCN on new images / panels. # Generate per-panel mappings and (optionally) aggregate metrics. # """ # import argparse # import torch # from pathlib import Path # from torch_geometric.data import Batch # from typing import Dict, List # from utils import load_model # from train_speaker import DatasetLoader, infer_associations # your original file # @torch.no_grad() # def evaluate_panels(model, # panels: List["HeteroData"], # compute_metrics: bool = True) -> None: # device = next(model.parameters()).device # tp = fp = fn = 0 # for idx, data in enumerate(panels): # mapping = infer_associations(model, data) # print(f"\n🖼️ Panel {data.panel_id}:") # for bub, face in mapping.items(): # print(f" Bubble {bub} → Face {face}") # if compute_metrics and "edge_label" in data["bubble", "to", "face"]: # gt = {(i.item(), j.item()) # for i, j, lbl in zip(*data["bubble", "to", "face"].edge_index, # data["bubble", "to", "face"].edge_label) # if lbl == 1} # pred = {(b, f) for b, f in mapping.items()} # tp += len(gt & pred) # fp += len(pred - gt) # fn += len(gt - pred) # if compute_metrics: # prec = tp / (tp + fp) if (tp + fp) else 0 # rec = tp / (tp + fn) if (tp + fn) else 0 # f1 = 2*prec*rec/(prec+rec) if (prec+rec) else 0 # print("\n📊 Aggregated metrics") # print(f" Precision: {prec:.3f}") # print(f" Recall : {rec:.3f}") # print(f" F1 Score : {f1:.3f}") # def identify_speaker(config): # pass