File size: 1,418 Bytes
09823ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torchvision.transforms as T

class AugmentationPipeline:
    """
    Data augmentation and preprocessing transformations for CropGuard.
    """

    def __init__(self):
        # Mean and Std from ImageNet (can be adjusted later if needed)
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

        # Define transformations
        self.train_transforms = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.RandomRotation(degrees=30),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.ToTensor(),
            T.Normalize(mean=self.mean, std=self.std)
        ])

        self.val_transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=self.mean, std=self.std)
        ])

        self.test_transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=self.mean, std=self.std)
        ])

    def get_transforms(self, phase="train"):
        """
        Returns the appropriate transformation based on phase.
        """
        if phase == "train":
            return self.train_transforms
        elif phase == "val":
            return self.val_transforms
        elif phase == "test":
            return self.test_transforms
        else:
            raise ValueError(f"Unknown phase: {phase}. Use 'train', 'val', or 'test'.")