File size: 3,298 Bytes
51638da
1effaf5
 
 
 
 
 
419b35b
1d1ee87
1effaf5
 
 
 
419b35b
 
 
e80e29d
419b35b
1effaf5
 
 
 
 
 
 
51638da
1effaf5
 
 
 
1d1ee87
1effaf5
 
 
 
 
 
 
 
51638da
 
 
 
 
 
 
 
419b35b
1effaf5
51638da
1effaf5
51638da
 
a787930
51638da
1effaf5
 
419b35b
1effaf5
 
 
 
 
1d1ee87
1effaf5
419b35b
 
1effaf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1ee87
 
 
1effaf5
 
e80e29d
 
1effaf5
e80e29d
 
 
 
1effaf5
 
 
 
e80e29d
 
 
1effaf5
 
7cc6c4a
1effaf5
 
7cc6c4a
 
1effaf5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Define structure optimization tasks.
"""

from __future__ import annotations

from ase import Atoms
from ase.calculators.calculator import BaseCalculator
from ase.constraints import FixSymmetry
from ase.filters import *  # type: ignore
from ase.filters import Filter
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run

from mlip_arena.tasks.utils import logger, pformat

_valid_filters: dict[str, Filter] = {
    "Filter": Filter,
    "UnitCell": UnitCellFilter,
    "ExpCell": ExpCellFilter,
    "Strain": StrainFilter,
    "FrechetCell": FrechetCellFilter,
}  # type: ignore

_valid_optimizers: dict[str, Optimizer] = {
    "MDMin": MDMin,
    "FIRE": FIRE,
    "FIRE2": FIRE2,
    "LBFGS": LBFGS,
    "LBFGSLineSearch": LBFGSLineSearch,
    "BFGS": BFGS,
    "BFGSLineSearch": BFGSLineSearch,
    "QuasiNewton": QuasiNewton,
    "GPMin": GPMin,
    "CellAwareBFGS": CellAwareBFGS,
    "ODE12r": ODE12r,
}  # type: ignore


def _generate_task_run_name():
    task_name = task_run.task_name
    parameters = task_run.parameters

    atoms = parameters["atoms"]
    calculator_name = parameters["calculator"]

    return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"


@task(
    name="OPT", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
)
def run(
    atoms: Atoms,
    calculator: BaseCalculator,
    optimizer: Optimizer | str = BFGSLineSearch,
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = None,
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    symmetry: bool = False,
):
    atoms = atoms.copy()
    atoms.calc = calculator

    if isinstance(filter, str):
        if filter not in _valid_filters:
            raise ValueError(f"Invalid filter: {filter}")
        filter = _valid_filters[filter]

    if isinstance(optimizer, str):
        if optimizer not in _valid_optimizers:
            raise ValueError(f"Invalid optimizer: {optimizer}")
        optimizer = _valid_optimizers[optimizer]

    filter_kwargs = filter_kwargs or {}
    optimizer_kwargs = optimizer_kwargs or {}
    criterion = criterion or {}

    if symmetry:
        atoms.set_constraint(FixSymmetry(atoms))

    if isinstance(filter, type) and issubclass(filter, Filter):
        filter_instance = filter(atoms, **filter_kwargs)
        logger.info(f"Using filter: {filter_instance}")
        logger.info(pformat(filter_kwargs))

        optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
        logger.info(f"Using optimizer: {optimizer_instance}")
        logger.info(pformat(optimizer_kwargs))
        logger.info(f"Criterion: {pformat(criterion)}")

        optimizer_instance.run(**criterion)
    elif filter is None:
        optimizer_instance = optimizer(atoms, **optimizer_kwargs)
        logger.info(f"Using optimizer: {optimizer_instance}")
        logger.info(pformat(optimizer_kwargs))
        logger.info(f"Criterion: {pformat(criterion)}")
        optimizer_instance.run(**criterion)


    return {
        "atoms": atoms,
        "steps": optimizer_instance.nsteps,
        "converged": optimizer_instance.converged(),
    }