CropGuard / src /model /train.py
mitraarka27's picture
πŸš€ Initial full clean push to Hugging Face
09823ea
import torch
import torch.nn as nn
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, preds = outputs.max(1)
correct += preds.eq(labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / len(dataloader)
epoch_acc = correct / total
return epoch_loss, epoch_acc
def validate_one_epoch(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, preds = outputs.max(1)
correct += preds.eq(labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / len(dataloader)
epoch_acc = correct / total
return epoch_loss, epoch_acc