File size: 346 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from typing import List, Dict, Tuple


def recursive_detach(x):
    if isinstance(x, torch.Tensor):
        return x.detach()
    elif isinstance(x, Dict):
        return {k: recursive_detach(v) for k, v in x.items()}
    elif isinstance(x, (List, Tuple)):
        return [recursive_detach(v) for v in x]
    else:
        return x