Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import os | |
import boto3 | |
import torch | |
from torch import nn, distributed as dist | |
from torch.nn import functional as F | |
from imaginaire.utils.distributed import is_local_master | |
from imaginaire.utils.io import download_file_from_google_drive | |
def get_segmentation_hist_model(dataset_name, aws_credentials=None): | |
if dist.is_initialized() and not is_local_master(): | |
# Make sure only the first process in distributed training downloads | |
# the model, and the others will use the cache | |
# noinspection PyUnresolvedReferences | |
torch.distributed.barrier() | |
# Load the segmentation network. | |
if dataset_name == "celebamask_hq": | |
from imaginaire.evaluation.segmentation.celebamask_hq import Unet | |
seg_network = Unet() | |
os.makedirs(os.path.join(torch.hub.get_dir(), 'checkpoints'), exist_ok=True) | |
model_path = os.path.join(os.path.join(torch.hub.get_dir(), 'checkpoints'), "celebamask_hq.pt") | |
if not os.path.exists(model_path): | |
if aws_credentials is not None: | |
s3 = boto3.client('s3', **aws_credentials) | |
s3.download_file('lpi-poe', 'model_zoo/celebamask_hq.pt', model_path) | |
else: | |
download_file_from_google_drive("1o1m-eT38zNCIFldcRaoWcLvvBtY8S4W3", model_path) | |
state_dict = torch.load(model_path, map_location='cpu') | |
seg_network.load_state_dict(state_dict) | |
elif dataset_name == "cocostuff" or dataset_name == "getty": | |
from imaginaire.evaluation.segmentation.cocostuff import DeepLabV2 | |
seg_network = DeepLabV2() | |
else: | |
print(f"No segmentation network for {dataset_name} was found.") | |
return None | |
if dist.is_initialized() and is_local_master(): | |
# Make sure only the first process in distributed training downloads | |
# the model, and the others will use the cache | |
# noinspection PyUnresolvedReferences | |
torch.distributed.barrier() | |
if seg_network is not None: | |
seg_network = seg_network.to('cuda').eval() | |
return SegmentationHistModel(seg_network) | |
class SegmentationHistModel(nn.Module): | |
def __init__(self, seg_network): | |
super().__init__() | |
self.seg_network = seg_network | |
def forward(self, data, fake_images, align_corners=True): | |
pred = self.seg_network(fake_images, align_corners=align_corners) | |
gt = data["segmaps"] | |
gt = gt * 255.0 | |
gt = gt.long() | |
# print(fake_images.shape, fake_images.min(), fake_images.max()) | |
# print(gt.shape, gt.min(), gt.max()) | |
# exit() | |
return compute_hist(pred, gt, self.seg_network.n_classes, self.seg_network.use_dont_care) | |
def compute_hist(pred, gt, n_classes, use_dont_care): | |
_, H, W = pred.size() | |
gt = F.interpolate(gt.float(), (H, W), mode="nearest").long().squeeze(1) | |
ignore_idx = n_classes if use_dont_care else -1 | |
all_hist = [] | |
for cur_pred, cur_gt in zip(pred, gt): | |
keep = torch.logical_not(cur_gt == ignore_idx) | |
merge = cur_pred[keep] * n_classes + cur_gt[keep] | |
hist = torch.bincount(merge, minlength=n_classes ** 2) | |
hist = hist.view((n_classes, n_classes)) | |
all_hist.append(hist) | |
all_hist = torch.stack(all_hist) | |
return all_hist | |
def get_miou(hist, eps=1e-8): | |
hist = hist.sum(0) | |
IOUs = torch.diag(hist) / ( | |
torch.sum(hist, dim=0, keepdim=False) + torch.sum(hist, dim=1, keepdim=False) - torch.diag(hist) + eps) | |
mIOU = 100 * torch.mean(IOUs).item() | |
return {"seg_mIOU": mIOU} | |