|
|
|
"""
|
|
Complete Standalone End-to-End Speech-to-Speaker Association Inference Script
|
|
Includes all necessary functions without external dependencies on custom files
|
|
"""
|
|
|
|
import json
|
|
import torch
|
|
import cv2
|
|
from torch_geometric.data import HeteroData, Batch
|
|
from typing import Dict, List, Any, Optional, Tuple, Union
|
|
import os
|
|
from ultralytics import YOLO
|
|
from utils.train_speaker import hungarian_matching,AssocGCN,infer_associations,DatasetLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectionPredictions:
|
|
"""Container for object detection predictions from an unseen image"""
|
|
|
|
def __init__(self, image_path: str):
|
|
self.image_path = image_path
|
|
self.panels: List[Dict] = []
|
|
self.speech_bubbles: List[Dict] = []
|
|
self.faces: List[Dict] = []
|
|
self.bodies: List[Dict] = []
|
|
self.image_size: Tuple[int, int] = (0, 0)
|
|
|
|
def add_panel(self, bbox: List[float], confidence: float, panel_id: int):
|
|
"""Add panel detection (frame in YOLO terms)"""
|
|
self.panels.append({
|
|
'bbox': bbox, 'confidence': confidence, 'id': panel_id
|
|
})
|
|
|
|
def add_speech_bubble(self, bbox: List[float], confidence: float,
|
|
panel_id: int, bubble_id: int):
|
|
"""Add speech bubble detection (text in YOLO terms)"""
|
|
self.speech_bubbles.append({
|
|
'bbox': bbox, 'confidence': confidence,
|
|
'panel_id': panel_id, 'id': bubble_id
|
|
})
|
|
|
|
def add_face(self, bbox: List[float], confidence: float,
|
|
panel_id: int, face_id: int):
|
|
"""Add face detection"""
|
|
self.faces.append({
|
|
'bbox': bbox, 'confidence': confidence,
|
|
'panel_id': panel_id, 'id': face_id
|
|
})
|
|
|
|
def add_body(self, bbox: List[float], confidence: float,
|
|
panel_id: int, body_id: int):
|
|
"""Add body detection"""
|
|
self.bodies.append({
|
|
'bbox': bbox, 'confidence': confidence,
|
|
'panel_id': panel_id, 'id': body_id
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_panel_dict_from_predictions(predictions: DetectionPredictions,
|
|
panel_id: int) -> Dict:
|
|
"""
|
|
ADAPTED: Creates a panel dictionary from detection predictions
|
|
Uses the same format as original create_hetero_data_from_panel expects
|
|
"""
|
|
|
|
panel_bubbles = [b for b in predictions.speech_bubbles if b['panel_id'] == panel_id]
|
|
panel_faces = [f for f in predictions.faces if f['panel_id'] == panel_id]
|
|
|
|
|
|
bubbles = []
|
|
for i, bubble in enumerate(panel_bubbles):
|
|
bubbles.append({
|
|
'bubble_id': bubble['id'],
|
|
'bbox': bubble['bbox']
|
|
})
|
|
|
|
|
|
faces = []
|
|
for i, face in enumerate(panel_faces):
|
|
faces.append({
|
|
'face_id': face['id'],
|
|
'bbox': face['bbox']
|
|
})
|
|
|
|
|
|
panel_dict = {
|
|
'panel_id': f"unseen_panel_{panel_id}",
|
|
'width': predictions.image_size[0],
|
|
'height': predictions.image_size[1],
|
|
'bubbles': bubbles,
|
|
'faces': faces,
|
|
'links': []
|
|
}
|
|
|
|
return panel_dict
|
|
|
|
|
|
def create_hetero_data_from_predictions(predictions: DetectionPredictions,
|
|
panel_id: int) -> Optional[HeteroData]:
|
|
"""
|
|
ENHANCED: Creates HeteroData from detection predictions
|
|
Leverages existing create_hetero_data_from_panel function
|
|
"""
|
|
|
|
panel_dict = create_panel_dict_from_predictions(predictions, panel_id)
|
|
|
|
|
|
hetero_data = DatasetLoader.create_hetero_data_from_panel(panel_dict)
|
|
|
|
if hetero_data is not None:
|
|
|
|
if hasattr(hetero_data['bubble', 'to', 'face'], 'edge_label'):
|
|
delattr(hetero_data['bubble', 'to', 'face'], 'edge_label')
|
|
|
|
return hetero_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_predictions_from_yolo(img_path: str, yolo_model) -> DetectionPredictions:
|
|
"""
|
|
INTEGRATED: YOLO detection pipeline that maps to our detection categories
|
|
Maps YOLO classes: {0:"body", 1:"face", 2:"frame", 3:"text"}
|
|
"""
|
|
CLASSES = {0: "body", 1: "face", 2: "frame", 3: "text"}
|
|
|
|
|
|
predictions = DetectionPredictions(img_path)
|
|
|
|
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
raise ValueError(f"Could not load image: {img_path}")
|
|
height, width = img.shape[:2]
|
|
predictions.image_size = (width, height)
|
|
|
|
|
|
results = yolo_model.predict(source=img_path, device='cuda', verbose=False)
|
|
|
|
|
|
detections_by_type = {"body": [], "face": [], "frame": [], "text": []}
|
|
|
|
for box in results[0].boxes:
|
|
c = int(box.cls[0])
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
confidence = float(box.conf[0])
|
|
|
|
detection_type = CLASSES[c]
|
|
detections_by_type[detection_type].append({
|
|
'bbox': [x1, y1, x2, y2],
|
|
'confidence': confidence
|
|
})
|
|
|
|
|
|
for i, frame in enumerate(detections_by_type["frame"]):
|
|
predictions.add_panel(frame['bbox'], frame['confidence'], i)
|
|
|
|
|
|
if len(predictions.panels) == 0:
|
|
predictions.add_panel([0, 0, width, height], 1.0, 0)
|
|
|
|
|
|
for i, text in enumerate(detections_by_type["text"]):
|
|
panel_id = find_containing_panel(text['bbox'], predictions.panels)
|
|
predictions.add_speech_bubble(text['bbox'], text['confidence'], panel_id, i)
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
grouped_texts = defaultdict(list)
|
|
for bubble in predictions.speech_bubbles:
|
|
grouped_texts[bubble['panel_id']].append(bubble)
|
|
|
|
|
|
for panel_id, bubbles in grouped_texts.items():
|
|
sorted_bubbles = sorted(bubbles, key=lambda b: b['bbox'][0], reverse=True)
|
|
for seq, bubble in enumerate(sorted_bubbles):
|
|
bubble['seq'] = seq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, face in enumerate(detections_by_type["face"]):
|
|
panel_id = find_containing_panel(face['bbox'], predictions.panels)
|
|
predictions.add_face(face['bbox'], face['confidence'], panel_id, i)
|
|
|
|
|
|
for i, body in enumerate(detections_by_type["body"]):
|
|
panel_id = find_containing_panel(body['bbox'], predictions.panels)
|
|
predictions.add_body(body['bbox'], body['confidence'], panel_id, i)
|
|
|
|
return predictions
|
|
|
|
|
|
def find_containing_panel(bbox: List[float], panels: List[Dict]) -> int:
|
|
"""
|
|
Helper function to determine which panel contains a detection
|
|
"""
|
|
bbox_center_x = (bbox[0] + bbox[2]) / 2
|
|
bbox_center_y = (bbox[1] + bbox[3]) / 2
|
|
|
|
for panel in panels:
|
|
p_bbox = panel['bbox']
|
|
if (p_bbox[0] <= bbox_center_x <= p_bbox[2] and
|
|
p_bbox[1] <= bbox_center_y <= p_bbox[3]):
|
|
return panel['id']
|
|
|
|
|
|
return panels[0]['id'] if panels else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_trained_speaker_model(model_path: str, device: str = "cuda") -> AssocGCN:
|
|
"""
|
|
FIXED: Load trained AssocGCN model with proper error handling
|
|
Handles different checkpoint formats including 'model_state' key
|
|
"""
|
|
|
|
model = AssocGCN().to(device)
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device)
|
|
|
|
|
|
if isinstance(checkpoint, dict):
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
state_dict = checkpoint['model_state_dict']
|
|
elif 'model_state' in checkpoint:
|
|
state_dict = checkpoint['model_state']
|
|
elif 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
else:
|
|
|
|
|
|
state_dict = {k: v for k, v in checkpoint.items()
|
|
if k not in {'epoch', 'loss', 'optimizer_state_dict'}}
|
|
else:
|
|
|
|
state_dict = checkpoint
|
|
|
|
|
|
if any(key.startswith('module.') for key in state_dict.keys()):
|
|
new_state_dict = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace('module.', '') if key.startswith('module.') else key
|
|
new_state_dict[new_key] = value
|
|
state_dict = new_state_dict
|
|
|
|
try:
|
|
|
|
model.load_state_dict(state_dict, strict=True)
|
|
print("✅ Model loaded successfully with strict=True")
|
|
except RuntimeError as e:
|
|
print(f"⚠️ Warning: {str(e)}")
|
|
print("Attempting to load with strict=False...")
|
|
try:
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
if missing_keys:
|
|
print(f"⚠️ Missing keys: {missing_keys}")
|
|
if unexpected_keys:
|
|
print(f"⚠️ Unexpected keys: {unexpected_keys}")
|
|
print("✅ Model loaded successfully with strict=False")
|
|
except Exception as e2:
|
|
print(f"❌ Failed to load model: {str(e2)}")
|
|
raise e2
|
|
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
def identify_speaker(model: AssocGCN, image_path: str,
|
|
yolo_model, device: str = "cuda") -> Dict[str, Dict[int, int]]:
|
|
"""
|
|
COMPLETE: End-to-end inference pipeline for unseen images
|
|
"""
|
|
|
|
try:
|
|
|
|
print(f"🔍 Running YOLO detection on {os.path.basename(image_path)}...")
|
|
predictions = get_predictions_from_yolo(image_path, yolo_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(predictions.speech_bubbles) == 0:
|
|
print("⚠️ No speech bubbles/text detected. Cannot perform association.")
|
|
return {}
|
|
|
|
if len(predictions.faces) == 0:
|
|
print("⚠️ No faces detected. Cannot perform association.")
|
|
return {}
|
|
|
|
|
|
panels_data = []
|
|
panel_ids = list(set([p['id'] for p in predictions.panels]))
|
|
|
|
for panel_id in panel_ids:
|
|
panel_data = create_hetero_data_from_predictions(predictions, panel_id)
|
|
|
|
|
|
|
|
if panel_data is not None:
|
|
panel_data = panel_data.to(device)
|
|
panels_data.append(panel_data)
|
|
|
|
|
|
results = {}
|
|
|
|
for panel_data in panels_data:
|
|
|
|
|
|
|
|
if ( panel_data != None and 'bubble' in panel_data.node_types and 'face' in panel_data.node_types and
|
|
panel_data['bubble'].x.size(0) > 0 and panel_data['face'].x.size(0) > 0):
|
|
|
|
|
|
mapping = infer_associations(model, panel_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results[panel_data.panel_id] = mapping
|
|
else:
|
|
print(f"⚠️ {panel_data.panel_id}: No valid bubbles or faces, skipping...")
|
|
|
|
|
|
total_associations = 0
|
|
if results:
|
|
for panel_id, mapping in results.items():
|
|
if mapping:
|
|
|
|
for bubble_id, face_id in mapping.items():
|
|
|
|
total_associations += 1
|
|
|
|
|
|
else:
|
|
print("No associations found in the image.")
|
|
|
|
print(f"\\n✅ Total associations found: {total_associations}")
|
|
|
|
return results,predictions
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error during inference: {str(e)}")
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|