Spaces:
Sleeping
Sleeping
File size: 4,973 Bytes
b87aa54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
import os
import json
from PIL import Image
from tqdm import tqdm # Import tqdm for loading bar
# Paths (Modify These)
DATASET_DIR = "datasets/images" # Folder containing images
ANNOTATIONS_FILE = "datasets/annotations.json" # Path to COCO JSON
# Define Custom COCO Dataset Class (Without pycocotools)
class CocoDataset(Dataset):
def __init__(self, root, annotation_file, transforms=None):
self.root = root
with open(annotation_file, 'r') as f:
self.coco_data = json.load(f)
self.image_data = {img["id"]: img for img in self.coco_data["images"]}
self.annotations = self.coco_data["annotations"]
self.transforms = transforms
def __len__(self):
return len(self.image_data)
def __getitem__(self, idx):
try:
image_info = self.image_data[idx]
image_path = os.path.join(self.root, image_info["file_name"])
image = Image.open(image_path).convert("RGB")
img_width, img_height = image.size # Get image dimensions
# Get Annotations
annotations = [ann for ann in self.annotations if ann["image_id"] == image_info["id"]]
boxes = []
labels = []
for ann in annotations:
xmin, ymin, xmax, ymax = ann["bbox"] # Now using [xmin, ymin, xmax, ymax]
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = min(img_width, xmax)
ymax = min(img_height, ymax)
if xmax > xmin and ymax > ymin:
boxes.append([xmin, ymin, xmax, ymax])
labels.append(ann["category_id"])
else:
print(f"⚠️ Skipping invalid bbox {ann['bbox']} in image {image_info['file_name']} (image_id: {image_info['id']})")
if len(boxes) == 0:
print(f"⚠️ Skipping entire image {image_info['file_name']} because no valid bounding boxes remain.")
return None, None
# Convert to tensors
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {"boxes": boxes, "labels": labels}
if self.transforms:
image = self.transforms(image)
return image, target
except Exception as e:
print(f"⚠️ Skipping image {image_info['file_name']} due to error: {e}")
return None, None
# Define Image Transformations
transform = T.Compose([T.ToTensor()])
# Load Dataset
full_dataset = CocoDataset(root=DATASET_DIR, annotation_file=ANNOTATIONS_FILE, transforms=transform)
subset_size = min(10000, len(full_dataset)) # Limit dataset to 10,000 samples or less
subset_indices = list(range(subset_size))
dataset = Subset(full_dataset, subset_indices)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*[item for item in x if item[0] is not None])))
# Load Faster R-CNN Model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
# Freeze Backbone Layers
for param in model.backbone.parameters():
param.requires_grad = False
# Modify Classifier Head for Custom Classes
num_classes = 2 # One object class + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
device = torch.device("cpu")
# # Check for MPS Availability
# if torch.backends.mps.is_available():
# print("✅ Using MPS (Apple Metal GPU)")
# device = torch.device("mps")
# else:
# print("⚠️ MPS not available, using CPU")
# device = torch.device("cpu")
model.to(device)
# Training Setup
optimizer = optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 5
# Training Loop
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
print(f"Epoch {epoch+1}/{num_epochs}...")
for images, targets in tqdm(data_loader, desc=f"Training Epoch {epoch+1}"):
images = list(img.to(device) for img in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
if any(len(t["boxes"]) == 0 for t in targets):
print("⚠️ Skipping batch with no valid bounding boxes")
continue
optimizer.zero_grad()
loss_dict = model(images, targets)
loss = sum(loss for loss in loss_dict.values())
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
# Save Trained Model
torch.save(model.state_dict(), "faster_rcnn_custom.pth")
print("Training Complete! Model saved as 'faster_rcnn_custom.pth'")
|