In [1]:
import torch
import numpy as np
from models_cifm.cifm import CIFM
import scanpy as sc

### 1. load model

In [2]:
def load_model():
    args_model = torch.load('./models_cifm/args.pt')
    device = 'cpu' # or 'cuda' if you have a GPU
    model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)
    model.channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')
    model.eval()
    return model
model = load_model()

model.safetensors:   0%|          | 0.00/569M [00:00<?, ?B/s]

### 2. load and preprocess sample adata
- some requirements for adata:
- ```adata.X```: need to the raw count
- ```adata.obsm['spatial']```: the coordinates of cells in the unit of micrometer
- if in a different unit, it might result in a weird geometric graph: we use a radius 20 (micrometer) to construct the geometric graph in the model, so a different unit might result in a overly sparse or dense graph

In [3]:
adata = sc.read_h5ad('./adata.h5ad')
adata.layers['counts'] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata

AnnData object with n_obs × n_vars = 24844 × 18289
    obs: 'in_tissue'
    var: 'feature_types', 'genome', 'gene_names'
    uns: 'log1p'
    obsm: 'spatial'
    layers: 'counts'

### 3. match feature channels
- we need a list which maps feature channels to ensemble ids: ```channel2ensembl_ids_target```
- format: ```channel2ensembl_ids_target = [[ensemblid1_for_channel1, ensemblid2_for_channel1, ...], [ensemblid1_for_channel2, ensemblid2_for_channel2, ...], ...]```
- one channel could correspond to multiple ensemble ids, e.g., when in your original data the channels are annotated with gene names
- you can use BioMart to map your gene name into one or multiple ensemble ids

In [4]:
channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
model.channel_matching(channel2ensembl_ids_target, model.channel2ensembl_ids_source)

matching 18289 gene channels out of 18289 ; unmatched channels: []


### 4. embed the microenvironments centered at each cell

In [5]:
with torch.no_grad():
    embeddings = model.embed(adata)
embeddings, embeddings.shape

(tensor([[-0.4326, -0.8625,  0.1121,  ...,  0.4980,  0.3855, -0.1965],
         [-0.6833, -0.9950,  0.1927,  ..., -0.2064,  0.6193,  0.0387],
         [-0.2099, -0.9877,  0.3462,  ...,  0.2102,  0.6807, -0.2155],
         ...,
         [-0.0187, -0.8444,  0.3058,  ...,  0.1030,  0.8362, -0.1859],
         [-0.5535, -0.8201,  0.7805,  ..., -0.1402,  0.5221, -0.3520],
         [-0.9339, -0.8467,  0.0600,  ...,  0.0406,  0.3608,  0.3418]]),
 torch.Size([24844, 1024]))

### 5. infer the potential gene expressions at certain locations

In [6]:
# we here randomly generate the locations for the cells just for demonstration
target_locs = np.random.rand(10, 2)
x_min, x_max = adata.obsm['spatial'][:, 0].min(), adata.obsm['spatial'][:, 0].max()
y_min, y_max = adata.obsm['spatial'][:, 1].min(), adata.obsm['spatial'][:, 1].max()
target_locs[:, 0] = target_locs[:, 0] * (x_max - x_min) + x_min
target_locs[:, 1] = target_locs[:, 1] * (y_max - y_min) + y_min

with torch.no_grad():
    expressions = model.predict_cells_at_locations(adata, target_locs)
expressions, expressions.shape

(tensor([[0.0000, 0.0000, 2.8781,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 2.9699,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 3.2570,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 torch.Size([10, 18289]))

In [7]:
# you can convert it into normalize counts
counts_normalized = np.exp(expressions) - 1
counts_normalized = counts_normalized / counts_normalized.sum(axis=1, keepdims=True)
counts_normalized, counts_normalized.shape

(tensor([[0.0000, 0.0000, 0.0002,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0002,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0003,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 torch.Size([10, 18289]))