diff --git a/.github/README.md b/.github/README.md index eb971408393c9ab89fc13e7a5f66abbedaf93087..89da83e2f4ed3cb86b96d353f82cb41dff86e928 100644 --- a/.github/README.md +++ b/.github/README.md @@ -1,22 +1,25 @@
-> [!CAUTION] -> MLIP Arena is currently in pre-alpha. The results are not stable. Please intepret them with care. +MLIP Arena is a unified platform for evaluating foundation machine learning interatomic potentials (MLIPs) beyond conventional error metrics. It focuses on revealing the physical soundness learned by MLIPs and assessing their utilitarian performance agnostic to underlying model architecture. The platform's benchmarks are specifically designed to evaluate the readiness and reliability of open-source, open-weight models in accurately reproducing both qualitative and quantitative behaviors of atomic systems. + +MLIP Arena leverages modern pythonic workflow orchestrator [Prefect](https://www.prefect.io/) to enable advanced task/flow chaining and caching. > [!NOTE] > Contributions of new tasks are very welcome! If you're interested in joining the effort, please reach out to Yuan at [cyrusyc@berkeley.edu](mailto:cyrusyc@berkeley.edu). See [project page](https://github.com/orgs/atomind-ai/projects/1) for some outstanding tasks, or propose new one in [Discussion](https://github.com/atomind-ai/mlip-arena/discussions/new?category=ideas). -MLIP Arena is a unified platform for evaluating foundation machine learning interatomic potentials (MLIPs) beyond conventional error metrics. It focuses on revealing the physics and chemistry learned by these models and assessing their utilitarian performance agnostic to underlying model architecture. The platform's benchmarks are specifically designed to evaluate the readiness and reliability of open-source, open-weight models in accurately reproducing both qualitative and quantitative behaviors of atomic systems. +## Announcement + +- **[April 8, 2025]** [π MLIP Arena accepted as an ICLR AI4Mat Spotlight! π](https://openreview.net/forum?id=ysKfIavYQE#discussion) Huge thanks to all co-authors for their contributions! -MLIP Arena leverages modern pythonic workflow orchestrator [Prefect](https://www.prefect.io/) to enable advanced task/flow chaining and caching. ## Installation @@ -28,6 +31,8 @@ pip install mlip-arena ### From source +> [!Caution] We recommand to start from clean virtual environment due to the compatibility issues between multiple popular MLIPs. We provide one script installation script using uv for minimal package conflicts and fast installation! + **Linux** ```bash @@ -63,7 +68,7 @@ bash scripts/install-macosx.sh ## Quickstart -### Molecular dynamics (MD) +### First example: Molecular dynamics Arena provides a unified interface to run all the compiled MLIPs. This can be achieved simply by looping through `MLIPEnum`: @@ -109,6 +114,29 @@ The implemented tasks are available under `mlip_arena.tasks.15:46:06.786 | INFO | Flow run 'amigurumi-beagle' - Beginning flow run 'amigurumi-beagle' for flow 'benchmark-one'\n", + "\n" + ], + "text/plain": [ + "15:46:06.786 | \u001b[36mINFO\u001b[0m | Flow run\u001b[35m 'amigurumi-beagle'\u001b[0m - Beginning flow run\u001b[35m 'amigurumi-beagle'\u001b[0m for flow\u001b[1;35m 'benchmark-one'\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2a1bc33089b44a308a44cd14979533f7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6c5c9eda59644f528f76c5b2b18b272d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "mof/mofs.db: 0%| | 0.00/168k [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
15:46:07.619 | INFO | Task run 'Widom Insertion: C28H16O10V2 + CO2 - MACE-MP(M)' - Optimizing structure\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "15:46:07.619 | \u001b[36mINFO\u001b[0m | Task run 'Widom Insertion: C28H16O10V2 + CO2 - MACE-MP(M)' - Optimizing structure\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Selected GPU cuda:0 with 40339.31 MB free memory from 1 GPUs\n",
+ "Using device: cuda:0\n",
+ "Selected GPU cuda:0 with 40339.31 MB free memory from 1 GPUs\n",
+ "Default dtype float32 does not match model dtype float64, converting models to float32.\n",
+ "Using calculator: 15:46:14.836 | INFO | Task run 'OPT: C28H16O10V2 - MACE-MP(M)' - Finished in state Completed()\n", + "\n" + ], + "text/plain": [ + "15:46:14.836 | \u001b[36mINFO\u001b[0m | Task run 'OPT: C28H16O10V2 - MACE-MP(M)' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
15:46:14.840 | INFO | Task run 'Widom Insertion: C28H16O10V2 + CO2 - MACE-MP(M)' - Optimizing gas molecule\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "15:46:14.840 | \u001b[36mINFO\u001b[0m | Task run 'Widom Insertion: C28H16O10V2 + CO2 - MACE-MP(M)' - Optimizing gas molecule\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Selected GPU cuda:0 with 40301.98 MB free memory from 1 GPUs\n",
+ "Using device: cuda:0\n",
+ "Selected GPU cuda:0 with 40301.98 MB free memory from 1 GPUs\n",
+ "Default dtype float32 does not match model dtype float64, converting models to float32.\n",
+ "Using calculator: 15:46:17.127 | INFO | Task run 'OPT: CO2 - MACE-MP(M)' - Finished in state Completed()\n", + "\n" + ], + "text/plain": [ + "15:46:17.127 | \u001b[36mINFO\u001b[0m | Task run 'OPT: CO2 - MACE-MP(M)' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of accessible positions: 357364 out of total 498623\n", + "Selected GPU cuda:0 with 40280.80 MB free memory from 1 GPUs\n", + "Using device: cuda:0\n", + "Selected GPU cuda:0 with 40280.80 MB free memory from 1 GPUs\n", + "Default dtype float32 does not match model dtype float64, converting models to float32.\n", + "Using calculator:
15:52:13.884 | INFO | Task run 'Widom Insertion: C28H16O10V2 + CO2 - MACE-MP(M)' - Finished in state Completed()\n", + "\n" + ], + "text/plain": [ + "15:52:13.884 | \u001b[36mINFO\u001b[0m | Task run 'Widom Insertion: C28H16O10V2 + CO2 - MACE-MP(M)' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
15:52:14.235 | INFO | Flow run 'amigurumi-beagle' - Finished in state Completed()\n", + "\n" + ], + "text/plain": [ + "15:52:14.235 | \u001b[36mINFO\u001b[0m | Flow run\u001b[35m 'amigurumi-beagle'\u001b[0m - Finished in state \u001b[32mCompleted\u001b[0m()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[2.3816888372250245e-06, 2.5323794093995965e-06, inf]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = benchmark_one()\n", + "result[0]['henry_coefficient']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
11:40:36.644 | WARNING | MDAnalysis.coordinates.AMBER - netCDF4 is not available. Writing AMBER ncdf files will be slow.\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:36.644 | \u001b[38;5;184mWARNING\u001b[0m | MDAnalysis.coordinates.AMBER - netCDF4 is not available. Writing AMBER ncdf files will be slow.\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.431 | INFO | distributed.http.proxy - To route to workers diagnostics web server please install jupyter-server-proxy: python -m pip install jupyter-server-proxy\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.431 | \u001b[36mINFO\u001b[0m | distributed.http.proxy - To route to workers diagnostics web server please install jupyter-server-proxy: python -m pip install jupyter-server-proxy\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.445 | INFO | distributed.scheduler - State start\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.445 | \u001b[36mINFO\u001b[0m | distributed.scheduler - State start\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.503 | INFO | distributed.scheduler - Scheduler at: tcp://128.55.64.42:36351\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.503 | \u001b[36mINFO\u001b[0m | distributed.scheduler - Scheduler at: tcp://128.55.64.42:36351\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.505 | INFO | distributed.scheduler - dashboard at: http://128.55.64.42:8787/status\n", + "\n" + ], + "text/plain": [ + "11:40:44.505 | \u001b[36mINFO\u001b[0m | distributed.scheduler - dashboard at: \u001b[94mhttp://128.55.64.42:8787/status\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
11:40:44.506 | INFO | distributed.scheduler - Registering Worker plugin shuffle\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.506 | \u001b[36mINFO\u001b[0m | distributed.scheduler - Registering Worker plugin shuffle\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "#!/bin/bash\n",
+ "\n",
+ "#SBATCH -A matgen\n",
+ "#SBATCH --mem=0\n",
+ "#SBATCH -t 00:30:00\n",
+ "#SBATCH -J mof\n",
+ "#SBATCH -q regular\n",
+ "#SBATCH -N 1\n",
+ "#SBATCH -C gpu\n",
+ "#SBATCH -G 4\n",
+ "source ~/.bashrc\n",
+ "module load python\n",
+ "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\n",
+ "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/bin/python -m distributed.cli.dask_worker tcp://128.55.64.42:36351 --name dummy-name --nthreads 1 --memory-limit 59.60GiB --nanny --death-timeout 60\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.514 | INFO | distributed.deploy.adaptive - Adaptive scaling started: minimum=10 maximum=20\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.514 | \u001b[36mINFO\u001b[0m | distributed.deploy.adaptive - Adaptive scaling started: minimum=10 maximum=20\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.522 | INFO | distributed.scheduler - Receive client connection: Client-a27a9a6e-c09c-11ef-8318-c77ccf4f19b4\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.522 | \u001b[36mINFO\u001b[0m | distributed.scheduler - Receive client connection: Client-a27a9a6e-c09c-11ef-8318-c77ccf4f19b4\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:44.523 | INFO | distributed.core - Starting established connection to tcp://128.55.64.42:48148\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:44.523 | \u001b[36mINFO\u001b[0m | distributed.core - Starting established connection to tcp://128.55.64.42:48148\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:45.046 | INFO | prefect.engine - Created flow run 'enormous-hog' for flow 'run'\n", + "\n" + ], + "text/plain": [ + "11:40:45.046 | \u001b[36mINFO\u001b[0m | prefect.engine - Created flow run\u001b[35m 'enormous-hog'\u001b[0m for flow\u001b[1;35m 'run'\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
11:40:45.048 | INFO | prefect.engine - View at https://app.prefect.cloud/account/f7d40474-9362-4bfa-8950-ee6a43ec00f3/workspace/d4bb0913-5f5e-49f7-bfc5-06509088baeb/runs/flow-run/c0c7a3f2-d8d0-4f17-9789-4e070f17bf3b\n", + "\n" + ], + "text/plain": [ + "11:40:45.048 | \u001b[36mINFO\u001b[0m | prefect.engine - View at \u001b[94mhttps://app.prefect.cloud/account/f7d40474-9362-4bfa-8950-ee6a43ec00f3/workspace/d4bb0913-5f5e-49f7-bfc5-06509088baeb/runs/flow-run/c0c7a3f2-d8d0-4f17-9789-4e070f17bf3b\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
11:40:45.366 | INFO | prefect.task_runner.dask - Connecting to existing Dask cluster SLURMCluster(00ac1d39, 'tcp://128.55.64.42:36351', workers=0, threads=0, memory=0 B)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:45.366 | \u001b[36mINFO\u001b[0m | prefect.task_runner.dask - Connecting to existing Dask cluster SLURMCluster(00ac1d39, 'tcp://128.55.64.42:36351', workers=0, threads=0, memory=0 B)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:45.395 | INFO | distributed.scheduler - Receive client connection: PrefectDaskClient-a2fe06b3-c09c-11ef-8318-c77ccf4f19b4\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:45.395 | \u001b[36mINFO\u001b[0m | distributed.scheduler - Receive client connection: PrefectDaskClient-a2fe06b3-c09c-11ef-8318-c77ccf4f19b4\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:45.401 | INFO | distributed.core - Starting established connection to tcp://128.55.64.42:48168\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:45.401 | \u001b[36mINFO\u001b[0m | distributed.core - Starting established connection to tcp://128.55.64.42:48168\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c88bb75bf9b84285bfd6d524e7d73650",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "11:40:45.501 | INFO | Task run 'get_atoms_from_db-6be' - Created task run 'get_atoms_from_db-6be' for task 'get_atoms_from_db'\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "11:40:45.501 | \u001b[36mINFO\u001b[0m | Task run 'get_atoms_from_db-6be' - Created task run 'get_atoms_from_db-6be' for task 'get_atoms_from_db'\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from dask.distributed import Client\n",
+ "from dask_jobqueue import SLURMCluster\n",
+ "from prefect_dask import DaskTaskRunner\n",
+ "from mlip_arena.tasks.mof.flow import run as MOF\n",
+ "\n",
+ "# Orchestrate your awesome dask workflow runner\n",
+ "\n",
+ "nodes_per_alloc = 1\n",
+ "gpus_per_alloc = 4\n",
+ "ntasks = 1\n",
+ "\n",
+ "cluster_kwargs = dict(\n",
+ " cores=1,\n",
+ " memory=\"64 GB\",\n",
+ " shebang=\"#!/bin/bash\",\n",
+ " account=\"matgen\",\n",
+ " walltime=\"00:30:00\",\n",
+ " job_mem=\"0\",\n",
+ " job_script_prologue=[\n",
+ " \"source ~/.bashrc\",\n",
+ " \"module load python\",\n",
+ " \"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\",\n",
+ " ],\n",
+ " job_directives_skip=[\"-n\", \"--cpus-per-task\", \"-J\"],\n",
+ " job_extra_directives=[\n",
+ " \"-J mof\",\n",
+ " \"-q regular\",\n",
+ " f\"-N {nodes_per_alloc}\",\n",
+ " \"-C gpu\",\n",
+ " f\"-G {gpus_per_alloc}\",\n",
+ " ],\n",
+ ")\n",
+ "\n",
+ "cluster = SLURMCluster(**cluster_kwargs)\n",
+ "print(cluster.job_script())\n",
+ "cluster.adapt(minimum_jobs=10, maximum_jobs=20)\n",
+ "client = Client(cluster)\n",
+ "\n",
+ "# Run the workflow on HPC cluster in parallel\n",
+ "\n",
+ "results = MOF.with_options(\n",
+ " task_runner=DaskTaskRunner(address=client.scheduler.address),\n",
+ " # log_prints=True,\n",
+ ")()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "mlip-arena",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {},
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/mlip_arena/models/__init__.py b/mlip_arena/models/__init__.py
index ec149a30292a24f1e4a1aa689bdc28f8afe1aaf6..833ffa00c5ff25755e50af03c1f70237625a9d29 100644
--- a/mlip_arena/models/__init__.py
+++ b/mlip_arena/models/__init__.py
@@ -3,6 +3,9 @@ from __future__ import annotations
import importlib
from enum import Enum
from pathlib import Path
+from typing import Dict, Optional, Type, TypeVar, Union
+
+T = TypeVar("T", bound="MLIP")
import torch
import yaml
@@ -11,6 +14,7 @@ from torch import nn
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
+from mlip_arena.data.collate import collate_fn
try:
from prefect.logging import get_run_logger
@@ -49,6 +53,35 @@ class MLIP(
# self.model = torch.compile(model)
self.model = model
+ def _save_pretrained(self, save_directory: Path) -> None:
+ return super()._save_pretrained(save_directory)
+
+ @classmethod
+ def from_pretrained(
+ cls: Type[T],
+ pretrained_model_name_or_path: Union[str, Path],
+ *,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[Dict] = None,
+ token: Optional[Union[str, bool]] = None,
+ cache_dir: Optional[Union[str, Path]] = None,
+ local_files_only: bool = False,
+ revision: Optional[str] = None,
+ **model_kwargs,
+ ) -> T:
+ return super().from_pretrained(
+ pretrained_model_name_or_path,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ revision=revision,
+ **model_kwargs,
+ )
+
def forward(self, x):
return self.model(x)
@@ -101,7 +134,6 @@ class MLIPCalculator(MLIP, Calculator):
):
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)
- from mlip_arena.data.collate import collate_fn
# TODO: move collate_fn to here in MLIPCalculator
data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
diff --git a/mlip_arena/models/registry.yaml b/mlip_arena/models/registry.yaml
index 814a7a204da51a4bdd4ef09f32f3ae1a6533da50..b8c6e5def10b18ad00e5040015027c9730e70cf3 100644
--- a/mlip_arena/models/registry.yaml
+++ b/mlip_arena/models/registry.yaml
@@ -363,42 +363,4 @@ ORB:
prediction: EFS
nvt: true
npt: true
- license: Apache-2.0
-
-# ORBv2(MPTrj):
-# module: externals
-# class: ORBv2
-# family: orb
-# package:
-# checkpoint: orb-mptraj-only-v2-20241014.ckpt
-# username:
-# last-update: 2024-10-29T00:00:00
-# datetime: 2024-10-29T00:00:00 # TODO: Fake datetime
-# datasets:
-# - MPTrj
-# github: https://github.com/orbital-materials/orb-models
-# doi:
-# date: 2024-10-15
-# prediction: EFS
-# nvt: true
-# npt: true
-# license: Apache-2.0
-
-# eqV2(MPTrj-S):
-# module: externals
-# class: eqV2
-# family: fairchem
-# package: fairchem-core==1.2.0
-# checkpoint: eqV2_31M_mp.pt
-# username: fairchem # HF handle
-# last-update: 2024-10-29T00:00:00
-# datetime: 2024-10-29T00:00:00
-# datasets:
-# - MPTrj
-# prediction: EFS
-# nvt: true
-# npt: true
-# date: 2024-10-18
-# github: https://github.com/FAIR-Chem/fairchem
-# doi: https://arxiv.org/abs/2410.12771
-# license: Modified Apache-2.0 (Meta)
\ No newline at end of file
+ license: Apache-2.0
\ No newline at end of file
diff --git a/mlip_arena/models/utils.py b/mlip_arena/models/utils.py
index 33acd9c2778c0f86b29ceb551152f8fbe4e2d333..40d7aa9085e3d90b698ceb78801e9dea9c83a384 100644
--- a/mlip_arena/models/utils.py
+++ b/mlip_arena/models/utils.py
@@ -12,12 +12,12 @@ except (ImportError, RuntimeError):
def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory, or use MPS if available.
- s
- Returns:
- torch.device: The selected GPU device or MPS.
- Raises:
- ValueError: If no GPU or MPS is available.
+ Returns:
+ torch.device: The selected GPU device or MPS.
+
+ Raises:
+ ValueError: If no GPU or MPS is available.
"""
device_count = torch.cuda.device_count()
if device_count > 0:
diff --git a/mlip_arena/tasks/README.md b/mlip_arena/tasks/README.md
index 3f922b07d9630a83f5862a25b808983a50126797..b2be171a46e3596709631a323bef18f53c61069d 100644
--- a/mlip_arena/tasks/README.md
+++ b/mlip_arena/tasks/README.md
@@ -1,3 +1,21 @@
+
+
+## Task
+
+In the language of Prefect workflow manager, we define a task as *one operation on one input structure* that generates result for **one sample**. For example, [Structure optimization (OPT)](optimize.py) initiates one structure optimization on one structure and return the relaxed structure.
+
+It is possible to chain multiple subtasks into a single, complex task. For example, [Equation of states (EOS)](eos.py) first performs one full relaxed [OPT](optimize.py) task and parallelizes/serializes multiple constrained [OPT](optimize.py) tasks in one call, and returns the equation of state and bulk modulus of the structure.
+
+There are some general tasks that can be reused:
+- [Structure optimization (OPT)](optimize.py)
+- [Molecular dynamics (MD)](md.py)
+- [Equation of states (EOS)](eos.py)
+
+## Flow
+
+Flow is meant to be used to parallize multiple tasks and be orchestrated for production at scale on high-throughput cluster.
+
+