Yuning You
commited on
Commit
·
be4e6f5
1
Parent(s):
6788772
update
Browse files- adata.h5ad +2 -2
- models/__init__.py +0 -8
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/cifm.cpython-311.pyc +0 -0
- models/__pycache__/egnn_void_invariant.cpython-311.pyc +0 -0
- models/__pycache__/mlp_and_gnn.cpython-311.pyc +0 -0
- models/cifm.py +7 -6
- models/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc +0 -0
- test.ipynb +234 -0
adata.h5ad
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4f8b3caccbb84f31fa795ad012d22c28068d2fc8a8c1a28d7b034483a168e08
|
3 |
+
size 90959812
|
models/__init__.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# from models.schnet import SchNetModel
|
2 |
-
# from models.dimenet import DimeNetPPModel
|
3 |
-
# from models.spherenet import SphereNetModel
|
4 |
-
# from models.egnn import EGNNModel
|
5 |
-
# from models.gvpgnn import GVPGNNModel
|
6 |
-
# from models.tfn import TFNModel
|
7 |
-
# from models.mace import MACEModel
|
8 |
-
from models.egnn_void_invariant import VIEGNNModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (302 Bytes)
|
|
models/__pycache__/cifm.cpython-311.pyc
ADDED
Binary file (10.1 kB). View file
|
|
models/__pycache__/egnn_void_invariant.cpython-311.pyc
CHANGED
Binary files a/models/__pycache__/egnn_void_invariant.cpython-311.pyc and b/models/__pycache__/egnn_void_invariant.cpython-311.pyc differ
|
|
models/__pycache__/mlp_and_gnn.cpython-311.pyc
CHANGED
Binary files a/models/__pycache__/mlp_and_gnn.cpython-311.pyc and b/models/__pycache__/mlp_and_gnn.cpython-311.pyc differ
|
|
models/cifm.py
CHANGED
@@ -2,10 +2,9 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from torch_geometric.nn import radius_graph
|
4 |
import scanpy as sc
|
5 |
-
from main import Model
|
6 |
from huggingface_hub import PyTorchModelHubMixin
|
7 |
from models.mlp_and_gnn import MLPBiasFree
|
8 |
-
from models import VIEGNNModel
|
9 |
|
10 |
|
11 |
class CIFM(
|
@@ -59,10 +58,9 @@ class CIFM(
|
|
59 |
embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
|
60 |
embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
|
61 |
embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
|
62 |
-
else:
|
63 |
-
unmatched_channels.append(ensembl)
|
64 |
|
65 |
if len(embs_in) == 0:
|
|
|
66 |
continue
|
67 |
|
68 |
embs_in = torch.stack(embs_in).mean(dim=0)
|
@@ -98,8 +96,11 @@ class CIFM(
|
|
98 |
|
99 |
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
|
100 |
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
|
|
|
|
|
|
|
101 |
expressions_dec[dropouts_dec<=0.5] = 0
|
102 |
-
return expressions_dec
|
103 |
|
104 |
def embed(self, adata):
|
105 |
device = next(self.parameters()).device
|
@@ -119,7 +120,7 @@ class CIFM(
|
|
119 |
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
|
120 |
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
|
121 |
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
|
122 |
-
coordinates = torch.cat([coordinates,
|
123 |
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
|
124 |
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
|
125 |
idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)
|
|
|
2 |
import torch.nn as nn
|
3 |
from torch_geometric.nn import radius_graph
|
4 |
import scanpy as sc
|
|
|
5 |
from huggingface_hub import PyTorchModelHubMixin
|
6 |
from models.mlp_and_gnn import MLPBiasFree
|
7 |
+
from models.egnn_void_invariant import VIEGNNModel
|
8 |
|
9 |
|
10 |
class CIFM(
|
|
|
58 |
embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
|
59 |
embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
|
60 |
embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
|
|
|
|
|
61 |
|
62 |
if len(embs_in) == 0:
|
63 |
+
unmatched_channels += ensembls
|
64 |
continue
|
65 |
|
66 |
embs_in = torch.stack(embs_in).mean(dim=0)
|
|
|
96 |
|
97 |
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
|
98 |
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
|
99 |
+
|
100 |
+
# import pdb ; pdb.set_trace()
|
101 |
+
|
102 |
expressions_dec[dropouts_dec<=0.5] = 0
|
103 |
+
return expressions_dec
|
104 |
|
105 |
def embed(self, adata):
|
106 |
device = next(self.parameters()).device
|
|
|
120 |
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
|
121 |
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
|
122 |
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
|
123 |
+
coordinates = torch.cat([coordinates, locations], dim=0)
|
124 |
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
|
125 |
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
|
126 |
idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)
|
models/layers/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/models/layers/__pycache__/__init__.cpython-311.pyc and b/models/layers/__pycache__/__init__.cpython-311.pyc differ
|
|
models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc
CHANGED
Binary files a/models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc and b/models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc differ
|
|
test.ipynb
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"from models.cifm import CIFM\n",
|
12 |
+
"import scanpy as sc"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [
|
20 |
+
{
|
21 |
+
"data": {
|
22 |
+
"text/plain": [
|
23 |
+
"CIFM(\n",
|
24 |
+
" (gene_encoder): MLPBiasFree(\n",
|
25 |
+
" (layers): ModuleList(\n",
|
26 |
+
" (0): Linear(in_features=18289, out_features=1024, bias=False)\n",
|
27 |
+
" (1-3): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
|
28 |
+
" )\n",
|
29 |
+
" (layernorms): ModuleList(\n",
|
30 |
+
" (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
|
31 |
+
" )\n",
|
32 |
+
" (activation): ReLU()\n",
|
33 |
+
" )\n",
|
34 |
+
" (model): VIEGNNModel(\n",
|
35 |
+
" (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n",
|
36 |
+
" (convs): ModuleList(\n",
|
37 |
+
" (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n",
|
38 |
+
" )\n",
|
39 |
+
" (pred): MLPBiasFree(\n",
|
40 |
+
" (layers): ModuleList(\n",
|
41 |
+
" (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n",
|
42 |
+
" )\n",
|
43 |
+
" (layernorms): ModuleList(\n",
|
44 |
+
" (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
|
45 |
+
" )\n",
|
46 |
+
" (activation): ReLU()\n",
|
47 |
+
" )\n",
|
48 |
+
" )\n",
|
49 |
+
" (mask_cell_decoder): VIEGNNModel(\n",
|
50 |
+
" (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n",
|
51 |
+
" (convs): ModuleList(\n",
|
52 |
+
" (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n",
|
53 |
+
" )\n",
|
54 |
+
" (pred): MLPBiasFree(\n",
|
55 |
+
" (layers): ModuleList(\n",
|
56 |
+
" (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n",
|
57 |
+
" )\n",
|
58 |
+
" (layernorms): ModuleList(\n",
|
59 |
+
" (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
|
60 |
+
" )\n",
|
61 |
+
" (activation): ReLU()\n",
|
62 |
+
" )\n",
|
63 |
+
" )\n",
|
64 |
+
" (mask_cell_expression): MLPBiasFree(\n",
|
65 |
+
" (layers): ModuleList(\n",
|
66 |
+
" (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
|
67 |
+
" (3): Linear(in_features=1024, out_features=18289, bias=False)\n",
|
68 |
+
" )\n",
|
69 |
+
" (layernorms): ModuleList(\n",
|
70 |
+
" (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
|
71 |
+
" )\n",
|
72 |
+
" (activation): ReLU()\n",
|
73 |
+
" )\n",
|
74 |
+
" (mask_cell_dropout): MLPBiasFree(\n",
|
75 |
+
" (layers): ModuleList(\n",
|
76 |
+
" (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
|
77 |
+
" (3): Linear(in_features=1024, out_features=18289, bias=False)\n",
|
78 |
+
" )\n",
|
79 |
+
" (layernorms): ModuleList(\n",
|
80 |
+
" (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
|
81 |
+
" )\n",
|
82 |
+
" (activation): ReLU()\n",
|
83 |
+
" )\n",
|
84 |
+
" (mask_embedding): Embedding(1, 1024)\n",
|
85 |
+
" (relu): ReLU()\n",
|
86 |
+
" (sigmoid): Sigmoid()\n",
|
87 |
+
")"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"execution_count": 2,
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "execute_result"
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"args_model = torch.load('./model_files/args.pt')\n",
|
97 |
+
"model = CIFM.from_pretrained('ynyou/CIFM', args=args_model)\n",
|
98 |
+
"model.eval()"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": 3,
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"data": {
|
108 |
+
"text/plain": [
|
109 |
+
"AnnData object with n_obs × n_vars = 24844 × 18289\n",
|
110 |
+
" obs: 'in_tissue'\n",
|
111 |
+
" var: 'feature_types', 'genome', 'gene_names'\n",
|
112 |
+
" uns: 'log1p'\n",
|
113 |
+
" obsm: 'spatial'\n",
|
114 |
+
" layers: 'counts'"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
"execution_count": 3,
|
118 |
+
"metadata": {},
|
119 |
+
"output_type": "execute_result"
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"source": [
|
123 |
+
"channel2ensembl = torch.load('./model_files/channel2ensembl.pt')\n",
|
124 |
+
"adata = sc.read_h5ad('./adata.h5ad')\n",
|
125 |
+
"adata.layers['counts'] = adata.X.copy()\n",
|
126 |
+
"sc.pp.normalize_total(adata)\n",
|
127 |
+
"sc.pp.log1p(adata)\n",
|
128 |
+
"adata"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 4,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [
|
136 |
+
{
|
137 |
+
"name": "stdout",
|
138 |
+
"output_type": "stream",
|
139 |
+
"text": [
|
140 |
+
"matching 18289 gene channels out of 18289 unmatched channels: []\n"
|
141 |
+
]
|
142 |
+
}
|
143 |
+
],
|
144 |
+
"source": [
|
145 |
+
"model.channel_matching(adata, channel2ensembl)"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": 5,
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [
|
153 |
+
{
|
154 |
+
"data": {
|
155 |
+
"text/plain": [
|
156 |
+
"(tensor([[-0.4132, -0.9847, 0.1647, ..., -0.8351, -0.8177, -1.3235],\n",
|
157 |
+
" [ 0.8701, 0.0967, -0.3676, ..., 0.2687, -1.4821, 0.1605],\n",
|
158 |
+
" [-0.5178, -0.4442, -0.0862, ..., -0.7446, -0.5761, -0.5571],\n",
|
159 |
+
" ...,\n",
|
160 |
+
" [ 1.2264, 1.2326, 0.2791, ..., 0.8018, -1.4069, 1.4567],\n",
|
161 |
+
" [ 0.6699, -0.6107, 0.2450, ..., -0.1975, -0.6034, -0.6608],\n",
|
162 |
+
" [-1.9240, -1.8125, -0.0766, ..., -0.2799, -0.0217, -2.2051]]),\n",
|
163 |
+
" torch.Size([13898, 1024]))"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
"execution_count": 5,
|
167 |
+
"metadata": {},
|
168 |
+
"output_type": "execute_result"
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"source": [
|
172 |
+
"with torch.no_grad():\n",
|
173 |
+
" embeddings = model.embed(adata)\n",
|
174 |
+
"embeddings, embeddings.shape"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": 5,
|
180 |
+
"metadata": {},
|
181 |
+
"outputs": [
|
182 |
+
{
|
183 |
+
"data": {
|
184 |
+
"text/plain": [
|
185 |
+
"(tensor([[0.0000, 0.0000, 0.8603, ..., 0.0000, 0.0000, 0.0000],\n",
|
186 |
+
" [0.0000, 0.0000, 0.6644, ..., 0.0000, 0.0000, 0.0000],\n",
|
187 |
+
" [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
188 |
+
" ...,\n",
|
189 |
+
" [0.0000, 0.0000, 0.9809, ..., 0.0000, 0.0000, 0.0000],\n",
|
190 |
+
" [0.6641, 0.0000, 0.6858, ..., 0.0000, 0.0000, 0.0000],\n",
|
191 |
+
" [0.4999, 0.0000, 0.5311, ..., 0.0000, 0.0000, 0.0000]]),\n",
|
192 |
+
" torch.Size([10, 18289]))"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
"execution_count": 5,
|
196 |
+
"metadata": {},
|
197 |
+
"output_type": "execute_result"
|
198 |
+
}
|
199 |
+
],
|
200 |
+
"source": [
|
201 |
+
"rand_loc = np.random.rand(10, 2)\n",
|
202 |
+
"x_min, x_max = adata.obsm['spatial'][:, 0].min(), adata.obsm['spatial'][:, 0].max()\n",
|
203 |
+
"y_min, y_max = adata.obsm['spatial'][:, 1].min(), adata.obsm['spatial'][:, 1].max()\n",
|
204 |
+
"rand_loc[:, 0] = rand_loc[:, 0] * (x_max - x_min) + x_min\n",
|
205 |
+
"rand_loc[:, 1] = rand_loc[:, 1] * (y_max - y_min) + y_min\n",
|
206 |
+
"\n",
|
207 |
+
"with torch.no_grad():\n",
|
208 |
+
" expressions = model.predict_cells_at_locations(adata, rand_loc)\n",
|
209 |
+
"expressions, expressions.shape"
|
210 |
+
]
|
211 |
+
}
|
212 |
+
],
|
213 |
+
"metadata": {
|
214 |
+
"kernelspec": {
|
215 |
+
"display_name": "Python 3 (ipykernel)",
|
216 |
+
"language": "python",
|
217 |
+
"name": "python3"
|
218 |
+
},
|
219 |
+
"language_info": {
|
220 |
+
"codemirror_mode": {
|
221 |
+
"name": "ipython",
|
222 |
+
"version": 3
|
223 |
+
},
|
224 |
+
"file_extension": ".py",
|
225 |
+
"mimetype": "text/x-python",
|
226 |
+
"name": "python",
|
227 |
+
"nbconvert_exporter": "python",
|
228 |
+
"pygments_lexer": "ipython3",
|
229 |
+
"version": "3.11.10"
|
230 |
+
}
|
231 |
+
},
|
232 |
+
"nbformat": 4,
|
233 |
+
"nbformat_minor": 2
|
234 |
+
}
|