Spaces:
Sleeping
Sleeping
File size: 1,330 Bytes
09823ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
import torch.nn as nn
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, preds = outputs.max(1)
correct += preds.eq(labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / len(dataloader)
epoch_acc = correct / total
return epoch_loss, epoch_acc
def validate_one_epoch(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, preds = outputs.max(1)
correct += preds.eq(labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / len(dataloader)
epoch_acc = correct / total
return epoch_loss, epoch_acc |