File size: 692 Bytes
44504f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
case: "def epoch(loader, model, device, criterion, opt=None):
            losses = AverageMeter()
    
            if opt is None:
                model.eval()
            else:
                model.train()
            for inputs, _ in tqdm(loader, leave=False):
                inputs = inputs.view(-1, 28 * 28).to(device)
                outputs = model(inputs)
                loss = criterion(outputs, inputs)
                if opt:
                    opt.zero_grad(set_to_none=True)
                    loss.backward()
                    opt.step()
                    model.clamp()
    
                losses.update(loss.item(), inputs.size(0))
    
            return losses.avg"