ianpan commited on
Commit
3dfe59a
·
1 Parent(s): a1306dd

Fix weight loading

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -54,7 +54,7 @@ class Ensemble(nn.Module):
54
 
55
 
56
  checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
57
- weights = [torch.load(ckpt)["state_dict"] for ckpt in checkpoints]
58
  weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
59
  models = [Net2D(wt) for wt in weights]
60
  ensemble = Ensemble(models).eval()
 
54
 
55
 
56
  checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
57
+ weights = [torch.load(ckpt, map_location=torch.device("cpu"))["state_dict"] for ckpt in checkpoints]
58
  weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
59
  models = [Net2D(wt) for wt in weights]
60
  ensemble = Ensemble(models).eval()