radio-tiramisu / data /rg_masks.py
rsortino's picture
Changed folder name
59de7a1
import json
import math
import random
import warnings
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from astropy.io import fits
from astropy.io.fits.verify import VerifyWarning
from einops import rearrange
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid, save_image
warnings.simplefilter('ignore', category=VerifyWarning)
import warnings
import numpy as np
import torch
from astropy.stats import sigma_clip
from astropy.visualization import ZScaleInterval
from torch.utils.data import DataLoader
warnings.simplefilter('ignore', category=VerifyWarning)
CLASSES = ['background', 'spurious', 'compact', 'extended']
COLORS = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
def get_transforms(img_size):
return T.Compose([
RemoveNaNs(),
ZScale(),
SigmaClip(),
ToTensor(),
torch.nn.Tanh(),
MinMaxNormalize(),
Unsqueeze(),
T.Resize((img_size, img_size)),
RepeatChannels((3))
])
class RemoveNaNs(object):
def __init__(self):
pass
def __call__(self, img):
img[np.isnan(img)] = 0
return img
class ZScale(object):
def __init__(self, contrast=0.15):
self.contrast = contrast
def __call__(self, img):
interval = ZScaleInterval(contrast=self.contrast)
min, max = interval.get_limits(img)
img = (img - min) / (max - min)
return img
class SigmaClip(object):
def __init__(self, sigma=3, masked=True):
self.sigma = sigma
self.masked = masked
def __call__(self, img):
img = sigma_clip(img, sigma=self.sigma, masked=self.masked)
return img
class MinMaxNormalize(object):
def __init__(self):
pass
def __call__(self, img):
img = (img - img.min()) / (img.max() - img.min())
return img
class ToTensor(object):
def __init__(self):
pass
def __call__(self, img):
return torch.tensor(img, dtype=torch.float32)
class RepeatChannels(object):
def __init__(self, ch):
self.ch = ch
def __call__(self, img):
return img.repeat(1, self.ch, 1, 1)
class FromNumpy(object):
def __init__(self):
pass
def __call__(self, img):
return torch.from_numpy(img.astype(np.float32)).type(torch.float32)
class Unsqueeze(object):
def __init__(self):
pass
def __call__(self, img):
return img.unsqueeze(0)
def mask_to_rgb(mask):
rgb_mask = torch.zeros_like(mask, device=mask.device).repeat(1, 3, 1, 1)
for i, c in enumerate(COLORS):
color_mask = torch.tensor(c, device=mask.device).unsqueeze(
1).unsqueeze(2) * (mask == i)
rgb_mask += color_mask
return rgb_mask
def get_data_loader(dataset, batch_size, split="train"):
batch_size = batch_size
workers = min(8, batch_size)
is_train = split == "train"
return DataLoader(dataset, shuffle=is_train, batch_size=batch_size,
num_workers=workers, persistent_workers=True,
drop_last=is_train
)
def rgb_to_tensor(mask):
r,g,b = mask
r *= 1
g *= 2
b *= 3
mask, _ = torch.max(torch.stack([r,g,b]), dim=0, keepdim=True)
return mask
def rand_horizontal_flip(img, mask):
if random.random() < 0.5:
img = TF.hflip(img)
mask = TF.hflip(mask)
return img, mask
class RGDataset(Dataset):
def __init__(self, data_dir, img_paths, img_size=128):
super().__init__()
data_dir = Path(data_dir)
with open(img_paths) as f:
self.img_paths = f.read().splitlines()
self.img_paths = [data_dir / p for p in self.img_paths]
self.transforms = T.Compose([
RemoveNaNs(),
ZScale(),
SigmaClip(),
ToTensor(),
torch.nn.Tanh(),
MinMaxNormalize(),
# T.Resize((img_size),
# interpolation=T.InterpolationMode.NEAREST),
Unsqueeze(),
T.Resize((img_size, img_size)),
RepeatChannels((3))
])
self.img_size = img_size
self.mask_transforms = T.Compose([
FromNumpy(),
Unsqueeze(),
T.Resize((img_size, img_size),
interpolation=T.InterpolationMode.NEAREST),
])
def get_mask(self, img_path, type):
assert type in ["real", "synthetic"], f"Type {type} not supported"
if type == "real":
ann_path = str(img_path).replace(
'imgs', 'masks').replace('.fits', '.json')
ann_dir = Path(ann_path).parent
ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
with open(ann_path) as j:
mask_info = json.load(j)
masks = []
for obj in mask_info['objs']:
seg_path = ann_dir / obj['mask']
mask = fits.getdata(seg_path)
mask = self.mask_transforms(mask.astype(np.float32))
masks.append(mask)
mask, _ = torch.max(torch.stack(masks), dim=0)
elif type == "synthetic":
mask_path = str(img_path).replace("gen_fits", "cond_fits")
mask = fits.getdata(mask_path)
mask = self.mask_transforms(mask)
mask = mask.squeeze()
if mask.shape[0] == 3:
mask = rgb_to_tensor(mask)
return mask
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image_path = self.img_paths[idx]
img = fits.getdata(image_path)
img = self.transforms(img)
if "synthetic" in str(image_path):
mask = self.get_mask(image_path, type='synthetic')
else:
mask = self.get_mask(image_path, type='real')
# ann_path = str(image_path).replace(
# 'imgs', 'masks').replace('.fits', '.json')
# ann_dir = Path(ann_path).parent
# ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
# with open(ann_path) as j:
# mask_info = json.load(j)
# masks = []
# for obj in mask_info['objs']:
# seg_path = ann_dir / obj['mask']
# mask = fits.getdata(seg_path)
# mask = self.mask_transforms(mask.astype(np.float32))
# masks.append(mask)
# if 'bkg' in str(image_path):
# mask = torch.zeros_like(img)
# masks.append(mask)
# mask, _ = torch.max(torch.stack(masks), dim=0)
mask = mask.long()
return img.squeeze(), mask.squeeze()
class SyntheticRGDataset(Dataset):
def __init__(self, data_dir, img_paths, img_size=128):
super().__init__()
data_dir = Path(data_dir)
with open(img_paths) as f:
self.img_paths = f.read().splitlines()
self.img_paths = [data_dir / p for p in self.img_paths]
self.transforms = T.Compose([
RemoveNaNs(),
ZScale(),
SigmaClip(),
ToTensor(),
torch.nn.Tanh(),
MinMaxNormalize(),
# T.Resize((img_size),
# interpolation=T.InterpolationMode.NEAREST),
Unsqueeze(),
T.Resize((img_size, img_size)),
RepeatChannels((3))
])
self.img_size = img_size
self.mask_transforms = T.Compose([
FromNumpy(),
Unsqueeze(),
T.Resize((img_size, img_size),
interpolation=T.InterpolationMode.NEAREST),
])
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image_path = self.img_paths[idx]
img = fits.getdata(image_path)
img = self.transforms(img)
img = img.squeeze()
mask_path = str(image_path).replace("gen_fits", "cond_fits")
mask = fits.getdata(mask_path)
mask = self.mask_transforms(mask)
img, mask = rand_horizontal_flip(img, mask)
mask = mask.squeeze().long()
return img, mask
if __name__ == '__main__':
rgtrain = SyntheticRGDataset('data/rg-dataset/data',
'data/rg-dataset/val_w_bg.txt')
batch = next(iter(rgtrain))
image, mask, masked_image = batch
to_pil_image(image).save('image.png')
rgb_mask = mask_to_rgb(mask)[0]
to_pil_image(rgb_mask).save('mask.png')
to_pil_image(masked_image[0]).save('masked.png')
bs = 256
loader = torch.utils.data.DataLoader(
rgtrain, batch_size=bs, shuffle=False, num_workers=16)
for i, batch in enumerate(loader):
image, mask, masked_image = batch
rgb_mask = mask_to_rgb(mask)
nrow = int(math.sqrt(bs))
# nrow = bs // 2
grid = make_grid(rgb_mask, nrow=nrow, padding=0)
save_image(grid, f'mask_{nrow}x{nrow}.png')
break