File size: 3,737 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import os

import boto3
import torch
from torch import nn, distributed as dist
from torch.nn import functional as F

from imaginaire.utils.distributed import is_local_master
from imaginaire.utils.io import download_file_from_google_drive


def get_segmentation_hist_model(dataset_name, aws_credentials=None):
    if dist.is_initialized() and not is_local_master():
        # Make sure only the first process in distributed training downloads
        # the model, and the others will use the cache
        # noinspection PyUnresolvedReferences
        torch.distributed.barrier()

    # Load the segmentation network.
    if dataset_name == "celebamask_hq":
        from imaginaire.evaluation.segmentation.celebamask_hq import Unet
        seg_network = Unet()
        os.makedirs(os.path.join(torch.hub.get_dir(), 'checkpoints'), exist_ok=True)
        model_path = os.path.join(os.path.join(torch.hub.get_dir(), 'checkpoints'), "celebamask_hq.pt")
        if not os.path.exists(model_path):
            if aws_credentials is not None:
                s3 = boto3.client('s3', **aws_credentials)
                s3.download_file('lpi-poe', 'model_zoo/celebamask_hq.pt', model_path)
            else:
                download_file_from_google_drive("1o1m-eT38zNCIFldcRaoWcLvvBtY8S4W3", model_path)
        state_dict = torch.load(model_path, map_location='cpu')
        seg_network.load_state_dict(state_dict)
    elif dataset_name == "cocostuff" or dataset_name == "getty":
        from imaginaire.evaluation.segmentation.cocostuff import DeepLabV2
        seg_network = DeepLabV2()
    else:
        print(f"No segmentation network for {dataset_name} was found.")
        return None

    if dist.is_initialized() and is_local_master():
        # Make sure only the first process in distributed training downloads
        # the model, and the others will use the cache
        # noinspection PyUnresolvedReferences
        torch.distributed.barrier()

    if seg_network is not None:
        seg_network = seg_network.to('cuda').eval()

    return SegmentationHistModel(seg_network)


class SegmentationHistModel(nn.Module):
    def __init__(self, seg_network):
        super().__init__()
        self.seg_network = seg_network

    def forward(self, data, fake_images, align_corners=True):
        pred = self.seg_network(fake_images, align_corners=align_corners)
        gt = data["segmaps"]
        gt = gt * 255.0
        gt = gt.long()
        # print(fake_images.shape, fake_images.min(), fake_images.max())
        # print(gt.shape, gt.min(), gt.max())
        # exit()
        return compute_hist(pred, gt, self.seg_network.n_classes, self.seg_network.use_dont_care)


def compute_hist(pred, gt, n_classes, use_dont_care):
    _, H, W = pred.size()
    gt = F.interpolate(gt.float(), (H, W), mode="nearest").long().squeeze(1)
    ignore_idx = n_classes if use_dont_care else -1
    all_hist = []
    for cur_pred, cur_gt in zip(pred, gt):
        keep = torch.logical_not(cur_gt == ignore_idx)
        merge = cur_pred[keep] * n_classes + cur_gt[keep]
        hist = torch.bincount(merge, minlength=n_classes ** 2)
        hist = hist.view((n_classes, n_classes))
        all_hist.append(hist)
    all_hist = torch.stack(all_hist)
    return all_hist


def get_miou(hist, eps=1e-8):
    hist = hist.sum(0)
    IOUs = torch.diag(hist) / (
            torch.sum(hist, dim=0, keepdim=False) + torch.sum(hist, dim=1, keepdim=False) - torch.diag(hist) + eps)
    mIOU = 100 * torch.mean(IOUs).item()
    return {"seg_mIOU": mIOU}