File size: 3,584 Bytes
86e22bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torchvision.models as models
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from dataset import PlaneDataset, transform
from torchvision.ops import box_iou
import numpy as np

# Load dataset
dataset = PlaneDataset("Images", "Annotations", transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# Load pre-trained Faster R-CNN model
model = models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)

# Replace the classifier for detecting planes
num_classes = 2  # 1 for plane + 1 for background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Track statistics
train_losses = []
mAPs = []

# Function to compute mAP (mean Average Precision)
def compute_mAP(model, dataloader, device):
    model.eval()
    iou_threshold = 0.5
    all_precisions = []

    with torch.no_grad():
        for images, targets in dataloader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            preds = model(images)
            
            for pred, target in zip(preds, targets):
                pred_boxes = pred["boxes"]
                pred_scores = pred["scores"]
                gt_boxes = target["boxes"]

                if len(pred_boxes) == 0 or len(gt_boxes) == 0:
                    continue

                ious = box_iou(pred_boxes, gt_boxes)
                correct = (ious.max(dim=1).values > iou_threshold).float()
                precision = correct.sum() / max(len(pred_boxes), 1)
                all_precisions.append(precision.item())

    return np.mean(all_precisions) if all_precisions else 0.0

# Training loop with statistics logging
num_epochs = 5
plt.ion()  # Turn on interactive mode for live plotting

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for images, targets in dataloader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Compute and log statistics
    avg_loss = total_loss / len(dataloader)
    train_losses.append(avg_loss)
    mAP = compute_mAP(model, dataloader, device)
    mAPs.append(mAP)

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | mAP: {mAP:.4f}")

    # Live Plot Training Progress
    plt.figure(figsize=(10, 5))
    plt.clf()
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training Loss")

    plt.subplot(1, 2, 2)
    plt.plot(mAPs, label="mAP")
    plt.xlabel("Epoch")
    plt.ylabel("mAP")
    plt.legend()
    plt.title("Mean Average Precision")

    plt.pause(0.1)

# Save model
torch.save(model.state_dict(), "models/plane_detector.pth")
plt.ioff()  # Turn off interactive mode
plt.show() 
plt.savefig("plots/training_progress.png") # Show final plots