import torchvision.transforms as T def get_transforms(phase="train"): if phase == "train": return T.Compose([ T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.RandomRotation(20), T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: return T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])