#!/usr/bin/env python3 """ Discover ECG-FM Model Labels Inspect the actual labels that the finetuned model outputs """ import torch import numpy as np import json from typing import Dict, Any, List import requests import time def test_model_with_sample_ecg(): """Test the deployed model to see what labels it actually outputs""" print("๐Ÿ” Discovering ECG-FM Model Labels") print("=" * 50) # Test with a simple ECG signal # Create a minimal 12-lead ECG signal (500 samples, 12 leads) sample_ecg = np.random.normal(0, 0.1, (12, 500)).tolist() payload = { "signal": sample_ecg, "fs": 500, "lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"], "recording_duration": 1.0 } print("๐Ÿ“Š Testing with sample ECG signal...") print(f" Signal shape: {len(sample_ecg)} leads x {len(sample_ecg[0])} samples") # Test the deployed API api_url = "https://mystic-cbk-ecg-fm-api.hf.space" try: print(f"\n๐ŸŒ Testing deployed API: {api_url}") # Test health first health_response = requests.get(f"{api_url}/health", timeout=30) if health_response.status_code == 200: print("โœ… API is healthy") else: print(f"โŒ API health check failed: {health_response.status_code}") return # Test full analysis print("\n๐Ÿ”ฌ Testing full ECG analysis...") analysis_response = requests.post( f"{api_url}/analyze", json=payload, timeout=180 ) if analysis_response.status_code == 200: result = analysis_response.json() print("โœ… Analysis successful!") # Inspect the response structure print("\n๐Ÿ“‹ Response Structure Analysis:") print(f" Keys: {list(result.keys())}") if 'clinical_analysis' in result: clinical = result['clinical_analysis'] print(f"\n๐Ÿฅ Clinical Analysis Keys: {list(clinical.keys())}") if 'label_probabilities' in clinical: label_probs = clinical['label_probabilities'] print(f"\n๐Ÿท๏ธ Label Probabilities Found: {len(label_probs)} labels") print(" Labels and probabilities:") for label, prob in label_probs.items(): print(f" {label}: {prob:.3f}") # Save discovered labels discovered_labels = list(label_probs.keys()) save_discovered_labels(discovered_labels) else: print("โŒ No label_probabilities found in response") print(" This suggests the model might not be outputting clinical labels yet") if 'probabilities' in result: probs = result['probabilities'] print(f"\n๐Ÿ“Š Raw Probabilities Array: {len(probs)} values") print(f" First 10 values: {probs[:10]}") # If we have probabilities but no labels, we need to discover the label mapping if len(probs) > 0 and 'label_probabilities' not in result.get('clinical_analysis', {}): print("\nโš ๏ธ Model outputs probabilities but no label names") print(" This suggests we need to find the label definitions from the model") else: print(f"โŒ Analysis failed: {analysis_response.status_code}") print(f" Response: {analysis_response.text}") except Exception as e: print(f"โŒ Error testing API: {e}") def save_discovered_labels(labels: List[str]): """Save discovered labels to a file""" try: # Create a proper label definition file label_def_content = [] for i, label in enumerate(labels): label_def_content.append(f"{i},{label}") with open('discovered_labels.csv', 'w') as f: f.write('\n'.join(label_def_content)) print(f"\n๐Ÿ’พ Discovered labels saved to: discovered_labels.csv") print(f" Total labels: {len(labels)}") # Also create a simple list file with open('model_labels.txt', 'w') as f: f.write('\n'.join(labels)) print(f" Labels list saved to: model_labels.txt") except Exception as e: print(f"โŒ Error saving discovered labels: {e}") def inspect_model_checkpoint(): """Inspect the model checkpoint to understand its structure""" print("\n๐Ÿ” Model Checkpoint Inspection") print("=" * 40) print("๐Ÿ’ก To properly discover model labels, you should:") print("1. Load the model checkpoint locally") print("2. Inspect the model's classification head") print("3. Check for label mapping in the checkpoint") print("4. Or test with known ECG data to see output patterns") print("\n๐Ÿ“š Alternative approaches:") print("1. Check ECG-FM paper/repository for label definitions") print("2. Contact the model authors for label mapping") print("3. Use a small labeled dataset to map outputs to known conditions") def main(): """Main function to discover model labels""" print("๐Ÿงช ECG-FM Model Label Discovery") print("=" * 50) print("๐ŸŽฏ Goal: Discover the actual labels that the finetuned model outputs") print(" This will help us create the correct label_def.csv") # Test with deployed API test_model_with_sample_ecg() # Provide guidance for further investigation inspect_model_checkpoint() print("\n๐Ÿ’ก Next Steps:") print("1. Run this script to test the deployed API") print("2. Check if label_probabilities are returned") print("3. If yes, use those labels; if no, investigate further") print("4. Update label_def.csv with the correct labels") if __name__ == "__main__": main()