Spaces:
Runtime error
Runtime error
case: "def epoch(loader, model, device, criterion, opt=None): | |
losses = AverageMeter() | |
if opt is None: | |
model.eval() | |
else: | |
model.train() | |
for inputs, _ in tqdm(loader, leave=False): | |
inputs = inputs.view(-1, 28 * 28).to(device) | |
outputs = model(inputs) | |
loss = criterion(outputs, inputs) | |
if opt: | |
opt.zero_grad(set_to_none=True) | |
loss.backward() | |
opt.step() | |
model.clamp() | |
losses.update(loss.item(), inputs.size(0)) | |
return losses.avg" | |