File size: 1,915 Bytes
86e22bf
 
 
e2855de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86e22bf
e2855de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86e22bf
e2855de
 
 
 
 
 
 
 
 
 
 
 
86e22bf
e2855de
 
 
 
86e22bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from PIL import Image
import xml.etree.ElementTree as ET
import yaml 

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

class LegoDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        # Limit dataset to 10000 images
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".jpg")])[:config["model"]["image_sample_size"]]
    
    def __len__(self):
        return len(self.image_files)
    
    def parse_annotation(self, xml_file):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        boxes = []
        labels = []
        
        for obj in root.findall("object"):
            label = obj.find("name").text
            bbox = obj.find("bndbox")
            xmin = int(bbox.find("xmin").text)
            ymin = int(bbox.find("ymin").text)
            xmax = int(bbox.find("xmax").text)
            ymax = int(bbox.find("ymax").text)
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(1)  # Assuming 'lego' is class 1
        
        return torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        annotation_path = os.path.join(self.annotation_dir, self.image_files[idx].replace(".jpg", ".xml"))
        
        image = Image.open(image_path).convert("RGB")
        boxes, labels = self.parse_annotation(annotation_path)
        
        target = {"boxes": boxes, "labels": labels}
        
        if self.transform:
            image = self.transform(image)
        
        return image, target