#!/usr/bin/env python3 """ Complete GCN Training Pipeline for Speech Bubble to Speaker Association Fixed version that handles the dataset format correctly and resolves training issues. """ import json import torch import numpy as np import random from torch_geometric.data import HeteroData, Batch import torch.nn as nn from scipy.optimize import linear_sum_assignment from typing import Dict, List, Any, Optional, Tuple # from utils.utilities import save_checkpoint import os from pathlib import Path CHECKPOINT_DIR = Path("checkpoints") CHECKPOINT_DIR.mkdir(exist_ok=True) def save_checkpoint(model: torch.nn.Module, epoch: int, loss: float, path: Path = CHECKPOINT_DIR / "assoc_gcn.pt") -> None: """ Persist full training state so you can resume fine-tuning later. """ path = Path(path) torch.save({ "epoch": epoch, "loss": loss, "model_state": model.state_dict() }, path) print(f"✅ Model checkpoint saved to {path.resolve()}") class DatasetLoader: """Handles loading and preprocessing of the converted GCN dataset""" @staticmethod def load_converted_dataset(json_path: str) -> List[HeteroData]: """Load the converted GCN dataset and create PyTorch Geometric HeteroData objects""" with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) dataset = [] panels = data.get('panels', []) for panel in panels: het_data = DatasetLoader.create_hetero_data_from_panel(panel) if het_data is not None: dataset.append(het_data) print(f"Loaded {len(dataset)} panels from {json_path}") return dataset @staticmethod def create_hetero_data_from_panel(panel: Dict) -> Optional[HeteroData]: """Convert a single panel from the converted dataset into HeteroData format""" bubbles = panel.get('bubbles', []) faces = panel.get('faces', []) links = panel.get('links', []) if len(bubbles) == 0 or len(faces) == 0: return None W, H = panel['width'], panel['height'] # Create node features bubble_features = [] for bubble in bubbles: x1, y1, x2, y2 = bubble['bbox'] cx, cy = (x1 + x2) / (2 * W), (y1 + y2) / (2 * H) w, h = (x2 - x1) / W, (y2 - y1) / H area = w * h aspect = w / h if h > 0 else 1.0 bubble_features.append([cx, cy, w, h, area, aspect]) face_features = [] for face in faces: x1, y1, x2, y2 = face['bbox'] cx, cy = (x1 + x2) / (2 * W), (y1 + y2) / (2 * H) w, h = (x2 - x1) / W, (y2 - y1) / H area = w * h aspect = w / h if h > 0 else 1.0 face_features.append([cx, cy, w, h, area, aspect]) # Create edge indices and features edge_indices, edge_features, edge_labels = [], [], [] # Create mappings bubble_id_to_idx = {bubble['bubble_id']: i for i, bubble in enumerate(bubbles)} face_id_to_idx = {face['face_id']: i for i, face in enumerate(faces)} # Create ground truth mapping gt_links = {} for link in links: if link['bubble_id'] in bubble_id_to_idx and link['face_id'] in face_id_to_idx: bubble_idx = bubble_id_to_idx[link['bubble_id']] face_idx = face_id_to_idx[link['face_id']] gt_links[(bubble_idx, face_idx)] = 1 # Create all possible bubble-face edges for i, bubble in enumerate(bubbles): for j, face in enumerate(faces): # Calculate edge features b_x1, b_y1, b_x2, b_y2 = bubble['bbox'] f_x1, f_y1, f_x2, f_y2 = face['bbox'] b_cx, b_cy = (b_x1 + b_x2) / (2 * W), (b_y1 + b_y2) / (2 * H) f_cx, f_cy = (f_x1 + f_x2) / (2 * W), (f_y1 + f_y2) / (2 * H) dx, dy = b_cx - f_cx, b_cy - f_cy dist = (dx**2 + dy**2)**0.5 # Calculate IoU xx1, yy1 = max(b_x1, f_x1), max(b_y1, f_y1) xx2, yy2 = min(b_x2, f_x2), min(b_y2, f_y2) inter = max(0, xx2 - xx1) * max(0, yy2 - yy1) union = (b_x2 - b_x1) * (b_y2 - b_y1) + (f_x2 - f_x1) * (f_y2 - f_y1) - inter iou = inter / union if union > 0 else 0 edge_indices.append([i, j]) edge_features.append([dx, dy, dist, iou]) edge_labels.append(1.0 if (i, j) in gt_links else 0.0) if len(edge_indices) == 0: return None # Create HeteroData object data = HeteroData() data['bubble'].x = torch.tensor(bubble_features, dtype=torch.float) data['face'].x = torch.tensor(face_features, dtype=torch.float) data['bubble', 'to', 'face'].edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() data['bubble', 'to', 'face'].edge_attr = torch.tensor(edge_features, dtype=torch.float) data['bubble', 'to', 'face'].edge_label = torch.tensor(edge_labels, dtype=torch.float) # Add metadata data.panel_id = panel['panel_id'] data.width = W data.height = H return data class AssocGCN(nn.Module): """Graph Convolutional Network for Speech Bubble to Speaker Association""" def __init__(self, in_feats: int = 6, hid: int = 128): super().__init__() self.node_encoder = nn.Sequential( nn.Linear(in_feats, hid), nn.ReLU(), nn.Linear(hid, hid) ) # Message passing layers self.conv1 = nn.Sequential( nn.Linear(hid * 2 + 4, hid), # node features + edge features nn.ReLU(), nn.Linear(hid, hid) ) self.conv2 = nn.Sequential( nn.Linear(hid * 2 + 4, hid), nn.ReLU(), nn.Linear(hid, hid) ) self.conv3 = nn.Sequential( nn.Linear(hid * 2 + 4, hid), nn.ReLU(), nn.Linear(hid, hid) ) # Edge classifier self.edge_mlp = nn.Sequential( nn.Linear(2 * hid + 4, hid), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hid, 1) ) def forward(self, data): bubble_x = self.node_encoder(data['bubble'].x) face_x = self.node_encoder(data['face'].x) edge_index = data['bubble', 'to', 'face'].edge_index edge_attr = data['bubble', 'to', 'face'].edge_attr src_idx, dst_idx = edge_index[0], edge_index[1] # Apply message passing for conv in [self.conv1, self.conv2, self.conv3]: src_features = bubble_x[src_idx] dst_features = face_x[dst_idx] edge_input = torch.cat([src_features, dst_features, edge_attr], dim=1) edge_updates = conv(edge_input) # Update node features (simplified aggregation) bubble_updates = torch.zeros_like(bubble_x) face_updates = torch.zeros_like(face_x) for i in range(len(src_idx)): s, d = src_idx[i].item(), dst_idx[i].item() bubble_updates[s] += edge_updates[i] face_updates[d] += edge_updates[i] # Normalize by degree bubble_degrees = torch.bincount(src_idx, minlength=bubble_x.size(0)).float().clamp(min=1) face_degrees = torch.bincount(dst_idx, minlength=face_x.size(0)).float().clamp(min=1) bubble_updates = bubble_updates / bubble_degrees.unsqueeze(1) face_updates = face_updates / face_degrees.unsqueeze(1) # Residual connection bubble_x = bubble_x + bubble_updates face_x = face_x + face_updates # Final edge prediction src_final = bubble_x[src_idx] dst_final = face_x[dst_idx] edge_input = torch.cat([src_final, dst_final, edge_attr], dim=1) logits = self.edge_mlp(edge_input).squeeze(-1) return logits def hungarian_matching(scores: torch.Tensor, src_indices, dst_indices): """Apply Hungarian algorithm for optimal bipartite matching""" if len(scores) == 0: return {} num_bubbles = src_indices.max().item() + 1 if len(src_indices) > 0 else 0 num_faces = dst_indices.max().item() + 1 if len(dst_indices) > 0 else 0 cost_matrix = np.full((num_bubbles, num_faces), 1e6, dtype=np.float32) scores_np = scores.detach().cpu().sigmoid().numpy() for i, (s, d, score) in enumerate(zip(src_indices.cpu(), dst_indices.cpu(), scores_np)): cost_matrix[s, d] = -score # Negative for minimization row_indices, col_indices = linear_sum_assignment(cost_matrix) mapping = {} for r, c in zip(row_indices, col_indices): if cost_matrix[r, c] < 0: # Valid assignment mapping[int(r)] = int(c) return mapping def train_gcn(dataset: List[HeteroData], epochs: int = 200, batch_size: int = 16, lr: float = 1e-4): """Train the GCN model on the dataset""" if len(dataset) == 0: raise ValueError("Dataset is empty!") print(f"Training on {len(dataset)} panels...") device = 'cuda' if torch.cuda.is_available() else 'cpu' model = AssocGCN().to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # Calculate class weights total_positive = sum(data['bubble', 'to', 'face'].edge_label.sum().item() for data in dataset) total_edges = sum(len(data['bubble', 'to', 'face'].edge_label) for data in dataset) pos_weight = (total_edges - total_positive) / total_positive if total_positive > 0 else 9.0 print(f"Positive edges: {total_positive}/{total_edges} ({100*total_positive/total_edges:.1f}%)") print(f"Using pos_weight: {pos_weight:.2f}") loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) model.train() best_loss = float("inf") for epoch in range(epochs): total_correct = 0 total_samples = 0 total_tp = 0 total_fp = 0 total_fn = 0 # FIXED: Properly shuffle the dataset (it's a list, not a dict) shuffled_dataset = dataset.copy() random.shuffle(shuffled_dataset) total_loss = 0.0 num_batches = (len(shuffled_dataset) + batch_size - 1) // batch_size for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, len(shuffled_dataset)) batch_data = shuffled_dataset[start_idx:end_idx] # Create batch batch = Batch.from_data_list(batch_data).to(device) # Forward pass logits = model(batch) labels = batch['bubble', 'to', 'face'].edge_label # Compute loss loss = loss_fn(logits, labels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * len(batch_data) # Accuracy computation probs = torch.sigmoid(logits) preds = (probs > 0.5).float() correct = (preds == labels).sum().item() total_correct += correct total_samples += labels.numel() avg_loss = total_loss / len(shuffled_dataset) print(f"Epoch {epoch+1:02d}/{epochs}: Loss = {avg_loss:.4f}") avg_loss = total_loss / len(shuffled_dataset) accuracy = total_correct / total_samples if avg_loss < best_loss: best_loss = avg_loss save_checkpoint(model, epoch+1, best_loss) # epoch is 0-indexed # Compute recall and F1 recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0 precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 print(f"Epoch {epoch+1:02d}/{epochs}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}, Recall = {recall:.4f}, F1 = {f1:.4f}") print("Training completed!") return model def infer_associations(model, data): """Infer speech bubble to speaker associations""" device = next(model.parameters()).device data = data.to(device) model.eval() with torch.no_grad(): logits = model(data) src, dst = data['bubble', 'to', 'face'].edge_index mapping = hungarian_matching(logits, src, dst) return mapping # Example usage and testing def train_speaker(config): # Test with sample data dataset = [] for panel_data_file in os.listdir(os.path.join(config["root"]+"panel_data/")): try: print(panel_data_file) # Load your converted dataset dataset += DatasetLoader.load_converted_dataset(os.path.join(config["root"]+"panel_data/",panel_data_file)) # if len(dataset) == 0: # print("No valid panels found in dataset!") # else: # # Train the model # model = train_gcn(dataset, epochs=10, batch_size=16) # Reduced epochs for testing # # Test inference on first panel # test_data = dataset[0] # print(test_data) # mapping = infer_associations(model, test_data) # print("\nInference Results:") # for bubble_id, face_id in mapping.items(): # print(f"Bubble {bubble_id} → Face {face_id}") except FileNotFoundError: print("Error: ./output.json not found!") print("Please ensure your converted dataset file exists.") except Exception as e: print(f"Error: {e}") print("Please check your dataset format and file paths.") model = train_gcn(dataset, epochs=30, batch_size=16) # Reduced epochs for testing