Yuning You commited on
Commit
be4e6f5
·
1 Parent(s): 6788772
adata.h5ad CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1126447b46abf9c31e77a009423473988738f0295c23a702b0416dd7f56e208d
3
- size 32068372
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4f8b3caccbb84f31fa795ad012d22c28068d2fc8a8c1a28d7b034483a168e08
3
+ size 90959812
models/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # from models.schnet import SchNetModel
2
- # from models.dimenet import DimeNetPPModel
3
- # from models.spherenet import SphereNetModel
4
- # from models.egnn import EGNNModel
5
- # from models.gvpgnn import GVPGNNModel
6
- # from models.tfn import TFNModel
7
- # from models.mace import MACEModel
8
- from models.egnn_void_invariant import VIEGNNModel
 
 
 
 
 
 
 
 
 
models/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (302 Bytes)
 
models/__pycache__/cifm.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
models/__pycache__/egnn_void_invariant.cpython-311.pyc CHANGED
Binary files a/models/__pycache__/egnn_void_invariant.cpython-311.pyc and b/models/__pycache__/egnn_void_invariant.cpython-311.pyc differ
 
models/__pycache__/mlp_and_gnn.cpython-311.pyc CHANGED
Binary files a/models/__pycache__/mlp_and_gnn.cpython-311.pyc and b/models/__pycache__/mlp_and_gnn.cpython-311.pyc differ
 
models/cifm.py CHANGED
@@ -2,10 +2,9 @@ import torch
2
  import torch.nn as nn
3
  from torch_geometric.nn import radius_graph
4
  import scanpy as sc
5
- from main import Model
6
  from huggingface_hub import PyTorchModelHubMixin
7
  from models.mlp_and_gnn import MLPBiasFree
8
- from models import VIEGNNModel
9
 
10
 
11
  class CIFM(
@@ -59,10 +58,9 @@ class CIFM(
59
  embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
60
  embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
61
  embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
62
- else:
63
- unmatched_channels.append(ensembl)
64
 
65
  if len(embs_in) == 0:
 
66
  continue
67
 
68
  embs_in = torch.stack(embs_in).mean(dim=0)
@@ -98,8 +96,11 @@ class CIFM(
98
 
99
  expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
100
  dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
 
 
 
101
  expressions_dec[dropouts_dec<=0.5] = 0
102
- return expressions_dec[mapping]
103
 
104
  def embed(self, adata):
105
  device = next(self.parameters()).device
@@ -119,7 +120,7 @@ class CIFM(
119
  expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
120
  expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
121
  coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
122
- coordinates = torch.cat([coordinates, torch.zeros(locations.shape[0], 2)], dim=1)
123
  coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
124
  edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
125
  idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)
 
2
  import torch.nn as nn
3
  from torch_geometric.nn import radius_graph
4
  import scanpy as sc
 
5
  from huggingface_hub import PyTorchModelHubMixin
6
  from models.mlp_and_gnn import MLPBiasFree
7
+ from models.egnn_void_invariant import VIEGNNModel
8
 
9
 
10
  class CIFM(
 
58
  embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
59
  embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
60
  embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
 
 
61
 
62
  if len(embs_in) == 0:
63
+ unmatched_channels += ensembls
64
  continue
65
 
66
  embs_in = torch.stack(embs_in).mean(dim=0)
 
96
 
97
  expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
98
  dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
99
+
100
+ # import pdb ; pdb.set_trace()
101
+
102
  expressions_dec[dropouts_dec<=0.5] = 0
103
+ return expressions_dec
104
 
105
  def embed(self, adata):
106
  device = next(self.parameters()).device
 
120
  expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
121
  expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
122
  coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
123
+ coordinates = torch.cat([coordinates, locations], dim=0)
124
  coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
125
  edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
126
  idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)
models/layers/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/models/layers/__pycache__/__init__.cpython-311.pyc and b/models/layers/__pycache__/__init__.cpython-311.pyc differ
 
models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc CHANGED
Binary files a/models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc and b/models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc differ
 
test.ipynb ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import numpy as np\n",
11
+ "from models.cifm import CIFM\n",
12
+ "import scanpy as sc"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "data": {
22
+ "text/plain": [
23
+ "CIFM(\n",
24
+ " (gene_encoder): MLPBiasFree(\n",
25
+ " (layers): ModuleList(\n",
26
+ " (0): Linear(in_features=18289, out_features=1024, bias=False)\n",
27
+ " (1-3): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
28
+ " )\n",
29
+ " (layernorms): ModuleList(\n",
30
+ " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
31
+ " )\n",
32
+ " (activation): ReLU()\n",
33
+ " )\n",
34
+ " (model): VIEGNNModel(\n",
35
+ " (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n",
36
+ " (convs): ModuleList(\n",
37
+ " (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n",
38
+ " )\n",
39
+ " (pred): MLPBiasFree(\n",
40
+ " (layers): ModuleList(\n",
41
+ " (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n",
42
+ " )\n",
43
+ " (layernorms): ModuleList(\n",
44
+ " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
45
+ " )\n",
46
+ " (activation): ReLU()\n",
47
+ " )\n",
48
+ " )\n",
49
+ " (mask_cell_decoder): VIEGNNModel(\n",
50
+ " (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n",
51
+ " (convs): ModuleList(\n",
52
+ " (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n",
53
+ " )\n",
54
+ " (pred): MLPBiasFree(\n",
55
+ " (layers): ModuleList(\n",
56
+ " (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n",
57
+ " )\n",
58
+ " (layernorms): ModuleList(\n",
59
+ " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
60
+ " )\n",
61
+ " (activation): ReLU()\n",
62
+ " )\n",
63
+ " )\n",
64
+ " (mask_cell_expression): MLPBiasFree(\n",
65
+ " (layers): ModuleList(\n",
66
+ " (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
67
+ " (3): Linear(in_features=1024, out_features=18289, bias=False)\n",
68
+ " )\n",
69
+ " (layernorms): ModuleList(\n",
70
+ " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
71
+ " )\n",
72
+ " (activation): ReLU()\n",
73
+ " )\n",
74
+ " (mask_cell_dropout): MLPBiasFree(\n",
75
+ " (layers): ModuleList(\n",
76
+ " (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
77
+ " (3): Linear(in_features=1024, out_features=18289, bias=False)\n",
78
+ " )\n",
79
+ " (layernorms): ModuleList(\n",
80
+ " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
81
+ " )\n",
82
+ " (activation): ReLU()\n",
83
+ " )\n",
84
+ " (mask_embedding): Embedding(1, 1024)\n",
85
+ " (relu): ReLU()\n",
86
+ " (sigmoid): Sigmoid()\n",
87
+ ")"
88
+ ]
89
+ },
90
+ "execution_count": 2,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "args_model = torch.load('./model_files/args.pt')\n",
97
+ "model = CIFM.from_pretrained('ynyou/CIFM', args=args_model)\n",
98
+ "model.eval()"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 3,
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "data": {
108
+ "text/plain": [
109
+ "AnnData object with n_obs × n_vars = 24844 × 18289\n",
110
+ " obs: 'in_tissue'\n",
111
+ " var: 'feature_types', 'genome', 'gene_names'\n",
112
+ " uns: 'log1p'\n",
113
+ " obsm: 'spatial'\n",
114
+ " layers: 'counts'"
115
+ ]
116
+ },
117
+ "execution_count": 3,
118
+ "metadata": {},
119
+ "output_type": "execute_result"
120
+ }
121
+ ],
122
+ "source": [
123
+ "channel2ensembl = torch.load('./model_files/channel2ensembl.pt')\n",
124
+ "adata = sc.read_h5ad('./adata.h5ad')\n",
125
+ "adata.layers['counts'] = adata.X.copy()\n",
126
+ "sc.pp.normalize_total(adata)\n",
127
+ "sc.pp.log1p(adata)\n",
128
+ "adata"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 4,
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "name": "stdout",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "matching 18289 gene channels out of 18289 unmatched channels: []\n"
141
+ ]
142
+ }
143
+ ],
144
+ "source": [
145
+ "model.channel_matching(adata, channel2ensembl)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 5,
151
+ "metadata": {},
152
+ "outputs": [
153
+ {
154
+ "data": {
155
+ "text/plain": [
156
+ "(tensor([[-0.4132, -0.9847, 0.1647, ..., -0.8351, -0.8177, -1.3235],\n",
157
+ " [ 0.8701, 0.0967, -0.3676, ..., 0.2687, -1.4821, 0.1605],\n",
158
+ " [-0.5178, -0.4442, -0.0862, ..., -0.7446, -0.5761, -0.5571],\n",
159
+ " ...,\n",
160
+ " [ 1.2264, 1.2326, 0.2791, ..., 0.8018, -1.4069, 1.4567],\n",
161
+ " [ 0.6699, -0.6107, 0.2450, ..., -0.1975, -0.6034, -0.6608],\n",
162
+ " [-1.9240, -1.8125, -0.0766, ..., -0.2799, -0.0217, -2.2051]]),\n",
163
+ " torch.Size([13898, 1024]))"
164
+ ]
165
+ },
166
+ "execution_count": 5,
167
+ "metadata": {},
168
+ "output_type": "execute_result"
169
+ }
170
+ ],
171
+ "source": [
172
+ "with torch.no_grad():\n",
173
+ " embeddings = model.embed(adata)\n",
174
+ "embeddings, embeddings.shape"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 5,
180
+ "metadata": {},
181
+ "outputs": [
182
+ {
183
+ "data": {
184
+ "text/plain": [
185
+ "(tensor([[0.0000, 0.0000, 0.8603, ..., 0.0000, 0.0000, 0.0000],\n",
186
+ " [0.0000, 0.0000, 0.6644, ..., 0.0000, 0.0000, 0.0000],\n",
187
+ " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
188
+ " ...,\n",
189
+ " [0.0000, 0.0000, 0.9809, ..., 0.0000, 0.0000, 0.0000],\n",
190
+ " [0.6641, 0.0000, 0.6858, ..., 0.0000, 0.0000, 0.0000],\n",
191
+ " [0.4999, 0.0000, 0.5311, ..., 0.0000, 0.0000, 0.0000]]),\n",
192
+ " torch.Size([10, 18289]))"
193
+ ]
194
+ },
195
+ "execution_count": 5,
196
+ "metadata": {},
197
+ "output_type": "execute_result"
198
+ }
199
+ ],
200
+ "source": [
201
+ "rand_loc = np.random.rand(10, 2)\n",
202
+ "x_min, x_max = adata.obsm['spatial'][:, 0].min(), adata.obsm['spatial'][:, 0].max()\n",
203
+ "y_min, y_max = adata.obsm['spatial'][:, 1].min(), adata.obsm['spatial'][:, 1].max()\n",
204
+ "rand_loc[:, 0] = rand_loc[:, 0] * (x_max - x_min) + x_min\n",
205
+ "rand_loc[:, 1] = rand_loc[:, 1] * (y_max - y_min) + y_min\n",
206
+ "\n",
207
+ "with torch.no_grad():\n",
208
+ " expressions = model.predict_cells_at_locations(adata, rand_loc)\n",
209
+ "expressions, expressions.shape"
210
+ ]
211
+ }
212
+ ],
213
+ "metadata": {
214
+ "kernelspec": {
215
+ "display_name": "Python 3 (ipykernel)",
216
+ "language": "python",
217
+ "name": "python3"
218
+ },
219
+ "language_info": {
220
+ "codemirror_mode": {
221
+ "name": "ipython",
222
+ "version": 3
223
+ },
224
+ "file_extension": ".py",
225
+ "mimetype": "text/x-python",
226
+ "name": "python",
227
+ "nbconvert_exporter": "python",
228
+ "pygments_lexer": "ipython3",
229
+ "version": "3.11.10"
230
+ }
231
+ },
232
+ "nbformat": 4,
233
+ "nbformat_minor": 2
234
+ }