# Instructions

In this tutorial, we will perform multi-label classification using an ECG-FM model finetuned on the [MIMIC-IV-ECG v1.0 dataset](https://physionet.org/content/mimic-iv-ecg/1.0/). It outlines the data and model loading, as well as inference, same-sample prediction aggregation, and visualizations for embeddings and saliency maps.

ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.

This is segment the ECG into inputs of 5 s and perform a label-specific aggregation of the predictions from each sample

This document serves largely as a quickstart introduction. Much of this functionality is also available via the [fairseq-signals scripts](https://github.com/bowang-lab/ECG-FM/blob/main/notebooks/infer_cli.ipynb), as well the [ECG-FM scripts](https://github.com/bowang-lab/ECG-FM/tree/main/scripts).

## Installation

Begin by cloning [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the installation section in the top-level README. For example, the following commands are sufficient at the present moment:
```
# Creating `fairseq` environment:
conda create --name fairseq python=3.10.6
source activate fairseq
git clone https://github.com/Jwoo5/fairseq-signals
cd fairseq-signals
python3 -m pip install pip==24.0
python3 -m pip install -e .
```

In [None]:
# You may require the following imports depending on what functionality you run
!pip install huggingface-hub
!pip install pandas
!pip install ecg-transform==0.1.3
!pip install umap-learn
!pip install plotly

In [102]:
import os

root = os.path.dirname(os.getcwd())

## Download checkpoints

Checkpoints are available on [HuggingFace](https://huggingface.co/wanglab/ecg-fm). The finetuned model be downloaded using the following command:

In [None]:
import os
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
 repo_id='wanglab/ecg-fm',
 filename='mimic_iv_ecg_finetuned.yaml',
 local_dir=os.path.join(root, 'ckpts'),
)

# Inference

In [None]:
ckpt_path: str = os.path.join(root, 'ckpts/mimic_iv_ecg_finetuned.pt')
assert os.path.isfile(ckpt_path)

device: str = 'cuda'
batch_size: int = 16
num_workers: int = 0

extract_saliency: bool = True

In [None]:
from typing import Any, List

def to_list(obj: Any) -> List[Any]:
 if isinstance(obj, list):
 return obj

 if isinstance(obj, (np.ndarray, set, dict)):
 return list(obj)

 return [obj]

file_paths = [
 os.path.join(root, 'data/code_15/org', file) for file in \
 os.listdir(os.path.join(root, 'data/code_15/org'))
]
file_paths = to_list(file_paths)
file_paths

## Prepare data

To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) using our [end-to-end data preprocessing pipeline](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is also helpful if looking to perform inference using your own dataset, where there are already preprocessing scripts implemented for several public datasets.

In [None]:
from typing import List
from itertools import chain

from scipy.io import loadmat

import numpy as np

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from ecg_transform.inp import ECGInput, ECGInputSchema
from ecg_transform.sample import ECGMetadata, ECGSample
from ecg_transform.t.base import ECGTransform
from ecg_transform.t.common import (
 HandleConstantLeads,
 LinearResample,
 ReorderLeads,
)
from ecg_transform.t.scale import Standardize
from ecg_transform.t.cut import SegmentNonoverlapping

class ECGFMDataset(Dataset):
 def __init__(
 self,
 schema,
 transforms,
 file_paths,
 ):
 self.schema = schema
 self.transforms = transforms
 self.file_paths = file_paths

 def __len__(self):
 return len(self.file_paths)

 def __getitem__(self, idx):
 mat = loadmat(self.file_paths[idx])
 metadata = ECGMetadata(
 sample_rate=int(mat['org_sample_rate'][0, 0]),
 num_samples=mat['feats'].shape[1],
 lead_names=['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'],
 unit=None,
 input_start=0,
 input_end=mat['feats'].shape[1],
 )
 metadata.file = self.file_paths[idx]
 inp = ECGInput(mat['feats'], metadata)
 sample = ECGSample(
 inp,
 self.schema,
 self.transforms,
 )
 source = torch.from_numpy(sample.out).float()

 return source, inp

def collate_fn(inps):
 sample_ids = list(
 chain.from_iterable([[inp[1]]*inp[0].shape[0] for inp in inps])
 )
 return torch.concatenate([inp[0] for inp in inps]), sample_ids

def file_paths_to_loader(
 file_paths: List[str],
 schema: ECGInputSchema,
 transforms: List[ECGTransform],
 batch_size = 64,
 num_workers = 7,
):
 dataset = ECGFMDataset(
 schema,
 transforms,
 file_paths,
 )

 return DataLoader(
 dataset,
 batch_size=batch_size,
 num_workers=num_workers,
 pin_memory=True,
 sampler=None,
 shuffle=False,
 collate_fn=collate_fn,
 drop_last=False,
 )

In [None]:
ECG_FM_LEAD_ORDER = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
SAMPLE_RATE = 500
N_SAMPLES = SAMPLE_RATE*5

label_def = pd.read_csv(
 os.path.join(root, 'data/mimic_iv_ecg/labels/label_def.csv'),
 index_col='name',
)
label_names = label_def.index.to_list()
label_names

In [None]:
AGG_METHODS = {
 'Poor data quality': 'max',
 'Sinus rhythm': 'mean',
 'Premature ventricular contraction': 'max',
 'Tachycardia': 'mean',
 'Ventricular tachycardia': 'max',
 'Supraventricular tachycardia with aberrancy': 'max',
 'Bradycardia': 'mean',
 'Infarction': 'mean',
 'Atrioventricular block': 'mean',
 'Right bundle branch block': 'mean',
 'Left bundle branch block': 'mean',
 'Electronic pacemaker': 'max',
 'Atrial fibrillation': 'mean',
 'Atrial flutter': 'mean',
 'Accessory pathway conduction': 'mean',
 '1st degree atrioventricular block': 'mean',
 'Bifascicular block': 'mean',
}

ECG_FM_SCHEMA = ECGInputSchema(
 sample_rate=SAMPLE_RATE,
 expected_lead_order=ECG_FM_LEAD_ORDER,
 required_num_samples=N_SAMPLES,
)

ECG_FM_TRANSFORMS = [
 ReorderLeads(
 expected_order=ECG_FM_LEAD_ORDER,
 missing_lead_strategy='raise',
 ),
 LinearResample(desired_sample_rate=SAMPLE_RATE),
 HandleConstantLeads(strategy='zero'),
 Standardize(),
 SegmentNonoverlapping(segment_length=N_SAMPLES),
]

loader = file_paths_to_loader(
 file_paths,
 ECG_FM_SCHEMA,
 ECG_FM_TRANSFORMS,
 batch_size=batch_size,
 num_workers=num_workers,
)

## Load model

In [None]:
from typing import Dict, List, Optional, Tuple, Type, Union
from collections import OrderedDict

import numpy as np
import pandas as pd

import torch

from fairseq_signals.models import build_model_from_checkpoint
from fairseq_signals.models.classification.ecg_transformer_classifier import (
 ECGTransformerClassificationModel
)

In [None]:
model: ECGTransformerClassificationModel = build_model_from_checkpoint(
 checkpoint_path=ckpt_path
)

# Forcibly enable the return of attention weights for saliency maps
if extract_saliency:
 model.encoder.encoder.need_weights = extract_saliency
 for layer in model.encoder.encoder.layers:
 layer.need_weights = extract_saliency

model.eval()
model.to(device)

## Infer

In [None]:
def encoder_out_to_emb(x, device='cpu'):
 # fairseq_signals/models/classification/ecg_transformer_classifier.py
 return torch.div(x.sum(dim=1), (x != 0).sum(dim=1))

def infer(
 model,
 loader,
 device,
 extract_saliency: bool = True,
):
 inps = []
 sources = []
 logits = []
 embs = []
 saliency = []
 file_names = []
 for source, inp in loader:
 source = source.to(device)
 out = model(source=source)
 inps.extend(inp)
 sources.append(source)
 logits.append(out['out'])
 embs.append(encoder_out_to_emb(out['encoder_out']))
 saliency.append(out['saliency'])
 file_names.extend([i.meta.file for i in inp])

 # Handle predictions
 pred = torch.sigmoid(torch.concatenate(logits)).detach().cpu().numpy()
 pred = pd.DataFrame(pred, columns=label_names, index=file_names)

 results = {
 'inps': inps,
 'sources': torch.concatenate(sources).detach().cpu().numpy(),
 'embs': torch.concatenate(embs).detach().cpu().numpy(),
 'pred': pred,
 }

 # Handle saliency
 if extract_saliency:
 saliency = torch.concatenate(saliency).detach()
 attn = saliency[:, -1] # Consider only the last attention layer
 results['attn_max'] = attn.max(axis=2).values.squeeze().cpu().detach().numpy()

 return results

In [None]:
results = infer(model, loader, device)

In [None]:
pred = results['pred']
print(f"Number of 5 s segment predictions: {len(pred)}.")
pred

# Result handling

## Prediction aggregation

In [None]:
pred_agg = pred.groupby(pred.index).agg(AGG_METHODS).astype(float)
print(f"Number of sample-aggregated predictions: {len(pred_agg)}.")
pred_agg

## Visualizing embeddings

In [None]:
import matplotlib.pyplot as plt
import umap

reducer = umap.UMAP(n_neighbors=3, min_dist=0.1, n_components=2, random_state=42)
embs_2d = reducer.fit_transform(results['embs'])

# Generate a color map
sample_identifier = pred.index.to_series()
unique_values = sample_identifier.unique()
colors = plt.colormaps.get_cmap('tab20') # Use a colormap with enough distinct colors
color_map = {val: colors(i) for i, val in enumerate(unique_values)}
colored_items = sample_identifier.map(color_map)

# Plot the 2D UMAP visualization
plt.scatter(
 embs_2d[:, 0],
 embs_2d[:, 1],
 s=30,
 alpha=0.9,
 color=colored_items.values,
 rasterized=True,
)

# Remove axis labels and grid
plt.xticks([])
plt.yticks([])
plt.grid(False)

More fitting when visualizing many embeddings:
```
import matplotlib.pyplot as plt
import umap
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) # Better when there are more embeddings

# Plot the 2D UMAP visualization
plt.scatter(
 embs_2d[:, 0],
 embs_2d[:, 1],
 s=1,
 alpha=0.9,
 rasterized=True,
)

# Remove axis labels and grid
plt.xticks([])
plt.yticks([])
plt.grid(False)
```

## Saliency maps

In [None]:
from typing import Tuple

import numpy as np

from scipy.ndimage import map_coordinates

import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [None]:
sample_idx = 0

saliency_lead = 'II'
lead_ind = ECG_FM_LEAD_ORDER.index(saliency_lead)

In [None]:
signal = results['sources'][sample_idx, lead_ind]
attn_max = results['attn_max'][sample_idx]

In [None]:
def blend_colors_hex(start_color: str, end_color: str, activations: np.ndarray) -> np.ndarray:
 """
 Blends between two colors based on an array of blend factors.

 Parameters
 ----------
 start_color : str
 Hexadecimal color code for the start color.
 end_color : str
 Hexadecimal color code for the end color.
 activations : np.ndarray
 An array of blend factors where 0 corresponds to the start color and 1 to the end color.

 Returns
 -------
 np.ndarray
 An array of hexadecimal color codes resulting from the blends.

 Raises
 ------
 ValueError
 If any of the input blend factors are not within the range [0, 1].
 """
 if np.any((activations < 0) | (activations > 1)):
 raise ValueError("All blend factors must be between 0 and 1.")

 # Convert hexadecimal to RGB
 def hex_to_rgb(hex_color: str) -> Tuple[int]:
 return tuple(int(hex_color[i: i+2], 16) for i in (1, 3, 5))

 # Get RGB tuples
 start_rgb = np.array(hex_to_rgb(start_color))
 end_rgb = np.array(hex_to_rgb(end_color))

 # Blend RGB values
 blended_rgb = np.outer(1 - activations, start_rgb) + np.outer(activations, end_rgb)

 # Convert blended RGB back to hex codes
 return blended_rgb / 255

def colored_line_segments(data: np.ndarray, colors: np.ndarray, ax=None, **kwargs):
 """
 Plots line segments based on the provided data points, with each segment
 colored according to the corresponding color specification in `colors`.

 Parameters
 ----------
 data : np.ndarray
 Array of y-values for the line segments.
 colors : np.ndarray
 Array of colors, each color applied to the corresponding line segment
 between points i and i+1.

 Raises
 ------
 ValueError
 If the `colors` array does not have exactly one less element than the `data` array,
 as each segment needs a unique color.

 Returns
 -------
 None
 """
 if len(colors) != len(data) - 1:
 raise ValueError("Colors array must have one fewer elements than data array.")

 if ax is None:
 for i in range(len(data) - 1):
 plt.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)
 else:
 for i in range(len(data) - 1):
 ax.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)

def prep_saliency_values(attn_max, target_sample_length):
 # Resample to original sample size
 new_dims = [
 np.linspace(0, original_length-1, new_length) \
 for original_length, new_length in \
 zip(attn_max.shape, (target_sample_length - 1,))
 ]
 coords = np.meshgrid(*new_dims, indexing='ij')
 attn_max = map_coordinates(attn_max, coords)

 # Min-max normalization
 attn_max = attn_max - attn_max.min()
 attn_max = attn_max/attn_max.max()

 return attn_max

saliency_prepped = prep_saliency_values(
 attn_max.ravel(),
 attn_max.shape[0] * signal.shape[-1],
)
saliency_colors = blend_colors_hex('#0047AB', '#DC143C', saliency_prepped)
saliency_colors = (saliency_colors*255).astype(int)

# Define a custom colorscale from blue to red
colorscale = [[0, 'blue'], [1, 'red']] # Simple gradient from blue to red

time = np.arange(2500)

# Create the figure
fig = go.Figure()
y_values = signal[:-1]
for i in range(len(y_values) - 1):
 fig.add_trace(
 go.Scatter(
 x=[time[i], time[i + 1]],
 y=[y_values[i], y_values[i + 1]],
 mode='lines',
 line=dict(color='rgb({},{},{})'.format(*saliency_colors[i]), width=2),
 showlegend=False # Avoid cluttering the legend
 )
 )
fig['layout']['yaxis'].update(autorange = True)
fig['layout']['xaxis'].update(autorange = True)

fig.show()