import torch import torch.nn as nn def train_one_epoch(model, dataloader, optimizer, criterion, device): model.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in dataloader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, preds = outputs.max(1) correct += preds.eq(labels).sum().item() total += labels.size(0) epoch_loss = running_loss / len(dataloader) epoch_acc = correct / total return epoch_loss, epoch_acc def validate_one_epoch(model, dataloader, criterion, device): model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in dataloader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) running_loss += loss.item() _, preds = outputs.max(1) correct += preds.eq(labels).sum().item() total += labels.size(0) epoch_loss = running_loss / len(dataloader) epoch_acc = correct / total return epoch_loss, epoch_acc