File size: 3,253 Bytes
1e50f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Define equation of state flows.

https://github.com/materialsvirtuallab/matcalc/blob/main/matcalc/eos.py
"""

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import numpy as np
from ase import Atoms
from ase.filters import *  # type: ignore
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from prefect import flow
from pymatgen.analysis.eos import BirchMurnaghan

from prefect.futures import wait

from mlip_arena.models.utils import MLIPEnum
from mlip_arena.tasks.optimize import run as OPT

if TYPE_CHECKING:
    from ase.filters import Filter


@flow
def fit(
    atoms: Atoms,
    calculator_name: str | MLIPEnum,
    calculator_kwargs: dict | None,
    device: str | None = None,
    optimizer: Optimizer | str = BFGSLineSearch, # type: ignore
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = None,
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    max_abs_strain: float = 0.1,
    npoints: int = 11,
):
    """
    Compute the equation of state (EOS) for the given atoms and calculator.

    Args:
        atoms: The input atoms.
        calculator_name: The name of the calculator to use.
        calculator_kwargs: Additional kwargs to pass to the calculator.
        device: The device to use.
        optimizer: The optimizer to use.
        optimizer_kwargs: Additional kwargs to pass to the optimizer.
        filter: The filter to use.
        filter_kwargs: Additional kwargs to pass to the filter.
        criterion: The criterion to use.
        max_abs_strain: The maximum absolute strain to use.
        npoints: The number of points to sample.

    Returns:
        A dictionary containing the EOS data and the bulk modulus.
    """
    result = OPT(
        atoms=atoms,
        calculator_name=calculator_name,
        calculator_kwargs=calculator_kwargs,
        device=device,
        optimizer=optimizer,
        optimizer_kwargs=optimizer_kwargs,
        filter=filter,
        filter_kwargs=filter_kwargs,
        criterion=criterion,
    )

    relaxed = result["atoms"]

    # p0 = relaxed.get_positions()
    c0 = relaxed.get_cell()

    factors = np.linspace(1 - max_abs_strain, 1 + max_abs_strain, npoints) ** (1 / 3)

    futures = []
    for f in factors:
        atoms = relaxed.copy()
        atoms.set_cell(c0 * f, scale_atoms=True)

        future = OPT.submit(
            atoms=atoms,
            calculator_name=calculator_name,
            calculator_kwargs=calculator_kwargs,
            device=device,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            filter=None,
            filter_kwargs=None,
            criterion=criterion,
        )

        futures.append(future)

    wait(futures)

    volumes = [f.result()["atoms"].get_volume() for f in futures]
    energies = [f.result()["atoms"].get_potential_energy() for f in futures]

    bm = BirchMurnaghan(volumes=volumes, energies=energies)
    bm.fit()

    volumes, energies = map(list, zip(*sorted(zip(volumes, energies, strict=False), key=lambda i: i[0]), strict=False))

    return {
        "eos": {"volumes": volumes, "energies": energies},
        "K": bm.b0_GPa,
    }