Data transforms now use v2
Browse files- model_training/data_setup.py +15 -11
model_training/data_setup.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
#####################################
|
2 |
# Packages & Dependencies
|
3 |
#####################################
|
4 |
-
|
|
|
|
|
5 |
from torch.utils.data import DataLoader
|
6 |
|
7 |
import utils
|
@@ -14,18 +16,20 @@ import numpy as np
|
|
14 |
import matplotlib.pyplot as plt
|
15 |
|
16 |
# Transformations applied to each image
|
17 |
-
BASE_TRANSFORMS =
|
18 |
-
|
19 |
-
|
|
|
20 |
])
|
21 |
|
22 |
-
TRAIN_TRANSFORMS =
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
])
|
30 |
|
31 |
|
|
|
1 |
#####################################
|
2 |
# Packages & Dependencies
|
3 |
#####################################
|
4 |
+
import torch
|
5 |
+
from torchvision import datasets
|
6 |
+
from torchvision.transforms import v2
|
7 |
from torch.utils.data import DataLoader
|
8 |
|
9 |
import utils
|
|
|
16 |
import matplotlib.pyplot as plt
|
17 |
|
18 |
# Transformations applied to each image
|
19 |
+
BASE_TRANSFORMS = v2.Compose([
|
20 |
+
v2.ToImage(), # Convert to tensor
|
21 |
+
v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1]
|
22 |
+
v2.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats
|
23 |
])
|
24 |
|
25 |
+
TRAIN_TRANSFORMS = v2.Compose([
|
26 |
+
v2.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees
|
27 |
+
scale = (0.8, 1.2), # Scale between 80 and 120 percent
|
28 |
+
translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y
|
29 |
+
shear = 10), # Shear up to -/+ 10 degrees
|
30 |
+
v2.ToImage(), # Convert to tensor
|
31 |
+
v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1]
|
32 |
+
v2.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats
|
33 |
])
|
34 |
|
35 |
|