Spaces:
Running
Running
gracefully bail out collate
Browse files- README.md +5 -0
- mlip_arena/models/__init__.py +20 -11
README.md
CHANGED
@@ -3,7 +3,12 @@ title: MLIP Arena
|
|
3 |
emoji: ⚛
|
4 |
sdk: streamlit
|
5 |
sdk_version: 1.43.2 # The latest supported version
|
|
|
6 |
app_file: serve/app.py
|
|
|
|
|
|
|
|
|
7 |
---
|
8 |
|
9 |
|
|
|
3 |
emoji: ⚛
|
4 |
sdk: streamlit
|
5 |
sdk_version: 1.43.2 # The latest supported version
|
6 |
+
python_version: 3.11
|
7 |
app_file: serve/app.py
|
8 |
+
colorFrom: indigo
|
9 |
+
colorTo: yellow
|
10 |
+
pinned: true
|
11 |
+
short_description: Benchmark machine learning interatomic potential at scale
|
12 |
---
|
13 |
|
14 |
|
mlip_arena/models/__init__.py
CHANGED
@@ -9,12 +9,20 @@ T = TypeVar("T", bound="MLIP")
|
|
9 |
|
10 |
import torch
|
11 |
import yaml
|
|
|
|
|
12 |
from huggingface_hub import PyTorchModelHubMixin
|
13 |
from torch import nn
|
|
|
14 |
|
15 |
-
|
16 |
-
from
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
try:
|
20 |
from prefect.logging import get_run_logger
|
@@ -58,18 +66,18 @@ class MLIP(
|
|
58 |
|
59 |
@classmethod
|
60 |
def from_pretrained(
|
61 |
-
cls
|
62 |
-
pretrained_model_name_or_path:
|
63 |
*,
|
64 |
force_download: bool = False,
|
65 |
-
resume_download:
|
66 |
-
proxies:
|
67 |
-
token:
|
68 |
-
cache_dir:
|
69 |
local_files_only: bool = False,
|
70 |
-
revision:
|
71 |
**model_kwargs,
|
72 |
-
) ->
|
73 |
return super().from_pretrained(
|
74 |
pretrained_model_name_or_path,
|
75 |
force_download=force_download,
|
@@ -108,6 +116,7 @@ class MLIPCalculator(MLIP, Calculator):
|
|
108 |
# Additional initialization if needed
|
109 |
# self.name: str = self.__class__.__name__
|
110 |
from mlip_arena.models.utils import get_freer_device
|
|
|
111 |
self.device = device or get_freer_device()
|
112 |
self.cutoff = cutoff
|
113 |
self.model.to(self.device)
|
|
|
9 |
|
10 |
import torch
|
11 |
import yaml
|
12 |
+
from ase import Atoms
|
13 |
+
from ase.calculators.calculator import Calculator, all_changes
|
14 |
from huggingface_hub import PyTorchModelHubMixin
|
15 |
from torch import nn
|
16 |
+
from typing_extensions import Self
|
17 |
|
18 |
+
try:
|
19 |
+
from mlip_arena.data.collate import collate_fn
|
20 |
+
except ImportError:
|
21 |
+
# Fallback to a dummy function if the import fails
|
22 |
+
def collate_fn(batch: list[Atoms], cutoff: float) -> None:
|
23 |
+
raise ImportError(
|
24 |
+
"collate_fn import failed. Please install the required dependencies."
|
25 |
+
)
|
26 |
|
27 |
try:
|
28 |
from prefect.logging import get_run_logger
|
|
|
66 |
|
67 |
@classmethod
|
68 |
def from_pretrained(
|
69 |
+
cls,
|
70 |
+
pretrained_model_name_or_path: str | Path,
|
71 |
*,
|
72 |
force_download: bool = False,
|
73 |
+
resume_download: bool | None = None,
|
74 |
+
proxies: dict | None = None,
|
75 |
+
token: str | bool | None = None,
|
76 |
+
cache_dir: str | Path | None = None,
|
77 |
local_files_only: bool = False,
|
78 |
+
revision: str | None = None,
|
79 |
**model_kwargs,
|
80 |
+
) -> Self:
|
81 |
return super().from_pretrained(
|
82 |
pretrained_model_name_or_path,
|
83 |
force_download=force_download,
|
|
|
116 |
# Additional initialization if needed
|
117 |
# self.name: str = self.__class__.__name__
|
118 |
from mlip_arena.models.utils import get_freer_device
|
119 |
+
|
120 |
self.device = device or get_freer_device()
|
121 |
self.cutoff = cutoff
|
122 |
self.model.to(self.device)
|