#!/usr/bin/env python3 """ Clinical Analysis Module for ECG-FM Handles real clinical predictions from finetuned model """ import numpy as np import torch from typing import Dict, Any, List def analyze_ecg_features(model_output: Dict[str, Any]) -> Dict[str, Any]: """Extract clinical predictions from finetuned ECG-FM model output""" try: # DEBUG: Log what we're receiving print(f"🔍 DEBUG: analyze_ecg_features received: {type(model_output)}") if isinstance(model_output, dict): print(f"🔍 DEBUG: Keys: {list(model_output.keys())}") for key, value in model_output.items(): if isinstance(value, torch.Tensor): print(f"🔍 DEBUG: {key} shape: {value.shape}, dtype: {value.dtype}") else: print(f"🔍 DEBUG: {key}: {type(value)} - {value}") # Check if we have clinical predictions from the finetuned model if 'label_logits' in model_output: print("✅ Found label_logits - using finetuned model output") # FINETUNED MODEL - Extract real clinical predictions logits = model_output['label_logits'] if isinstance(logits, torch.Tensor): probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() else: probs = 1 / (1 + np.exp(-np.array(logits).ravel())) # Extract clinical parameters from probabilities clinical_result = extract_clinical_from_probabilities(probs) return clinical_result # NEW: Check for 'out' key (actual finetuned model output) elif 'out' in model_output: print("✅ Found 'out' key - using finetuned model output") # FINETUNED MODEL - Extract real clinical predictions logits = model_output['out'] if isinstance(logits, torch.Tensor): # Remove batch dimension if present if logits.dim() == 2: # [batch, num_labels] logits = logits.squeeze(0) # Remove batch dimension probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() else: probs = 1 / (1 + np.exp(-np.array(logits).ravel())) # Extract clinical parameters from probabilities clinical_result = extract_clinical_from_probabilities(probs) return clinical_result # NEW: Check if the model output IS the logits tensor directly (classifier model) elif isinstance(model_output, torch.Tensor): print("✅ Found direct logits tensor - using classifier model output") # The model output is the logits directly logits = model_output if logits.dim() == 2: # [batch, num_labels] logits = logits.squeeze(0) # Remove batch dimension probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() clinical_result = extract_clinical_from_probabilities(probs) return clinical_result # NEW: Check if model output is a tuple (common in some frameworks) elif isinstance(model_output, tuple): print("✅ Found tuple output - checking for logits") # Look for logits in the tuple for item in model_output: if isinstance(item, torch.Tensor) and item.dim() == 2 and item.shape[1] == 17: print("✅ Found logits in tuple - using classifier model output") logits = item.squeeze(0) # Remove batch dimension probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() clinical_result = extract_clinical_from_probabilities(probs) return clinical_result print("❌ No suitable logits found in tuple") return create_fallback_response("Tuple output but no logits found") elif 'features' in model_output: # PRETRAINED MODEL - Fallback to feature analysis features = model_output.get('features', []) if isinstance(features, torch.Tensor): features = features.detach().cpu().numpy() if len(features) > 0: # Basic clinical estimation from features (fallback) clinical_result = estimate_clinical_from_features(features) return clinical_result else: return create_fallback_response("Insufficient features") else: return create_fallback_response("No clinical data available") except Exception as e: print(f"❌ Error in clinical analysis: {e}") return create_fallback_response("Analysis error") def extract_clinical_from_probabilities(probs: np.ndarray) -> Dict[str, Any]: """Extract clinical findings from probability array using official ECG-FM labels""" try: # Load official labels and thresholds labels = load_label_definitions() thresholds = load_clinical_thresholds() if len(probs) != len(labels): print(f"⚠️ Warning: Probability array length ({len(probs)}) doesn't match label count ({len(labels)})") # Truncate or pad as needed if len(probs) > len(labels): probs = probs[:len(labels)] else: probs = np.pad(probs, (0, len(labels) - len(probs)), 'constant', constant_values=0.0) # Find abnormalities above threshold abnormalities = [] for i, (label, prob) in enumerate(zip(labels, probs)): threshold = thresholds.get(label, 0.7) if prob >= threshold: abnormalities.append(label) # Determine rhythm rhythm = determine_rhythm_from_abnormalities(abnormalities) # Calculate confidence metrics confidence_metrics = calculate_confidence_metrics(probs, thresholds) return { "rhythm": rhythm, "heart_rate": None, # Will be calculated from features if available "qrs_duration": None, # Will be calculated from features if available "qt_interval": None, # Will be calculated from features if available "pr_interval": None, # Will be calculated from features if available "axis_deviation": "Normal", # Will be calculated from features if available "abnormalities": abnormalities, "confidence": confidence_metrics["overall_confidence"], "confidence_level": confidence_metrics["confidence_level"], "review_required": confidence_metrics["review_required"], "probabilities": probs.tolist(), "label_probabilities": dict(zip(labels, probs.tolist())), "method": "clinical_predictions", "warning": None, "labels_used": labels, "thresholds_used": thresholds } except Exception as e: print(f"❌ Error in clinical probability extraction: {e}") return create_fallback_response(f"Clinical analysis failed: {str(e)}") def estimate_clinical_from_features(features: np.ndarray) -> Dict[str, Any]: """Estimate clinical parameters from ECG features (fallback method)""" try: if len(features) == 0: return create_fallback_response("No features available for estimation") # ECG-FM features require proper validation and analysis # We cannot provide reliable clinical estimates without validated algorithms print("⚠️ Clinical estimation from features requires validated ECG-FM algorithms") print(" Returning fallback response to prevent incorrect clinical information") return create_fallback_response("Clinical estimation from features not yet validated") except Exception as e: print(f"❌ Error in clinical feature estimation: {e}") return create_fallback_response(f"Feature estimation error: {str(e)}") def create_fallback_response(reason: str) -> Dict[str, Any]: """Create fallback response when clinical analysis fails""" return { "rhythm": "Analysis Failed", "heart_rate": None, "qrs_duration": None, "qt_interval": None, "pr_interval": None, "axis_deviation": "Unknown", "abnormalities": [], "confidence": 0.0, "confidence_level": "None", "review_required": True, "probabilities": [], "label_probabilities": {}, "method": "fallback", "warning": reason, "labels_used": [], "thresholds_used": {} } # New helper functions for enhanced clinical analysis def load_label_definitions() -> List[str]: """Load official ECG-FM label definitions from CSV file""" try: import pandas as pd df = pd.read_csv('label_def.csv', header=None) label_names = [] for _, row in df.iterrows(): if len(row) >= 2: label_names.append(row[1]) # Second column contains label names # Validate that we have the expected 17 labels if len(label_names) != 17: print(f"⚠️ Warning: Expected 17 labels, got {len(label_names)}") print(f" Labels: {label_names}") print(f"✅ Loaded {len(label_names)} official ECG-FM labels") return label_names except Exception as e: print(f"❌ CRITICAL ERROR: Could not load label_def.csv: {e}") print(" ECG-FM clinical analysis cannot proceed without proper labels") raise RuntimeError(f"Failed to load ECG-FM label definitions: {e}") def load_clinical_thresholds() -> Dict[str, float]: """Load clinical thresholds from JSON file""" try: import json with open('thresholds.json', 'r') as f: config = json.load(f) thresholds = config.get('clinical_thresholds', {}) # Validate that thresholds match our labels expected_labels = load_label_definitions() missing_labels = [label for label in expected_labels if label not in thresholds] if missing_labels: print(f"⚠️ Warning: Missing thresholds for labels: {missing_labels}") # Use default threshold for missing labels for label in missing_labels: thresholds[label] = 0.7 print(f"✅ Loaded thresholds for {len(thresholds)} clinical labels") return thresholds except Exception as e: print(f"❌ CRITICAL ERROR: Could not load thresholds.json: {e}") print(" Using default threshold of 0.7 for all labels") # Load labels first to create default thresholds try: labels = load_label_definitions() default_thresholds = {label: 0.7 for label in labels} print(f"✅ Created default thresholds for {len(default_thresholds)} labels") return default_thresholds except Exception as label_error: print(f"❌ CRITICAL ERROR: Cannot create default thresholds: {label_error}") raise RuntimeError(f"Failed to load clinical thresholds: {e}") def determine_rhythm_from_abnormalities(abnormalities: List[str]) -> str: """Determine heart rhythm based on detected abnormalities using official ECG-FM labels""" if not abnormalities: return "Normal Sinus Rhythm" # Use official ECG-FM labels for rhythm determination # Priority-based rhythm determination if "Atrial fibrillation" in abnormalities: return "Atrial Fibrillation" elif "Atrial flutter" in abnormalities: return "Atrial Flutter" elif "Ventricular tachycardia" in abnormalities: return "Ventricular Tachycardia" elif "Supraventricular tachycardia with aberrancy" in abnormalities: return "Supraventricular Tachycardia with Aberrancy" elif "Bradycardia" in abnormalities: return "Bradycardia" elif "Tachycardia" in abnormalities: return "Tachycardia" elif "Premature ventricular contraction" in abnormalities: return "Premature Ventricular Contractions" elif "1st degree atrioventricular block" in abnormalities: return "1st Degree AV Block" elif "Atrioventricular block" in abnormalities: return "AV Block" elif "Right bundle branch block" in abnormalities: return "Right Bundle Branch Block" elif "Left bundle branch block" in abnormalities: return "Left Bundle Branch Block" elif "Bifascicular block" in abnormalities: return "Bifascicular Block" elif "Accessory pathway conduction" in abnormalities: return "Accessory Pathway Conduction" elif "Infarction" in abnormalities: return "Myocardial Infarction" elif "Electronic pacemaker" in abnormalities: return "Electronic Pacemaker" elif "Poor data quality" in abnormalities: return "Poor Data Quality - Rhythm Unclear" else: return "Abnormal Rhythm" def calculate_confidence_metrics(probs: np.ndarray, thresholds: Dict[str, float]) -> Dict[str, Any]: """Calculate confidence metrics and review flags""" max_prob = np.max(probs) mean_prob = np.mean(probs) # Determine confidence level if max_prob >= 0.8: confidence_level = "High" elif max_prob >= 0.6: confidence_level = "Medium" else: confidence_level = "Low" # Calculate overall confidence overall_confidence = float(max_prob) # Determine if review is required review_required = max_prob < 0.6 or mean_prob < 0.4 return { "overall_confidence": overall_confidence, "confidence_level": confidence_level, "review_required": review_required, "mean_probability": float(mean_prob), "max_probability": float(max_prob) }