File size: 19,731 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import os
from typing import Any

import einops
import mmengine
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from lightning.pytorch.utilities import grad_norm
from mmengine.structures import InstanceData

from mmpl.registry import MODELS
from mmseg.utils import SampleList
from ..builder import build_backbone, build_loss, build_neck, build_head
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F


@MODELS.register_module()
class SegPLer(BasePLer):
    def __init__(self,
                 sam=None,
                 sam_checkpoint='',
                 points_per_side=None,
                 sam_prompt_generator=None,
                 only_img_encoder=False,
                 only_decoder=False,
                 global_prompt=None,
                 need_train_names=None,
                 head=None,
                 with_clip=False,
                 train_head=False,
                 threshold=0.5,
                 ignore_index=255,
                 train_cfg=None,
                 test_cfg=None,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()
        self.need_train_names = need_train_names
        self.ignore_index = ignore_index
        self.threshold = threshold
        self.only_img_encoder = only_img_encoder
        self.only_decoder = only_decoder
        self.global_prompt = global_prompt
        self.train_head = train_head

        if sam is not None:
            if self.only_img_encoder:
                self.sam = sam_model_registry[sam](sam_checkpoint).image_encoder
            elif self.only_decoder:
                self.prompt_encoder = sam_model_registry[sam](sam_checkpoint).prompt_encoder
                self.mask_decoder = sam_model_registry[sam](sam_checkpoint).mask_decoder
            else:
                sam = sam_model_registry[sam](sam_checkpoint, train_head=train_head)
                self.img_encoder = sam.image_encoder
                self.prompt_encoder = sam.prompt_encoder
                self.mask_decoder = sam.mask_decoder
                self.prompt_encoder_no_mask_embed = sam.prompt_encoder.no_mask_embed

        if points_per_side is not None:
            self.point_grids = build_all_layer_point_grids(
                points_per_side, 0, 1)
        if sam_prompt_generator is not None:
            self.sam_prompt_generator = MODELS.build(sam_prompt_generator)
        if head is not None:
            self.head = MODELS.build(head)
        self.with_clip = with_clip
        if global_prompt is not None:
            if with_clip:
                self.logits_prompt = nn.Sequential(
                    nn.Linear(1, 8),
                    nn.ReLU(),
                    nn.Linear(8, 16)
                )
                self.global_prompt = nn.Sequential(
                    nn.Conv2d(768+16, 256, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(256, 256, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(256, 1, kernel_size=3, padding=1),
                )
            else:
                self.global_prompt = nn.Sequential(
                    nn.Conv2d(256, 128, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(128, 1, kernel_size=3, padding=1),
                )

    def setup(self, stage: str) -> None:
        if self.need_train_names is not None:
            self._set_grad(self.need_train_names, noneed_train_names=[])

    def configure_sharded_model(self) -> None:
        if self.trainer.strategy.__class__.__name__ == 'FSDPStrategy':
            from torch.distributed.fsdp.wrap import wrap
            self.sam_prompt_generator = wrap(self.sam_prompt_generator)
            self.img_encoder = wrap(self.img_encoder)
            self.prompt_encoder_no_mask_embed = wrap(self.prompt_encoder_no_mask_embed)
            self.mask_decoder = wrap(self.mask_decoder)
            self.prompt_encoder = wrap(self.prompt_encoder)
            from torch.distributed.fsdp import CPUOffload
            # from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
            # import functools
            # strategy = dict(
            #     type='FSDPStrategy',
            #     cpu_offload=CPUOffload(offload_params=True),
            #     auto_wrap_policy=functools.partial(
            #         size_based_auto_wrap_policy, min_num_params=int(1e8)
            #     )
            #
            # )
        else:
            super().configure_sharded_model()

    def configure_optimizers(self):
        if self.trainer.strategy.__class__.__name__ == 'DeepSpeedStrategy':
            import deepspeed
            # optimizer = deepspeed.runtime.
            optimizer = deepspeed.ops.adam.FusedAdam(self.sam_prompt_generator.parameters(), lr=1e-4)
            # optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(self.sam_prompt_generator.parameters(), lr=1e-4)
            # optimizer = torch.optim.Adam(self.sam_prompt_generator.parameters(), lr=1e-4)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
            return [optimizer], [lr_scheduler]
        else:
            return super().configure_optimizers()

    def init_weights(self):
        import ipdb; ipdb.set_trace()
        pass

    # def on_fit_start(self) -> None:
    #     if hasattr(self, 'train_evaluator'):
    #         self.train_evaluator = self.train_evaluator.to(self.device)
    #     if hasattr(self, 'val_evaluator'):
    #         self.val_evaluator = self.val_evaluator.to(self.device)

    def train(self, mode=True):
        if self.need_train_names is not None:
            return self._set_train_module(mode, self.need_train_names)
        else:
            super().train(mode)
            return self

    def validation_step(self, batch, batch_idx):
        seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0)
        if self.only_img_encoder:
            masks_pred = self.forward_only_img_encoder(batch)
            masks_pred = F.interpolate(masks_pred, size=seg_label.shape[-2:], mode='bilinear',
                                       align_corners=True)
            seg_logits = masks_pred > 0
        elif self.only_decoder:
            cls_logits, masks, n_iou_preds = self.forward_sam_prompt_generator(batch)  # 1x100x2, 1x100x1x256x256, 1x100x1
            masks = masks.squeeze(2)
            masks = F.interpolate(masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True)
            # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds
            seg_logits = self.post_process(cls_logits.detach(), masks.detach())
            seg_logits = seg_logits > self.threshold
        else:
            cls_logits, pred_masks, n_iou_preds = self.forward_sam_prompt_generator_all(
                batch)  # 1x100x2, 1x100x1x256x256, 1x100x1
            pred_masks = pred_masks.squeeze(2)
            pred_masks = F.interpolate(pred_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True)
            # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds
            seg_logits = self.post_process(cls_logits.detach(), pred_masks.detach())
            seg_logits = seg_logits > self.threshold
        # import ipdb; ipdb.set_trace()
        self.val_evaluator.update(seg_logits, seg_label)

    def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any):
        cls_logits, n_img_masks = self.forward(batch)

        seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0)
        seg_label = seg_label.squeeze(1)
        masks = F.interpolate(n_img_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True)
        masks = masks.squeeze(1) > 0
        self.evaluator.update(masks, seg_label)

    def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
        """Perform forward propagation to convert paradigm from MMSegmentation
        to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
        normally. Specifically, ``batch_gt_instances`` would be added.

        Args:
            batch_data_samples (List[:obj:`SegDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_sem_seg`.

        Returns:
            tuple[Tensor]: A tuple contains two lists.

                - batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                    gt_instance. It usually includes ``labels``, each is
                    unique ground truth label id of images, with
                    shape (num_gt, ) and ``masks``, each is ground truth
                    masks of each instances of a image, shape (num_gt, h, w).
                - batch_img_metas (list[dict]): List of image meta information.
        """
        batch_img_metas = []
        batch_gt_instances = []

        for data_sample in batch_data_samples:
            batch_img_metas.append(data_sample.metainfo)
            gt_masks = data_sample.instances_data.long()
            gt_labels = data_sample.instances_label.long()

            instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
            batch_gt_instances.append(instance_data)
        return batch_gt_instances, batch_img_metas

    def training_step(self, batch, batch_idx):
        if self.only_img_encoder:
            masks_pred = self.forward_only_img_encoder(batch)
            seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0)
            masks_pred = F.interpolate(masks_pred, size=seg_label.shape[-2:], mode='bilinear', align_corners=True)
            losses = self.head.loss(masks_pred, seg_label)
            masks_pred_result = masks_pred > 0
            self.train_evaluator.update(masks_pred_result.detach(), seg_label.detach())

        elif self.only_decoder:
            cls_logits, masks, n_iou_preds = self.forward_sam_prompt_generator(batch)  # 1x100x2, 1x100x1x256x256, 1x100x1
            masks = masks.squeeze(2)
            seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0)
            masks = F.interpolate(masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True)
            # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds
            seg_logits = self.post_process(cls_logits.clone().detach(), masks.clone().detach())
            seg_logits = seg_logits > self.threshold
            self.train_evaluator.update(seg_logits, seg_label)

            batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
                batch['data_samples'])

            losses = self.head.loss(cls_logits, masks, batch_gt_instances, batch_img_metas)
        else:
            cls_logits, pred_masks, n_iou_preds = self.forward_sam_prompt_generator_all(
                batch)  # 1x100x2, 1x100x1x256x256, 1x100x1
            pred_masks = pred_masks.squeeze(2)
            if torch.isinf(pred_masks).any() or torch.isnan(pred_masks).any():
                # import ipdb;
                # ipdb.set_trace()
                # raise ValueError('cost is nan in CrossEntropyLossCost')
                print('!!!!!!!!!!!!!!!!!!!!loss is nan or inf!!!!!!!!!!!!!!!!!!')
                return torch.tensor(0.0, requires_grad=True, device=self.device)
            seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0)
            pred_masks = F.interpolate(pred_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True)
            # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds
            seg_logits = self.post_process(cls_logits.clone().detach(), pred_masks.clone().detach())
            seg_logits = seg_logits > self.threshold
            self.train_evaluator.update(seg_logits, seg_label)

            batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
                batch['data_samples'])

            losses = self.head.loss(cls_logits, pred_masks, batch_gt_instances, batch_img_metas)

        parsed_losses, log_vars = self.parse_losses(losses)
        log_vars = {f'train_{k}': v for k, v in log_vars.items()}
        log_vars['loss'] = parsed_losses
        self.log_dict(log_vars, prog_bar=True)
        return log_vars

    def on_before_optimizer_step(self, optimizer) -> None:
        self.log_grad(module=self.sam_prompt_generator)

    def post_process(self, mask_cls_results, mask_pred_results):
        cls_score = F.softmax(mask_cls_results, dim=-1)[..., 1:2]
        mask_pred = mask_pred_results.sigmoid()
        seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
        return seg_logits

    def forward_only_img_encoder(self, batch, *args: Any, **kwargs: Any) -> Any:
        if self.with_clip:
            clip_dense_embs = torch.stack([x.clip_dense_embs for x in batch['data_samples']], dim=0)
            logits_per_images = torch.stack([x.logits_per_image for x in batch['data_samples']], dim=0)
            logits_per_images = self.logits_prompt(logits_per_images)  # Bx576x16
            clip_dense_embs = torch.cat([clip_dense_embs, logits_per_images], dim=-1)
            clip_dense_embs = rearrange(clip_dense_embs, 'b (h w) c -> b c h w', h=int(clip_dense_embs.shape[1]**0.5))
            masks_pred = self.global_prompt(clip_dense_embs)
        else:
            image_embeddings = torch.stack([x.image_embeddings for x in batch['data_samples']], dim=0)
            masks_pred = self.global_prompt(image_embeddings)
        return masks_pred

    def forward_sam_prompt_generator(self, batch, *args: Any, **kwargs: Any) -> Any:
        inner_states = [x.inner_states for x in batch['data_samples']]
        image_embeddings = torch.stack([x.image_embeddings for x in batch['data_samples']], dim=0)

        inner_states_tmp = []
        for idx in range(len(inner_states[0])):
            inner_states_tmp.append(torch.stack([x[idx] for x in inner_states], dim=0).to(image_embeddings.device))

        point_embs, cls_logits = self.sam_prompt_generator(inner_states_tmp)

        # if has points prompt, then get points embeddings
        if hasattr(self, 'point_grids'):
            points_scale = np.array(img.shape[-2:], dtype=np.float32).reshape(1, -1)  # 2,
            points_for_image = self.point_grids[0] * points_scale
            in_points = torch.as_tensor(points_for_image, device=img.device)
            in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
            in_points = rearrange(in_points, 'n c -> n () c')
            in_labels = rearrange(in_labels, 'n -> n ()')
            points = (in_points, in_labels)

            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
                points=points,
                boxes=None,
                masks=None,
            )  # 1024x2x256; 1024x256x64x64
        else:
            # ponits_embeddings B T N C
            sparse_embeddings = point_embs
            dense_embeddings = self.prompt_encoder.no_mask_embed.weight.view(1, 1, -1, 1, 1).expand(
                sparse_embeddings.shape[0], sparse_embeddings.shape[1], -1,
                self.prompt_encoder.image_embedding_size[0], self.prompt_encoder.image_embedding_size[1]
                )


        n_img_masks = []
        n_iou_preds = []
        n_class_aware_probs = []
        for curr_img_embedding, cur_s_emb, cur_d_emb in zip(image_embeddings, sparse_embeddings, dense_embeddings):
            lr_masks, iou_pred, class_aware_prob = self.mask_decoder(
                image_embeddings=curr_img_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=cur_s_emb,
                dense_prompt_embeddings=cur_d_emb
            )
            mask_slice = slice(0, 1)
            masks = lr_masks[:, mask_slice, :, :]
            iou_pred = iou_pred[:, mask_slice]
            class_aware_prob = class_aware_prob[:, mask_slice]

            n_img_masks.append(masks)
            n_iou_preds.append(iou_pred)
        n_img_masks = torch.stack(n_img_masks, dim=0)
        n_iou_preds = torch.stack(n_iou_preds, dim=0)

        return cls_logits, n_img_masks, n_iou_preds

    def forward_sam_prompt_generator_all(self, batch, *args: Any, **kwargs: Any) -> Any:
        x = torch.stack(batch['inputs'], dim=0)
        # if self.local_rank == 0:
        #     import pdb; pdb.set_trace()
        # self.trainer.strategy.barrier()
        x = x[:, [2, 1, 0], :, :]  # BGR -> RGB
        x = (x - self.img_encoder.pixel_mean) / self.img_encoder.pixel_std
        with torch.no_grad():
            image_embeddings, inner_states = self.img_encoder(x)

        point_embs, cls_logits = self.sam_prompt_generator(inner_states)

        # if has points prompt, then get points embeddings
        if hasattr(self, 'point_grids'):
            points_scale = np.array(img.shape[-2:], dtype=np.float32).reshape(1, -1)  # 2,
            points_for_image = self.point_grids[0] * points_scale
            in_points = torch.as_tensor(points_for_image, device=img.device)
            in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
            in_points = rearrange(in_points, 'n c -> n () c')
            in_labels = rearrange(in_labels, 'n -> n ()')
            points = (in_points, in_labels)

            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
                points=points,
                boxes=None,
                masks=None,
            )  # 1024x2x256; 1024x256x64x64
        else:
            # ponits_embeddings B T N C
            sparse_embeddings = point_embs
            dense_embeddings = self.prompt_encoder_no_mask_embed(torch.tensor([0], device=self.device)).view(1, 1, -1, 1, 1).expand(
                sparse_embeddings.shape[0], sparse_embeddings.shape[1], -1,
                image_embeddings.shape[-2], image_embeddings.shape[-1]
                )


        n_img_masks = []
        n_iou_preds = []
        n_class_aware_probs = []
        for curr_img_embedding, cur_s_emb, cur_d_emb in zip(image_embeddings, sparse_embeddings, dense_embeddings):
            lr_masks, iou_pred, class_aware_prob = self.mask_decoder(
                image_embeddings=curr_img_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=cur_s_emb,
                dense_prompt_embeddings=cur_d_emb
            )
            if self.train_head:
                masks = lr_masks
                iou_pred = iou_pred
            else:
                mask_slice = slice(0, 1)
                masks = lr_masks[:, mask_slice, :, :]
                iou_pred = iou_pred[:, mask_slice]

            n_img_masks.append(masks)
            n_iou_preds.append(iou_pred)
        n_img_masks = torch.stack(n_img_masks, dim=0)
        n_iou_preds = torch.stack(n_iou_preds, dim=0)

        return cls_logits, n_img_masks, n_iou_preds

    def vis_inter_states(self, batch, masks, *args: Any, **kwargs: Any):
        folder = 'results/tmp'
        import cv2
        cv2.imwrite(os.path.join(folder, f'img.png'), batch['inputs'][0].permute((1, 2, 0)).detach().cpu().numpy())
        cv2.imwrite(os.path.join(folder, f'label_mask.png'), seg_label[0][0].detach().cpu().numpy() * 255)
        masks = masks > 0
        for idx, mask_pred in enumerate(masks[0]):
            cv2.imwrite(os.path.join(folder, f'pred_mask_{idx}.png'), mask_pred[0].detach().cpu().numpy() * 255)
        import ipdb; ipdb.set_trace()