Spaces:
Running
Running
File size: 4,241 Bytes
a133fcb b7d94da a133fcb 17c5f78 a133fcb 17c5f78 a133fcb 17c5f78 a133fcb 17c5f78 a133fcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from __future__ import annotations
from pathlib import Path
import yaml
from ase import Atoms
from fairchem.core import OCPCalculator
from huggingface_hub import hf_hub_download
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
class eSEN(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["eSEN"]["checkpoint"],
cache_dir=None,
cpu=False, # TODO: cannot assign device
seed=0,
**kwargs,
) -> None:
# https://huggingface.co/facebook/OMAT24/resolve/main/esen_30m_oam.pt
checkpoint_path = hf_hub_download(
"fairchem/OMAT24",
filename=checkpoint,
revision="13ab5b8d71af67bd1c83fbbf53250c82cd87f506",
cache_dir=cache_dir
)
kwargs.pop("device", None)
super().__init__(
checkpoint_path=checkpoint_path,
cpu=cpu,
seed=seed,
**kwargs,
)
class eqV2(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["eqV2(OMat)"]["checkpoint"],
cache_dir=None,
cpu=False, # TODO: cannot assign device
seed=0,
**kwargs,
) -> None:
"""
Initialize an eqV2 calculator.
Parameters
----------
checkpoint : str, default="eqV2_86M_omat_mp_salex.pt"
The name of the eqV2 checkpoint to use.
local_cache : str, default="/tmp/ocp/"
The directory to store the downloaded checkpoint.
cpu : bool, default=False
Whether to run the model on CPU or GPU.
seed : int, default=0
The random seed for the model.
Other Parameters
----------------
**kwargs
Any additional keyword arguments are passed to the superclass.
"""
# https://huggingface.co/fairchem/OMAT24/resolve/main/eqV2_86M_omat_mp_salex.pt
checkpoint_path = hf_hub_download(
"fairchem/OMAT24",
filename=checkpoint,
revision="bf92f9671cb9d5b5c77ecb4aa8b317ff10b882ce",
cache_dir=cache_dir
)
kwargs.pop("device", None)
super().__init__(
checkpoint_path=checkpoint_path,
cpu=cpu,
seed=seed,
**kwargs,
)
class EquiformerV2(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
kwargs.pop("device", None)
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
super().calculate(atoms, properties, system_changes)
self.results.update(
force=atoms.get_forces(),
)
class EquiformerV2OC20(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
kwargs.pop("device", None)
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
class eSCN(OCPCalculator):
def __init__(
self,
checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
kwargs.pop("device", None)
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
super().calculate(atoms, properties, system_changes)
self.results.update(
force=atoms.get_forces(),
)
|