File size: 3,955 Bytes
a367f8b
 
 
 
54047c6
 
 
50064b5
54047c6
a367f8b
56200e6
 
a367f8b
 
50064b5
 
 
 
 
 
 
54047c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50064b5
a367f8b
 
54047c6
 
50064b5
a367f8b
 
54047c6
50064b5
a367f8b
a2da3dd
a367f8b
 
a2da3dd
 
 
a367f8b
 
 
56200e6
a367f8b
 
072f65e
54047c6
 
 
df7cf57
 
56200e6
54047c6
 
 
 
531f267
56200e6
a367f8b
 
 
df7cf57
50064b5
 
a367f8b
50064b5
 
 
a367f8b
 
 
50064b5
a367f8b
 
 
 
072f65e
a367f8b
df7cf57
50064b5
 
a367f8b
56200e6
a367f8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Utility functions for MLIP models."""

from __future__ import annotations

from pprint import pformat

import torch
from ase import units
from ase.calculators.calculator import BaseCalculator
from ase.calculators.mixing import SumCalculator
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator

from mlip_arena.models import MLIPEnum

try:
    from prefect.logging import get_run_logger

    logger = get_run_logger()
except (ImportError, RuntimeError):
    from loguru import logger


def get_freer_device() -> torch.device:
    """Get the GPU with the most free memory, or use MPS if available.
    s
        Returns:
            torch.device: The selected GPU device or MPS.

        Raises:
            ValueError: If no GPU or MPS is available.
    """
    device_count = torch.cuda.device_count()
    if device_count > 0:
        # If CUDA GPUs are available, select the one with the most free memory
        mem_free = [
            torch.cuda.get_device_properties(i).total_memory
            - torch.cuda.memory_allocated(i)
            for i in range(device_count)
        ]
        free_gpu_index = mem_free.index(max(mem_free))
        device = torch.device(f"cuda:{free_gpu_index}")
        logger.info(
            f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
        )
    elif torch.backends.mps.is_available():
        # If no CUDA GPUs are available but MPS is, use MPS
        logger.info("No GPU available. Using MPS.")
        device = torch.device("mps")
    else:
        # Fallback to CPU if neither CUDA GPUs nor MPS are available
        logger.info("No GPU or MPS available. Using CPU.")
        device = torch.device("cpu")

    return device


def get_calculator(
    calculator_name: str | MLIPEnum | BaseCalculator,
    calculator_kwargs: dict | None = None,
    dispersion: bool = False,
    dispersion_kwargs: dict | None = None,
    device: str | None = None,
) -> BaseCalculator:
    """Get a calculator with optional dispersion correction."""

    device = device or str(get_freer_device())

    calculator_kwargs = calculator_kwargs or {}
    calculator_kwargs.update({"device": device})

    logger.info(f"Using device: {device}")

    if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
        calc = calculator_name.value(**calculator_kwargs)
        calc.__str__ = lambda: calculator_name.name
    elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
        calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
        calc.__str__ = lambda: calculator_name
    elif isinstance(calculator_name, type) and issubclass(
        calculator_name, BaseCalculator
    ):
        logger.warning(f"Using custom calculator class: {calculator_name}")
        calc = calculator_name(**calculator_kwargs)
        calc.__str__ = lambda: f"{calc.__class__.__name__}"
    elif isinstance(calculator_name, BaseCalculator):
        logger.warning(
            f"Using custom calculator object (kwargs are ignored): {calculator_name}"
        )
        calc = calculator_name
        calc.__str__ = lambda: f"{calc.__class__.__name__}"
    else:
        raise ValueError(f"Invalid calculator: {calculator_name}")

    logger.info(f"Using calculator: {calc}")
    if calculator_kwargs:
        logger.info(pformat(calculator_kwargs))

    dispersion_kwargs = dispersion_kwargs or dict(
        damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
    )

    dispersion_kwargs.update({"device": device})

    if dispersion:
        disp_calc = TorchDFTD3Calculator(
            **dispersion_kwargs,
        )
        calc = SumCalculator([calc, disp_calc])
        # TODO: rename the SumCalculator

        logger.info(f"Using dispersion: {disp_calc}")
        if dispersion_kwargs:
            logger.info(pformat(dispersion_kwargs))

    assert isinstance(calc, BaseCalculator)    
    return calc