Jechen00 commited on
Commit
2565d4d
·
verified ·
1 Parent(s): 4e7192d

Data transforms now use v2

Browse files
Files changed (1) hide show
  1. model_training/data_setup.py +15 -11
model_training/data_setup.py CHANGED
@@ -1,7 +1,9 @@
1
  #####################################
2
  # Packages & Dependencies
3
  #####################################
4
- from torchvision import transforms, datasets
 
 
5
  from torch.utils.data import DataLoader
6
 
7
  import utils
@@ -14,18 +16,20 @@ import numpy as np
14
  import matplotlib.pyplot as plt
15
 
16
  # Transformations applied to each image
17
- BASE_TRANSFORMS = transforms.Compose([
18
- transforms.ToTensor(), # Convert to tensor and rescale pixel values to within [0, 1]
19
- transforms.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats
 
20
  ])
21
 
22
- TRAIN_TRANSFORMS = transforms.Compose([
23
- transforms.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees
24
- scale = (0.8, 1.2), # Scale between 80 and 120 percent
25
- translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y
26
- shear = 10), # Shear up to -/+ 10 degrees
27
- transforms.ToTensor(), # Convert to tensor and rescale pixel values to within [0, 1]
28
- transforms.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats
 
29
  ])
30
 
31
 
 
1
  #####################################
2
  # Packages & Dependencies
3
  #####################################
4
+ import torch
5
+ from torchvision import datasets
6
+ from torchvision.transforms import v2
7
  from torch.utils.data import DataLoader
8
 
9
  import utils
 
16
  import matplotlib.pyplot as plt
17
 
18
  # Transformations applied to each image
19
+ BASE_TRANSFORMS = v2.Compose([
20
+ v2.ToImage(), # Convert to tensor
21
+ v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1]
22
+ v2.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats
23
  ])
24
 
25
+ TRAIN_TRANSFORMS = v2.Compose([
26
+ v2.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees
27
+ scale = (0.8, 1.2), # Scale between 80 and 120 percent
28
+ translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y
29
+ shear = 10), # Shear up to -/+ 10 degrees
30
+ v2.ToImage(), # Convert to tensor
31
+ v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1]
32
+ v2.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats
33
  ])
34
 
35