Jechen00's picture
initial commit with Panel app
1078e59
#####################################
# Packages
#####################################
import torch
from typing import Tuple, Dict, List
import utils
#####################################
# Functions
#####################################
def train_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device) -> Tuple[float, float]:
'''
Performs a training step for a PyTorch model.
Args:
model (torch.nn.Module): PyTorch model that will be trained
dataloader (torch.utils.data.DataLoader): Dataloader containing data to train on
loss_fn (torch.nn.Module): Loss function used as the error metric
optimizer (torch.optim.Optimizer): Optimization method used to update model parameters per batch
device (torch.device): Device to train on
Returns:
train_loss (float): The average loss calculated over the training set.
train_acc (float): The accuracy calculated over the training set.
'''
model.train()
train_loss = torch.tensor(0.0, device = device)
train_acc = torch.tensor(0.0, device = device)
num_samps = len(dataloader.dataset)
# Loop through all batches in the dataloader
for X, y in dataloader:
optimizer.zero_grad() # Clear old accumulated gradients
X, y = X.to(device), y.to(device)
y_logits = model(X) # Get logits
loss = loss_fn(y_logits, y)
train_loss += loss.detach() * X.shape[0] # Calculate total loss for batch
loss.backward() # Perform backpropagation
optimizer.step() # Update parameters
y_pred = y_logits.argmax(dim = 1) # No softmax needed for argmax (b/c preserves order)
train_acc += (y_pred == y).sum() # Calculate total accuracy for batch
# Get average loss and accuracy per sample
train_loss = train_loss.item() / num_samps
train_acc = train_acc.item() / num_samps
return train_loss, train_acc
def test_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
device: torch.device) -> Tuple[float, float]:
'''
Performs a testing step for a PyTorch model.
Args:
model (torch.nn.Module): PyTorch model that will be tested.
dataloader (torch.utils.data.DataLoader): Dataloader containing data to test on.
loss_fn (torch.nn.Module): Loss function used as the error metric.
device (torch.device): Device to compute on.
Returns:
test_loss (float): The average loss calculated over batches.
test_acc (float): The average accuracy calculated over batches.
'''
model.eval()
test_loss = torch.tensor(0.0, device = device)
test_acc = torch.tensor(0.0, device = device)
num_samps = len(dataloader.dataset)
with torch.inference_mode():
# Loop through all batches in the dataloader
for X, y in dataloader:
X, y = X.to(device), y.to(device)
y_logits = model(X) # Get logits
test_loss += loss_fn(y_logits, y) * X.shape[0] # Calculate total loss for batch
y_pred = y_logits.argmax(dim = 1) # No softmax needed for argmax (b/c preserves order)
test_acc += (y_pred == y).sum() # Calculate total accuracy for batch
# Get average loss and accuracy
test_loss = test_loss.item() / num_samps
test_acc = test_acc.item() / num_samps
return test_loss, test_acc
def train(model: torch.nn.Module,
train_dl: torch.utils.data.DataLoader,
test_dl: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
num_epochs: int,
patience: int,
min_delta: float,
device: torch.device,
save_mod: bool = True,
save_dir: str = '',
mod_name: str = '') -> Dict[str, List[float]]:
'''
Performs the training and testing steps for a PyTorch model,
with early stopping applied for test loss.
Args:
model (torch.nn.Module): PyTorch model to train.
train_dl (torch.utils.data.DataLoader): DataLoader for training.
test_dl (torch.utils.data.DataLoader): DataLoader for testing.
loss_fn (torch.nn.Module): Loss function used as the error metric.
optimizer (torch.optim.Optimizer): Optimizer used to update model parameters per batch.
num_epochs (int): Max number of epochs to train.
patience (int): Number of epochs to wait before early stopping.
min_delta (float): Minimum decrease in loss to reset counter.
device (torch.device): Device to train on.
save_mod (bool, optional): If True, saves the model after each epoch. Default is True.
save_dir (str, optional): Directory to save the model to. Must be nonempty if save_mod is True.
mod_name (str, optional): Filename for the saved model. Must be nonempty if save_mod is True.
returns:
res (dict): A results dictionary containing lists of train and test metrics for each epoch.
'''
bold_start, bold_end = '\033[1m', '\033[0m'
if save_mod:
assert save_dir, 'save_dir cannot be None or empty.'
assert mod_name, 'mod_name cannot be None or empty.'
# Initialize results dictionary
res = {'train_loss': [],
'train_acc': [],
'test_loss': [],
'test_acc': []
}
# Initialize best_loss and counter for early stopping
best_loss, counter = None, 0
for epoch in range(num_epochs):
# Perform training and testing step
train_loss, train_acc = train_step(model, train_dl, loss_fn, optimizer, device)
test_loss, test_acc = test_step(model, test_dl, loss_fn, device)
# Store loss and accuracy values
res['train_loss'].append(train_loss)
res['train_acc'].append(train_acc)
res['test_loss'].append(test_loss)
res['test_acc'].append(test_acc)
print(f'Epoch: {epoch + 1} | ' +
f'train_loss = {train_loss:.4f} | train_acc = {train_acc:.4f} | ' +
f'test_loss = {test_loss:.4f} | test_acc = {test_acc:.4f}')
# Check for improvement
if best_loss == None:
best_loss = test_loss
if save_mod:
utils.save_model(model, save_dir, mod_name)
elif test_loss < best_loss - min_delta:
best_loss = test_loss
counter = 0
if save_mod:
utils.save_model(model, save_dir, mod_name)
print(f'{bold_start}[SAVED]{bold_end} Adequate improvement in test loss; model saved.')
else:
counter += 1
if counter > patience:
print(f'{bold_start}[ALERT]{bold_end} No improvement in test loss after {counter} epochs; early stopping triggered.')
break
return res