Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| import gradio as gr | |
| import biosppy.signals.ecg as ecg | |
| from PIL import Image | |
| import traceback | |
| # Create uploads directory | |
| UPLOAD_FOLDER = "/tmp/uploads" | |
| if not os.path.exists(UPLOAD_FOLDER): | |
| os.makedirs(UPLOAD_FOLDER) | |
| # Load the pre-trained model (assumes ecgScratchEpoch2.hdf5 is in the root directory) | |
| try: | |
| model = load_model("ecgScratchEpoch2.hdf5") | |
| except Exception as e: | |
| raise Exception(f"Failed to load model: {str(e)}") | |
| def image_to_signal(image): | |
| """Convert an ECG image to a 1D signal and save as CSV.""" | |
| try: | |
| # Convert Gradio image (PIL) to OpenCV format | |
| img = np.array(image) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| # Resize to a standard size | |
| img = cv2.resize(img, (1000, 500)) | |
| # Apply thresholding to isolate waveform | |
| _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV) | |
| # Find contours | |
| contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| raise ValueError("No waveform detected in the image") | |
| # Use the largest contour | |
| contour = max(contours, key=cv2.contourArea) | |
| # Extract y-coordinates along x-axis | |
| signal = [] | |
| width = img.shape[1] | |
| for x in range(width): | |
| column = contour[contour[:, :, 0] == x] | |
| if len(column) > 0: | |
| y = np.mean(column[:, :, 1]) | |
| signal.append(y) | |
| else: | |
| signal.append(signal[-1] if signal else 0) | |
| # Normalize signal | |
| signal = np.array(signal) | |
| signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000 | |
| # Save to CSV | |
| csv_path = os.path.join(UPLOAD_FOLDER, "converted_signal.csv") | |
| df = pd.DataFrame(signal, columns=[" Sample Value"]) | |
| df.to_csv(csv_path, index=False) | |
| return csv_path | |
| except Exception as e: | |
| raise Exception(f"Image processing error: {str(e)}") | |
| def model_predict(csv_path): | |
| """Predict ECG arrhythmia classes from a CSV file.""" | |
| try: | |
| output = [] | |
| APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], [] | |
| result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB} | |
| kernel = np.ones((4, 4), np.uint8) | |
| csv = pd.read_csv(csv_path) | |
| csv_data = csv[" Sample Value"] | |
| data = np.array(csv_data) | |
| signals = [] | |
| count = 1 | |
| peaks = ecg.christov_segmenter(signal=data, sampling_rate=200)[0] | |
| indices = [] | |
| for i in peaks[1:-1]: | |
| diff1 = abs(peaks[count - 1] - i) | |
| diff2 = abs(peaks[count + 1] - i) | |
| x = peaks[count - 1] + diff1 // 2 | |
| y = peaks[count + 1] - diff2 // 2 | |
| signal = data[x:y] | |
| signals.append(signal) | |
| count += 1 | |
| indices.append((x, y)) | |
| for signal, index in zip(signals, indices): | |
| if len(signal) > 10: | |
| img = np.zeros((128, 128)) | |
| for i in range(len(signal)): | |
| img[i, int(signal[i] / 10)] = 255 | |
| img = cv2.dilate(img, kernel, iterations=1) | |
| img = img.reshape(128, 128, 1) | |
| prediction = model.predict(np.array([img]), verbose=0).argmax() | |
| classes = ["Normal", "APC", "LBB", "PAB", "PVC", "RBB", "VEB"] | |
| result[classes[prediction]].append(index) | |
| output.append({"file": csv_path, "results": result}) | |
| return output | |
| except Exception as e: | |
| raise Exception(f"Prediction error: {str(e)}") | |
| def classify_ecg(file): | |
| """Main function to handle file uploads (CSV or image).""" | |
| try: | |
| if file is None: | |
| return "No file uploaded." | |
| # Save uploaded file | |
| file_path = os.path.join(UPLOAD_FOLDER, "uploaded_file") | |
| if isinstance(file, str): # CSV file path | |
| file_path += ".csv" | |
| with open(file_path, "wb") as f: | |
| with open(file, "rb") as src: | |
| f.write(src.read()) | |
| else: # Image file (PIL Image from Gradio) | |
| file_path += ".png" | |
| file.save(file_path) | |
| # Check file type | |
| ext = file_path.rsplit(".", 1)[1].lower() | |
| if ext in ["png", "jpg", "jpeg"]: | |
| csv_path = image_to_signal(file) | |
| elif ext == "csv": | |
| csv_path = file_path | |
| else: | |
| return "Unsupported file type. Use CSV, PNG, or JPG." | |
| # Run prediction | |
| results = model_predict(csv_path) | |
| # Format output | |
| output = "" | |
| for result in results: | |
| output += f"File: {result['file']}\n" | |
| for key, value in result["results"].items(): | |
| if value: | |
| output += f"{key}: {value}\n" | |
| return output | |
| except Exception as e: | |
| return f"Error: {str(e)}\n{traceback.format_exc()}" | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_ecg, | |
| inputs=gr.File(label="Upload ECG Image (PNG/JPG) or CSV"), | |
| outputs=gr.Textbox(label="Classification Results"), | |
| title="ECG Arrhythmia Classification", | |
| description="Upload an ECG image (PNG/JPG) or CSV file to classify arrhythmias. Images will be converted to CSV before processing.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |