File size: 378 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

@torch.no_grad()
def no_grad(net):
    for param in net.parameters():
        param.requires_grad = False
    net.eval()
    return net

@torch.no_grad()
def filter_nograd_tensors(params_list):
    filtered_params_list = []
    for param in params_list:
        if param.requires_grad:
            filtered_params_list.append(param)
    return filtered_params_list