Alex Hortua
Adding Skeletong for detection
86e22bf
raw
history blame
3.58 kB
import torch
import torchvision.models as models
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from dataset import PlaneDataset, transform
from torchvision.ops import box_iou
import numpy as np
# Load dataset
dataset = PlaneDataset("Images", "Annotations", transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
# Load pre-trained Faster R-CNN model
model = models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# Replace the classifier for detecting planes
num_classes = 2 # 1 for plane + 1 for background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Track statistics
train_losses = []
mAPs = []
# Function to compute mAP (mean Average Precision)
def compute_mAP(model, dataloader, device):
model.eval()
iou_threshold = 0.5
all_precisions = []
with torch.no_grad():
for images, targets in dataloader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
preds = model(images)
for pred, target in zip(preds, targets):
pred_boxes = pred["boxes"]
pred_scores = pred["scores"]
gt_boxes = target["boxes"]
if len(pred_boxes) == 0 or len(gt_boxes) == 0:
continue
ious = box_iou(pred_boxes, gt_boxes)
correct = (ious.max(dim=1).values > iou_threshold).float()
precision = correct.sum() / max(len(pred_boxes), 1)
all_precisions.append(precision.item())
return np.mean(all_precisions) if all_precisions else 0.0
# Training loop with statistics logging
num_epochs = 5
plt.ion() # Turn on interactive mode for live plotting
for epoch in range(num_epochs):
model.train()
total_loss = 0
for images, targets in dataloader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
optimizer.zero_grad()
loss_dict = model(images, targets)
loss = sum(loss for loss in loss_dict.values())
loss.backward()
optimizer.step()
total_loss += loss.item()
# Compute and log statistics
avg_loss = total_loss / len(dataloader)
train_losses.append(avg_loss)
mAP = compute_mAP(model, dataloader, device)
mAPs.append(mAP)
print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | mAP: {mAP:.4f}")
# Live Plot Training Progress
plt.figure(figsize=(10, 5))
plt.clf()
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss")
plt.subplot(1, 2, 2)
plt.plot(mAPs, label="mAP")
plt.xlabel("Epoch")
plt.ylabel("mAP")
plt.legend()
plt.title("Mean Average Precision")
plt.pause(0.1)
# Save model
torch.save(model.state_dict(), "models/plane_detector.pth")
plt.ioff() # Turn off interactive mode
plt.show()
plt.savefig("plots/training_progress.png") # Show final plots