File size: 1,229 Bytes
5ef2360
 
9a3524a
5ef2360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a3524a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import numpy as np
import subprocess

def get_device(force_cpu=False):
    if force_cpu:
        return "cpu"
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        torch.mps.empty_cache()
        return "mps"
    else:
        return "cpu"
    
def get_torch_and_np_dtypes(device, use_bfloat16=False):
    if device == "cuda":
        torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
        np_dtype = np.float16
    elif device == "mps":
        torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
        np_dtype = np.float16
    else:
        torch_dtype = torch.float32
        np_dtype = np.float32
    return torch_dtype, np_dtype

def cuda_version_check():
    if torch.cuda.is_available():
        try:
            cuda_runtime = subprocess.check_output(["nvcc", "--version"]).decode()
            cuda_version = cuda_runtime.split()[-2]
        except Exception:
            # Fallback to PyTorch's built-in version if nvcc isn't available
            cuda_version = torch.version.cuda
        
        device_name = torch.cuda.get_device_name(0)
        return cuda_version, device_name
    else:
        return None, None