cyrusyc commited on
Commit
d97b3b5
·
1 Parent(s): 258a9f9

gracefully bail out collate

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. mlip_arena/models/__init__.py +20 -11
README.md CHANGED
@@ -3,7 +3,12 @@ title: MLIP Arena
3
  emoji: ⚛
4
  sdk: streamlit
5
  sdk_version: 1.43.2 # The latest supported version
 
6
  app_file: serve/app.py
 
 
 
 
7
  ---
8
 
9
 
 
3
  emoji: ⚛
4
  sdk: streamlit
5
  sdk_version: 1.43.2 # The latest supported version
6
+ python_version: 3.11
7
  app_file: serve/app.py
8
+ colorFrom: indigo
9
+ colorTo: yellow
10
+ pinned: true
11
+ short_description: Benchmark machine learning interatomic potential at scale
12
  ---
13
 
14
 
mlip_arena/models/__init__.py CHANGED
@@ -9,12 +9,20 @@ T = TypeVar("T", bound="MLIP")
9
 
10
  import torch
11
  import yaml
 
 
12
  from huggingface_hub import PyTorchModelHubMixin
13
  from torch import nn
 
14
 
15
- from ase import Atoms
16
- from ase.calculators.calculator import Calculator, all_changes
17
- from mlip_arena.data.collate import collate_fn
 
 
 
 
 
18
 
19
  try:
20
  from prefect.logging import get_run_logger
@@ -58,18 +66,18 @@ class MLIP(
58
 
59
  @classmethod
60
  def from_pretrained(
61
- cls: Type[T],
62
- pretrained_model_name_or_path: Union[str, Path],
63
  *,
64
  force_download: bool = False,
65
- resume_download: Optional[bool] = None,
66
- proxies: Optional[Dict] = None,
67
- token: Optional[Union[str, bool]] = None,
68
- cache_dir: Optional[Union[str, Path]] = None,
69
  local_files_only: bool = False,
70
- revision: Optional[str] = None,
71
  **model_kwargs,
72
- ) -> T:
73
  return super().from_pretrained(
74
  pretrained_model_name_or_path,
75
  force_download=force_download,
@@ -108,6 +116,7 @@ class MLIPCalculator(MLIP, Calculator):
108
  # Additional initialization if needed
109
  # self.name: str = self.__class__.__name__
110
  from mlip_arena.models.utils import get_freer_device
 
111
  self.device = device or get_freer_device()
112
  self.cutoff = cutoff
113
  self.model.to(self.device)
 
9
 
10
  import torch
11
  import yaml
12
+ from ase import Atoms
13
+ from ase.calculators.calculator import Calculator, all_changes
14
  from huggingface_hub import PyTorchModelHubMixin
15
  from torch import nn
16
+ from typing_extensions import Self
17
 
18
+ try:
19
+ from mlip_arena.data.collate import collate_fn
20
+ except ImportError:
21
+ # Fallback to a dummy function if the import fails
22
+ def collate_fn(batch: list[Atoms], cutoff: float) -> None:
23
+ raise ImportError(
24
+ "collate_fn import failed. Please install the required dependencies."
25
+ )
26
 
27
  try:
28
  from prefect.logging import get_run_logger
 
66
 
67
  @classmethod
68
  def from_pretrained(
69
+ cls,
70
+ pretrained_model_name_or_path: str | Path,
71
  *,
72
  force_download: bool = False,
73
+ resume_download: bool | None = None,
74
+ proxies: dict | None = None,
75
+ token: str | bool | None = None,
76
+ cache_dir: str | Path | None = None,
77
  local_files_only: bool = False,
78
+ revision: str | None = None,
79
  **model_kwargs,
80
+ ) -> Self:
81
  return super().from_pretrained(
82
  pretrained_model_name_or_path,
83
  force_download=force_download,
 
116
  # Additional initialization if needed
117
  # self.name: str = self.__class__.__name__
118
  from mlip_arena.models.utils import get_freer_device
119
+
120
  self.device = device or get_freer_device()
121
  self.cutoff = cutoff
122
  self.model.to(self.device)