File size: 1,109 Bytes
cbe2793
 
d8fce5c
fc31c67
 
d8fce5c
cbe2793
 
 
 
 
 
 
 
 
 
fc31c67
cbe2793
fc31c67
 
 
 
cbe2793
fc31c67
cbe2793
fc31c67
cbe2793
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import torch
import logging

def select_device(min_memory=2048):
    logger = logging.getLogger(__name__)
    if torch.cuda.is_available():
        available_gpus = []
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            free_memory = props.total_memory - torch.cuda.memory_reserved(i)
            available_gpus.append((i, free_memory))
        selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1])
        device = torch.device(f'cuda:{selected_gpu}')
        free_memory_mb = max_free_memory / (1024 * 1024)
        if free_memory_mb < min_memory:
            logger.warning(f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU.')
            device = torch.device('cpu')
    elif torch.backends.mps.is_available():
        # For Apple M1/M2 chips with Metal Performance Shaders
        logger.info('Apple GPU found, using MPS.')
        device = torch.device('mps')
    else:
        logger.warning('No GPU found, use CPU instead')
        device = torch.device('cpu')

    return device