Fix weight loading
Browse files
app.py
CHANGED
|
@@ -54,7 +54,7 @@ class Ensemble(nn.Module):
|
|
| 54 |
|
| 55 |
|
| 56 |
checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
|
| 57 |
-
weights = [torch.load(ckpt)["state_dict"] for ckpt in checkpoints]
|
| 58 |
weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
|
| 59 |
models = [Net2D(wt) for wt in weights]
|
| 60 |
ensemble = Ensemble(models).eval()
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
|
| 57 |
+
weights = [torch.load(ckpt, map_location=torch.device("cpu"))["state_dict"] for ckpt in checkpoints]
|
| 58 |
weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
|
| 59 |
models = [Net2D(wt) for wt in weights]
|
| 60 |
ensemble = Ensemble(models).eval()
|