Sonofica / utils /identify_speaker.py
janmayjay's picture
Add application file
39a7537
raw
history blame
17.2 kB
#!/usr/bin/env python3
"""
Complete Standalone End-to-End Speech-to-Speaker Association Inference Script
Includes all necessary functions without external dependencies on custom files
"""
import json
import torch
import cv2
from torch_geometric.data import HeteroData, Batch
from typing import Dict, List, Any, Optional, Tuple, Union
import os
from ultralytics import YOLO
from utils.train_speaker import hungarian_matching,AssocGCN,infer_associations,DatasetLoader
# ============================================================================
# NEW CLASSES FOR UNSEEN IMAGE PROCESSING
# ============================================================================
class DetectionPredictions:
"""Container for object detection predictions from an unseen image"""
def __init__(self, image_path: str):
self.image_path = image_path
self.panels: List[Dict] = []
self.speech_bubbles: List[Dict] = []
self.faces: List[Dict] = []
self.bodies: List[Dict] = []
self.image_size: Tuple[int, int] = (0, 0) # (width, height)
def add_panel(self, bbox: List[float], confidence: float, panel_id: int):
"""Add panel detection (frame in YOLO terms)"""
self.panels.append({
'bbox': bbox, 'confidence': confidence, 'id': panel_id
})
def add_speech_bubble(self, bbox: List[float], confidence: float,
panel_id: int, bubble_id: int):
"""Add speech bubble detection (text in YOLO terms)"""
self.speech_bubbles.append({
'bbox': bbox, 'confidence': confidence,
'panel_id': panel_id, 'id': bubble_id
})
def add_face(self, bbox: List[float], confidence: float,
panel_id: int, face_id: int):
"""Add face detection"""
self.faces.append({
'bbox': bbox, 'confidence': confidence,
'panel_id': panel_id, 'id': face_id
})
def add_body(self, bbox: List[float], confidence: float,
panel_id: int, body_id: int):
"""Add body detection"""
self.bodies.append({
'bbox': bbox, 'confidence': confidence,
'panel_id': panel_id, 'id': body_id
})
# ============================================================================
# ENHANCED DATA PROCESSING FUNCTIONS
# ============================================================================
def create_panel_dict_from_predictions(predictions: DetectionPredictions,
panel_id: int) -> Dict:
"""
ADAPTED: Creates a panel dictionary from detection predictions
Uses the same format as original create_hetero_data_from_panel expects
"""
# Filter predictions for this panel
panel_bubbles = [b for b in predictions.speech_bubbles if b['panel_id'] == panel_id]
panel_faces = [f for f in predictions.faces if f['panel_id'] == panel_id]
# Create bubbles list in expected format
bubbles = []
for i, bubble in enumerate(panel_bubbles):
bubbles.append({
'bubble_id': bubble['id'],
'bbox': bubble['bbox']
})
# Create faces list in expected format
faces = []
for i, face in enumerate(panel_faces):
faces.append({
'face_id': face['id'],
'bbox': face['bbox']
})
# Create panel dict in expected format
panel_dict = {
'panel_id': f"unseen_panel_{panel_id}",
'width': predictions.image_size[0],
'height': predictions.image_size[1],
'bubbles': bubbles,
'faces': faces,
'links': [] # No ground truth links for unseen images
}
return panel_dict
def create_hetero_data_from_predictions(predictions: DetectionPredictions,
panel_id: int) -> Optional[HeteroData]:
"""
ENHANCED: Creates HeteroData from detection predictions
Leverages existing create_hetero_data_from_panel function
"""
# Convert predictions to panel dict format
panel_dict = create_panel_dict_from_predictions(predictions, panel_id)
# Use existing function to create HeteroData
hetero_data = DatasetLoader.create_hetero_data_from_panel(panel_dict)
if hetero_data is not None:
# Remove ground truth labels since we don't have them for unseen images
if hasattr(hetero_data['bubble', 'to', 'face'], 'edge_label'):
delattr(hetero_data['bubble', 'to', 'face'], 'edge_label')
return hetero_data
# ============================================================================
# YOLO DETECTION INTEGRATION
# ============================================================================
def get_predictions_from_yolo(img_path: str, yolo_model) -> DetectionPredictions:
"""
INTEGRATED: YOLO detection pipeline that maps to our detection categories
Maps YOLO classes: {0:"body", 1:"face", 2:"frame", 3:"text"}
"""
CLASSES = {0: "body", 1: "face", 2: "frame", 3: "text"}
# Create predictions container
predictions = DetectionPredictions(img_path)
# Load image to get dimensions
img = cv2.imread(img_path)
if img is None:
raise ValueError(f"Could not load image: {img_path}")
height, width = img.shape[:2]
predictions.image_size = (width, height)
# Get YOLO predictions
results = yolo_model.predict(source=img_path, device='cuda', verbose=False)
# Process detections and organize by type
detections_by_type = {"body": [], "face": [], "frame": [], "text": []}
for box in results[0].boxes:
c = int(box.cls[0])
x1, y1, x2, y2 = map(int, box.xyxy[0])
confidence = float(box.conf[0])
detection_type = CLASSES[c]
detections_by_type[detection_type].append({
'bbox': [x1, y1, x2, y2],
'confidence': confidence
})
# Process frame detections as panels
for i, frame in enumerate(detections_by_type["frame"]):
predictions.add_panel(frame['bbox'], frame['confidence'], i)
# If no frames detected, create a default full-image panel
if len(predictions.panels) == 0:
predictions.add_panel([0, 0, width, height], 1.0, 0)
# Process text detections as speech bubbles
for i, text in enumerate(detections_by_type["text"]):
panel_id = find_containing_panel(text['bbox'], predictions.panels)
predictions.add_speech_bubble(text['bbox'], text['confidence'], panel_id, i)
from collections import defaultdict
# Group bubbles by panel_id
grouped_texts = defaultdict(list)
for bubble in predictions.speech_bubbles:
grouped_texts[bubble['panel_id']].append(bubble)
# Assign seq based on right-to-left order (sort by bbox[0] descending)
for panel_id, bubbles in grouped_texts.items():
sorted_bubbles = sorted(bubbles, key=lambda b: b['bbox'][0], reverse=True)
for seq, bubble in enumerate(sorted_bubbles):
bubble['seq'] = seq
# # Optional: print results
# for bubble in speech_bubbles:
# print(bubble)
# Process face detections
for i, face in enumerate(detections_by_type["face"]):
panel_id = find_containing_panel(face['bbox'], predictions.panels)
predictions.add_face(face['bbox'], face['confidence'], panel_id, i)
# Process body detections (optional - can be used for additional context)
for i, body in enumerate(detections_by_type["body"]):
panel_id = find_containing_panel(body['bbox'], predictions.panels)
predictions.add_body(body['bbox'], body['confidence'], panel_id, i)
return predictions
def find_containing_panel(bbox: List[float], panels: List[Dict]) -> int:
"""
Helper function to determine which panel contains a detection
"""
bbox_center_x = (bbox[0] + bbox[2]) / 2
bbox_center_y = (bbox[1] + bbox[3]) / 2
for panel in panels:
p_bbox = panel['bbox']
if (p_bbox[0] <= bbox_center_x <= p_bbox[2] and
p_bbox[1] <= bbox_center_y <= p_bbox[3]):
return panel['id']
# Return first panel if not contained in any
return panels[0]['id'] if panels else 0
# ============================================================================
# MODEL LOADING WITH ERROR HANDLING
# ============================================================================
def load_trained_speaker_model(model_path: str, device: str = "cuda") -> AssocGCN:
"""
FIXED: Load trained AssocGCN model with proper error handling
Handles different checkpoint formats including 'model_state' key
"""
# Create model instance
model = AssocGCN().to(device)
# Load the checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
# Check for different possible keys where model state is stored
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'model_state' in checkpoint: # This handles your specific case
state_dict = checkpoint['model_state']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
# If no specific key, assume the entire checkpoint is the state dict
# but filter out non-model keys
state_dict = {k: v for k, v in checkpoint.items()
if k not in {'epoch', 'loss', 'optimizer_state_dict'}}
else:
# Assume checkpoint is directly the state dict
state_dict = checkpoint
# Handle DataParallel module prefix if present
if any(key.startswith('module.') for key in state_dict.keys()):
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace('module.', '') if key.startswith('module.') else key
new_state_dict[new_key] = value
state_dict = new_state_dict
try:
# Try to load with strict=True first
model.load_state_dict(state_dict, strict=True)
print("✅ Model loaded successfully with strict=True")
except RuntimeError as e:
print(f"⚠️ Warning: {str(e)}")
print("Attempting to load with strict=False...")
try:
# Try with strict=False to ignore missing/unexpected keys
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"⚠️ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠️ Unexpected keys: {unexpected_keys}")
print("✅ Model loaded successfully with strict=False")
except Exception as e2:
print(f"❌ Failed to load model: {str(e2)}")
raise e2
model.eval()
return model
# ============================================================================
# MAIN INFERENCE PIPELINE
# ============================================================================
@torch.no_grad()
def identify_speaker(model: AssocGCN, image_path: str,
yolo_model, device: str = "cuda") -> Dict[str, Dict[int, int]]:
"""
COMPLETE: End-to-end inference pipeline for unseen images
"""
try:
# Step 1: Run YOLO detection pipeline
print(f"🔍 Running YOLO detection on {os.path.basename(image_path)}...")
predictions = get_predictions_from_yolo(image_path, yolo_model)
# print(f"📊 Detection Results:")
# print(f" - Panels/Frames: {len(predictions.panels)}")
# print(f" - Speech Bubbles/Text: {len(predictions.speech_bubbles)}")
# print(f" - Faces: {len(predictions.faces)}")
# print(f" - Bodies: {len(predictions.bodies)}")
# Check if we have the minimum required detections
if len(predictions.speech_bubbles) == 0:
print("⚠️ No speech bubbles/text detected. Cannot perform association.")
return {}
if len(predictions.faces) == 0:
print("⚠️ No faces detected. Cannot perform association.")
return {}
# Step 2: Create HeteroData for each panel
panels_data = []
panel_ids = list(set([p['id'] for p in predictions.panels]))
for panel_id in panel_ids:
panel_data = create_hetero_data_from_predictions(predictions, panel_id)
# print("panel data", panel_data , dir(panel_data))
# print(( 'bubble' in panel_data.node_types , 'face' in panel_data.node_types,
# panel_data['bubble'].x.size(0) > 0 , panel_data['face'].x.size(0) > 0))
if panel_data is not None:
panel_data = panel_data.to(device)
panels_data.append(panel_data)
# Step 3: Run model inference
results = {}
for panel_data in panels_data:
# print("ppppp",panel_data == None)
# if (hasattr(panel_data, 'bubble') and hasattr(panel_data, 'face') and
# panel_data['bubble'].x.size(0) > 0 and panel_data['face'].x.size(0) > 0):
if ( panel_data != None and 'bubble' in panel_data.node_types and 'face' in panel_data.node_types and
panel_data['bubble'].x.size(0) > 0 and panel_data['face'].x.size(0) > 0):
# Run inference using existing function
mapping = infer_associations(model, panel_data)
# print(f"\\n🖼️ {panel_data.panel_id}:")
# if mapping:
# for bubble_idx, face_idx in mapping.items():
# print(f" Text/Bubble {bubble_idx} → Face {face_idx}")
# else:
# print(" No associations found")
results[panel_data.panel_id] = mapping
else:
print(f"⚠️ {panel_data.panel_id}: No valid bubbles or faces, skipping...")
total_associations = 0
if results:
for panel_id, mapping in results.items():
if mapping:
# print(f"\\n🖼️ {panel_id}:")
for bubble_id, face_id in mapping.items():
# print(f" Text/Bubble {bubble_id} ← → Face {face_id}")
total_associations += 1
# else:
# print(f"\\n🖼️ {panel_id}: No associations found")
else:
print("No associations found in the image.")
print(f"\\n✅ Total associations found: {total_associations}")
return results,predictions
except Exception as e:
print(f"❌ Error during inference: {str(e)}")
raise e
# #!/usr/bin/env python3
# """
# inference.py ── Run saved AssocGCN on new images / panels.
# Generate per-panel mappings and (optionally) aggregate metrics.
# """
# import argparse
# import torch
# from pathlib import Path
# from torch_geometric.data import Batch
# from typing import Dict, List
# from utils import load_model
# from train_speaker import DatasetLoader, infer_associations # your original file
# @torch.no_grad()
# def evaluate_panels(model,
# panels: List["HeteroData"],
# compute_metrics: bool = True) -> None:
# device = next(model.parameters()).device
# tp = fp = fn = 0
# for idx, data in enumerate(panels):
# mapping = infer_associations(model, data)
# print(f"\n🖼️ Panel {data.panel_id}:")
# for bub, face in mapping.items():
# print(f" Bubble {bub} → Face {face}")
# if compute_metrics and "edge_label" in data["bubble", "to", "face"]:
# gt = {(i.item(), j.item())
# for i, j, lbl in zip(*data["bubble", "to", "face"].edge_index,
# data["bubble", "to", "face"].edge_label)
# if lbl == 1}
# pred = {(b, f) for b, f in mapping.items()}
# tp += len(gt & pred)
# fp += len(pred - gt)
# fn += len(gt - pred)
# if compute_metrics:
# prec = tp / (tp + fp) if (tp + fp) else 0
# rec = tp / (tp + fn) if (tp + fn) else 0
# f1 = 2*prec*rec/(prec+rec) if (prec+rec) else 0
# print("\n📊 Aggregated metrics")
# print(f" Precision: {prec:.3f}")
# print(f" Recall : {rec:.3f}")
# print(f" F1 Score : {f1:.3f}")
# def identify_speaker(config):
# pass