File size: 6,531 Bytes
5d2f37a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from PIL import Image
import random
import torch.nn.functional as F
class CustomDataset(Dataset):
def __init__(self, red_dir, green_dir, blue_dir, nir_dir, mask_dir, pytorch=True):
super().__init__()
self.red_dir = red_dir
self.green_dir = green_dir
self.blue_dir = blue_dir
self.nir_dir = nir_dir
self.mask_dir = mask_dir
red_files = [f for f in self.red_dir.iterdir() if f.is_file()]
self.files = [self.combine_files(f) for f in red_files]
self.pytorch = pytorch
def combine_files(self, red_files: Path):
base_name = red_files.name
files = {
'red': red_files,
'green': self.green_dir / base_name.replace('red', 'green'),
'blue': self.blue_dir / base_name.replace('red', 'blue'),
'nir': self.nir_dir / base_name.replace('red', 'nir'),
'mask': self.mask_dir / base_name.replace('red', 'gt'),
}
for key, path in files.items():
if not path.exists():
raise FileNotFoundError(f'Missing file: {path} for {red_files}')
return files
def __len__(self):
return len(self.files)
def open_as_array(self, idx, invert=False, nir_included=False):
rgb = np.stack([
np.array(Image.open(self.files[idx]['red'])),
np.array(Image.open(self.files[idx]['green'])),
np.array(Image.open(self.files[idx]['blue']))
], axis=2)
if nir_included:
nir = np.array(Image.open(self.files[idx]['nir']))
nir = np.expand_dims(nir, 2)
rgb = np.concatenate([rgb, nir], axis=2)
if invert:
rgb = rgb.transpose((2, 0, 1))
raw_rgb = (rgb / np.iinfo(rgb.dtype).max)
return raw_rgb
def open_mask(self,idx, expand_dims=True):
raw_mask = np.array(Image.open(self.files[idx]['mask']))
raw_mask = np.where(raw_mask == 255, 1, 0) # Transform the mask into binary array where pixels with value 256(white) become 1(clouds), pixels with 0 or anything else becomes 0(not clouds)
return np.expand_dims(raw_mask, 0) if expand_dims else raw_mask
def __getitem__(self, idx):
X = torch.tensor(self.open_as_array(idx, invert=True, nir_included=True), dtype=torch.float32)
y = torch.tensor(self.open_mask(idx, expand_dims=True), dtype=torch.float32)
return X, y
class doubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
class downSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = doubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
down = self.conv(x)
p = self.pool(down)
return down, p
class upSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.conv = doubleConv(out_channels * 2, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], 1)
return self.conv(x)
class SpatialAttention(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, padding=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_pooling = torch.mean(x, dim=1, keepdim=True)
max_pooling = torch.max(x, dim=1, keepdim=True)[0] # return on max values and not their indices
concat = torch.cat([avg_pooling, max_pooling], dim=1)
attention = self.conv(concat)
attention = self.sigmoid(attention)
output = x * attention
return output
class UNet(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.down_conv1 = downSample(in_channels, 32)
self.down_conv2 = downSample(32, 64)
self.down_conv3 = downSample(64, 128)
self.bottleneck = doubleConv(128, 256)
self.spatial_attention = SpatialAttention()
self.up_conv1 = upSample(256, 128)
self.up_conv2 = upSample(128, 64)
self.up_conv3 = upSample(64, 32)
self.out = nn.Conv2d(in_channels=32 , out_channels=num_classes, kernel_size=1)
def forward(self, x):
down1, p1 = self.down_conv1(x)
down2, p2 = self.down_conv2(p1)
down3, p3 = self.down_conv3(p2)
b = self.bottleneck(p3)
b = self.spatial_attention(b)
up1 = self.up_conv1(b, down3)
up2 = self.up_conv2(up1, down2)
up3 = self.up_conv3(up2, down1)
output = self.out(up3)
return output
def acc_fn(predb, yb):
preds = torch.sigmoid(predb) # Convert logits to probabilities
preds = (preds > 0.5).float() # Threshold at 0.5
return (preds == yb).float().mean() # Compare with ground truth
def calculate_metrics(y_true, y_pred):
TP = torch.sum((y_true == 1) & (y_pred == 1)).float()
TN = torch.sum((y_true == 0) & (y_pred == 0)).float()
FP = torch.sum((y_true == 0) & (y_pred == 1)).float()
FN = torch.sum((y_true == 1) & (y_pred == 0)).float()
jaccard = TP / (TP + FN + FP + 1e-10)
precision = TP / (TP + FP + 1e-10)
recall = TP / (TP + FN + 1e-10)
specificity = TN / (TN + FP + 1e-10)
overall_acc = (TP + TN) / (TP + TN + FP + FN + 1e-10)
return {
"Jaccard index": jaccard.item(),
"Precision": precision.item(),
"Recall": recall.item(),
"Specificity": specificity.item(),
"Overall Accuracy": overall_acc.item()
}
|