Yuning You
commited on
Commit
·
14bd8e0
1
Parent(s):
6bd1fb8
update
Browse files- test.ipynb +9 -6
test.ipynb
CHANGED
@@ -100,10 +100,14 @@
|
|
100 |
}
|
101 |
],
|
102 |
"source": [
|
103 |
-
"
|
104 |
-
"
|
105 |
-
"
|
106 |
-
"model.
|
|
|
|
|
|
|
|
|
107 |
]
|
108 |
},
|
109 |
{
|
@@ -172,8 +176,7 @@
|
|
172 |
],
|
173 |
"source": [
|
174 |
"channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
|
175 |
-
"
|
176 |
-
"model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
|
177 |
]
|
178 |
},
|
179 |
{
|
|
|
100 |
}
|
101 |
],
|
102 |
"source": [
|
103 |
+
"def load_model():\n",
|
104 |
+
" args_model = torch.load('./models_cifm/args.pt')\n",
|
105 |
+
" device = 'cpu' # or 'cuda' if you have a GPU\n",
|
106 |
+
" model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
|
107 |
+
" model.channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')\n",
|
108 |
+
" model.eval()\n",
|
109 |
+
" return model\n",
|
110 |
+
"model = load_model()"
|
111 |
]
|
112 |
},
|
113 |
{
|
|
|
176 |
],
|
177 |
"source": [
|
178 |
"channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
|
179 |
+
"model.channel_matching(channel2ensembl_ids_target, model.channel2ensembl_ids_source)"
|
|
|
180 |
]
|
181 |
},
|
182 |
{
|