Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.models as models | |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
from torch.utils.data import DataLoader | |
import torch.optim as optim | |
import matplotlib.pyplot as plt | |
from dataset import PlaneDataset, transform | |
from torchvision.ops import box_iou | |
import numpy as np | |
# Load dataset | |
dataset = PlaneDataset("Images", "Annotations", transform=transform) | |
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x))) | |
# Load pre-trained Faster R-CNN model | |
model = models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) | |
# Replace the classifier for detecting planes | |
num_classes = 2 # 1 for plane + 1 for background | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
optimizer = optim.Adam(model.parameters(), lr=0.0001) | |
# Track statistics | |
train_losses = [] | |
mAPs = [] | |
# Function to compute mAP (mean Average Precision) | |
def compute_mAP(model, dataloader, device): | |
model.eval() | |
iou_threshold = 0.5 | |
all_precisions = [] | |
with torch.no_grad(): | |
for images, targets in dataloader: | |
images = [img.to(device) for img in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
preds = model(images) | |
for pred, target in zip(preds, targets): | |
pred_boxes = pred["boxes"] | |
pred_scores = pred["scores"] | |
gt_boxes = target["boxes"] | |
if len(pred_boxes) == 0 or len(gt_boxes) == 0: | |
continue | |
ious = box_iou(pred_boxes, gt_boxes) | |
correct = (ious.max(dim=1).values > iou_threshold).float() | |
precision = correct.sum() / max(len(pred_boxes), 1) | |
all_precisions.append(precision.item()) | |
return np.mean(all_precisions) if all_precisions else 0.0 | |
# Training loop with statistics logging | |
num_epochs = 5 | |
plt.ion() # Turn on interactive mode for live plotting | |
for epoch in range(num_epochs): | |
model.train() | |
total_loss = 0 | |
for images, targets in dataloader: | |
images = [img.to(device) for img in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
optimizer.zero_grad() | |
loss_dict = model(images, targets) | |
loss = sum(loss for loss in loss_dict.values()) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
# Compute and log statistics | |
avg_loss = total_loss / len(dataloader) | |
train_losses.append(avg_loss) | |
mAP = compute_mAP(model, dataloader, device) | |
mAPs.append(mAP) | |
print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | mAP: {mAP:.4f}") | |
# Live Plot Training Progress | |
plt.figure(figsize=(10, 5)) | |
plt.clf() | |
plt.subplot(1, 2, 1) | |
plt.plot(train_losses, label="Loss") | |
plt.xlabel("Epoch") | |
plt.ylabel("Loss") | |
plt.legend() | |
plt.title("Training Loss") | |
plt.subplot(1, 2, 2) | |
plt.plot(mAPs, label="mAP") | |
plt.xlabel("Epoch") | |
plt.ylabel("mAP") | |
plt.legend() | |
plt.title("Mean Average Precision") | |
plt.pause(0.1) | |
# Save model | |
torch.save(model.state_dict(), "models/plane_detector.pth") | |
plt.ioff() # Turn off interactive mode | |
plt.show() | |
plt.savefig("plots/training_progress.png") # Show final plots | |