Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
def train_one_epoch(model, dataloader, optimizer, criterion, device): | |
model.train() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
for images, labels in dataloader: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
_, preds = outputs.max(1) | |
correct += preds.eq(labels).sum().item() | |
total += labels.size(0) | |
epoch_loss = running_loss / len(dataloader) | |
epoch_acc = correct / total | |
return epoch_loss, epoch_acc | |
def validate_one_epoch(model, dataloader, criterion, device): | |
model.eval() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for images, labels in dataloader: | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
running_loss += loss.item() | |
_, preds = outputs.max(1) | |
correct += preds.eq(labels).sum().item() | |
total += labels.size(0) | |
epoch_loss = running_loss / len(dataloader) | |
epoch_acc = correct / total | |
return epoch_loss, epoch_acc |