import gradio as gr import numpy as np import pandas as pd import wfdb import tensorflow as tf from scipy import signal import os import subprocess import shutil import requests import zipfile # Disable GPU usage to avoid CUDA warnings os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force TensorFlow to use CPU only # Get HF_TOKEN from environment variables HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN not found. Please set it in the Space's environment variables.") # Define repository and dataset details REPO_URL = "https://github.com/AutoECG/Automated-ECG-Interpretation.git" REPO_DIR = "Automated-ECG-Interpretation" DATASET_URL = "https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip" DATASET_DIR = "ptb-xl" PERSISTENT_DIR = "/data" # Ensure persistent directory exists def ensure_persistent_dir(): if not os.path.exists(PERSISTENT_DIR): os.makedirs(PERSISTENT_DIR, exist_ok=True) # Function to clone the repository def clone_repository(): if not os.path.exists(REPO_DIR): print("Cloning repository...") try: subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) print("Repository cloned successfully.") except subprocess.CalledProcessError as e: print(f"Error cloning repository: {e}") else: print("Repository already cloned.") # Function to download and extract PTB-XL dataset def download_and_extract_dataset(): ensure_persistent_dir() # Ensure /data exists zip_path = os.path.join(PERSISTENT_DIR, "ptb-xl.zip") extract_path = os.path.join(PERSISTENT_DIR, DATASET_DIR) if not os.path.exists(extract_path): print("Downloading PTB-XL dataset...") response = requests.get(DATASET_URL, stream=True) with open(zip_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Extracting dataset...") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(PERSISTENT_DIR) os.remove(zip_path) # Clean up zip file print("Dataset extracted successfully.") else: print("Dataset already extracted.") # Clone repo and download dataset on startup clone_repository() download_and_extract_dataset() # Load the pre-trained model MODEL_FILENAME = "model.h5" MODEL_PATH = os.path.join(REPO_DIR, MODEL_FILENAME) PERSISTENT_MODEL_PATH = os.path.join(PERSISTENT_DIR, MODEL_FILENAME) if not os.path.exists(PERSISTENT_MODEL_PATH): if os.path.exists(MODEL_PATH): shutil.copy(MODEL_PATH, PERSISTENT_MODEL_PATH) else: raise FileNotFoundError( f"Model file not found at {MODEL_PATH}. Please ensure it's in the repository or upload it manually." ) model = tf.keras.models.load_model(PERSISTENT_MODEL_PATH) # Function to preprocess ECG data def preprocess_ecg(file_path): record = wfdb.rdrecord(file_path.replace(".dat", "")) ecg_signal = record.p_signal[:, 0] # First lead target_fs = 360 num_samples = int(len(ecg_signal) * target_fs / record.fs) ecg_resampled = signal.resample(ecg_signal, num_samples) ecg_normalized = (ecg_resampled - np.mean(ecg_resampled)) / np.std(ecg_resampled) if len(ecg_normalized) < 3600: ecg_normalized = np.pad(ecg_normalized, (0, 3600 - len(ecg_normalized)), "constant") else: ecg_normalized = ecg_normalized[:3600] ecg_input = ecg_normalized.reshape(1, 3600, 1) return ecg_input # Prediction function def predict_ecg(file=None, dataset_file=None): if file: file_path = file.name elif dataset_file: file_path = os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500", dataset_file) else: return "Please upload a file or select a dataset sample." ecg_data = preprocess_ecg(file_path) prediction = model.predict(ecg_data) label = "Abnormal" if prediction[0][0] > 0.5 else "Normal" confidence = float(prediction[0][0]) if label == "Abnormal" else float(1 - prediction[0][0]) return f"Prediction: {label}\nConfidence: {confidence:.2%}" # Get list of dataset files for dropdown dataset_files = [] if os.path.exists(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")): for root, _, files in os.walk(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")): for file in files: if file.endswith(".dat"): dataset_files.append(os.path.relpath(os.path.join(root, file), os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500"))) # Gradio interface interface = gr.Interface( fn=predict_ecg, inputs=[ gr.File(label="Upload ECG File (.dat format)"), gr.Dropdown(choices=dataset_files, label="Or Select a PTB-XL Sample") ], outputs=gr.Textbox(label="ECG Interpretation"), title="Automated ECG Interpretation", description="Upload an ECG file (.dat) or select a sample from the PTB-XL dataset for automated interpretation." ) # Launch the app interface.launch()