Yuning You commited on
Commit
14bd8e0
·
1 Parent(s): 6bd1fb8
Files changed (1) hide show
  1. test.ipynb +9 -6
test.ipynb CHANGED
@@ -100,10 +100,14 @@
100
  }
101
  ],
102
  "source": [
103
- "args_model = torch.load('./models_cifm/args.pt')\n",
104
- "device = 'cpu' # or 'cuda' if you have a GPU\n",
105
- "model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
106
- "model.eval()"
 
 
 
 
107
  ]
108
  },
109
  {
@@ -172,8 +176,7 @@
172
  ],
173
  "source": [
174
  "channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
175
- "channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')\n",
176
- "model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
177
  ]
178
  },
179
  {
 
100
  }
101
  ],
102
  "source": [
103
+ "def load_model():\n",
104
+ " args_model = torch.load('./models_cifm/args.pt')\n",
105
+ " device = 'cpu' # or 'cuda' if you have a GPU\n",
106
+ " model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
107
+ " model.channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')\n",
108
+ " model.eval()\n",
109
+ " return model\n",
110
+ "model = load_model()"
111
  ]
112
  },
113
  {
 
176
  ],
177
  "source": [
178
  "channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
179
+ "model.channel_matching(channel2ensembl_ids_target, model.channel2ensembl_ids_source)"
 
180
  ]
181
  },
182
  {