Spaces:
Running
Running
File size: 3,955 Bytes
a367f8b 54047c6 50064b5 54047c6 a367f8b 56200e6 a367f8b 50064b5 54047c6 50064b5 a367f8b 54047c6 50064b5 a367f8b 54047c6 50064b5 a367f8b a2da3dd a367f8b a2da3dd a367f8b 56200e6 a367f8b 072f65e 54047c6 df7cf57 56200e6 54047c6 531f267 56200e6 a367f8b df7cf57 50064b5 a367f8b 50064b5 a367f8b 50064b5 a367f8b 072f65e a367f8b df7cf57 50064b5 a367f8b 56200e6 a367f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""Utility functions for MLIP models."""
from __future__ import annotations
from pprint import pformat
import torch
from ase import units
from ase.calculators.calculator import BaseCalculator
from ase.calculators.mixing import SumCalculator
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
from mlip_arena.models import MLIPEnum
try:
from prefect.logging import get_run_logger
logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger
def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory, or use MPS if available.
s
Returns:
torch.device: The selected GPU device or MPS.
Raises:
ValueError: If no GPU or MPS is available.
"""
device_count = torch.cuda.device_count()
if device_count > 0:
# If CUDA GPUs are available, select the one with the most free memory
mem_free = [
torch.cuda.get_device_properties(i).total_memory
- torch.cuda.memory_allocated(i)
for i in range(device_count)
]
free_gpu_index = mem_free.index(max(mem_free))
device = torch.device(f"cuda:{free_gpu_index}")
logger.info(
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
)
elif torch.backends.mps.is_available():
# If no CUDA GPUs are available but MPS is, use MPS
logger.info("No GPU available. Using MPS.")
device = torch.device("mps")
else:
# Fallback to CPU if neither CUDA GPUs nor MPS are available
logger.info("No GPU or MPS available. Using CPU.")
device = torch.device("cpu")
return device
def get_calculator(
calculator_name: str | MLIPEnum | BaseCalculator,
calculator_kwargs: dict | None = None,
dispersion: bool = False,
dispersion_kwargs: dict | None = None,
device: str | None = None,
) -> BaseCalculator:
"""Get a calculator with optional dispersion correction."""
device = device or str(get_freer_device())
calculator_kwargs = calculator_kwargs or {}
calculator_kwargs.update({"device": device})
logger.info(f"Using device: {device}")
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
calc = calculator_name.value(**calculator_kwargs)
calc.__str__ = lambda: calculator_name.name
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
calc.__str__ = lambda: calculator_name
elif isinstance(calculator_name, type) and issubclass(
calculator_name, BaseCalculator
):
logger.warning(f"Using custom calculator class: {calculator_name}")
calc = calculator_name(**calculator_kwargs)
calc.__str__ = lambda: f"{calc.__class__.__name__}"
elif isinstance(calculator_name, BaseCalculator):
logger.warning(
f"Using custom calculator object (kwargs are ignored): {calculator_name}"
)
calc = calculator_name
calc.__str__ = lambda: f"{calc.__class__.__name__}"
else:
raise ValueError(f"Invalid calculator: {calculator_name}")
logger.info(f"Using calculator: {calc}")
if calculator_kwargs:
logger.info(pformat(calculator_kwargs))
dispersion_kwargs = dispersion_kwargs or dict(
damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
)
dispersion_kwargs.update({"device": device})
if dispersion:
disp_calc = TorchDFTD3Calculator(
**dispersion_kwargs,
)
calc = SumCalculator([calc, disp_calc])
# TODO: rename the SumCalculator
logger.info(f"Using dispersion: {disp_calc}")
if dispersion_kwargs:
logger.info(pformat(dispersion_kwargs))
assert isinstance(calc, BaseCalculator)
return calc
|