Spaces:
Paused
Paused
from typing import Tuple, Optional, Dict | |
import logging | |
import os | |
import shutil | |
from os import path | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import pycocotools.mask as mask_util | |
from threading import Thread | |
from queue import Queue | |
from dataclasses import dataclass | |
import copy | |
from tracker.utils.pano_utils import ID2RGBConverter | |
from tracker.utils.palette import davis_palette_np | |
from tracker.inference.object_manager import ObjectManager | |
from tracker.inference.object_info import ObjectInfo | |
log = logging.getLogger() | |
try: | |
import hickle as hkl | |
except ImportError: | |
log.warning('Failed to import hickle. Fine if not using multi-scale testing.') | |
class ResultSaver: | |
def __init__(self, | |
output_root, | |
video_name, | |
*, | |
dataset, | |
object_manager: ObjectManager, | |
use_long_id, | |
palette=None, | |
save_mask=True, | |
save_scores=False, | |
score_output_root=None, | |
visualize_output_root=None, | |
visualize=False, | |
init_json=None): | |
self.output_root = output_root | |
self.video_name = video_name | |
self.dataset = dataset.lower() | |
self.use_long_id = use_long_id | |
self.palette = palette | |
self.object_manager = object_manager | |
self.save_mask = save_mask | |
self.save_scores = save_scores | |
self.score_output_root = score_output_root | |
self.visualize_output_root = visualize_output_root | |
self.visualize = visualize | |
if self.visualize: | |
if self.palette is not None: | |
self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3) | |
else: | |
self.colors = davis_palette_np | |
self.need_remapping = True | |
self.json_style = None | |
self.id2rgb_converter = ID2RGBConverter() | |
if 'burst' in self.dataset: | |
assert init_json is not None | |
self.input_segmentations = init_json['segmentations'] | |
self.segmentations = [{} for _ in init_json['segmentations']] | |
self.annotated_frames = init_json['annotated_image_paths'] | |
self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'} | |
self.video_json['segmentations'] = self.segmentations | |
self.json_style = 'burst' | |
self.queue = Queue(maxsize=10) | |
self.thread = Thread(target=save_result, args=(self.queue, )) | |
self.thread.daemon = True | |
self.thread.start() | |
def process(self, | |
prob: torch.Tensor, | |
frame_name: str, | |
resize_needed: bool = False, | |
shape: Optional[Tuple[int, int]] = None, | |
last_frame: bool = False, | |
path_to_image: str = None): | |
if resize_needed: | |
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, | |
0] | |
# Probability mask -> index mask | |
mask = torch.argmax(prob, dim=0) | |
if self.save_scores: | |
# also need to pass prob | |
prob = prob.cpu() | |
else: | |
prob = None | |
# remap indices | |
if self.need_remapping: | |
new_mask = torch.zeros_like(mask) | |
for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): | |
new_mask[mask == tmp_id] = obj.id | |
mask = new_mask | |
args = ResultArgs(saver=self, | |
prob=prob, | |
mask=mask.cpu(), | |
frame_name=frame_name, | |
path_to_image=path_to_image, | |
tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj), | |
obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id), | |
last_frame=last_frame) | |
self.queue.put(args) | |
def end(self): | |
self.queue.put(None) | |
self.queue.join() | |
self.thread.join() | |
class ResultArgs: | |
saver: ResultSaver | |
prob: torch.Tensor | |
mask: torch.Tensor | |
frame_name: str | |
path_to_image: str | |
tmp_id_to_obj: Dict[int, ObjectInfo] | |
obj_to_tmp_id: Dict[ObjectInfo, int] | |
last_frame: bool | |
def save_result(queue: Queue): | |
while True: | |
args: ResultArgs = queue.get() | |
if args is None: | |
queue.task_done() | |
break | |
saver = args.saver | |
prob = args.prob | |
mask = args.mask | |
frame_name = args.frame_name | |
path_to_image = args.path_to_image | |
tmp_id_to_obj = args.tmp_id_to_obj | |
obj_to_tmp_id = args.obj_to_tmp_id | |
last_frame = args.last_frame | |
all_obj_ids = [k.id for k in obj_to_tmp_id] | |
# record output in the json file | |
if saver.json_style == 'burst': | |
if frame_name in saver.annotated_frames: | |
frame_index = saver.annotated_frames.index(frame_name) | |
input_segments = saver.input_segmentations[frame_index] | |
frame_segments = saver.segmentations[frame_index] | |
for id in all_obj_ids: | |
if id in input_segments: | |
# if this frame has been given as input, just copy | |
frame_segments[id] = input_segments[id] | |
continue | |
segment = {} | |
segment_mask = (mask == id) | |
if segment_mask.sum() > 0: | |
coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy())) | |
segment['rle'] = coco_mask['counts'].decode('utf-8') | |
frame_segments[id] = segment | |
# save the mask to disk | |
if saver.save_mask: | |
if saver.use_long_id: | |
out_mask = mask.numpy().astype(np.uint32) | |
rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8) | |
for id in all_obj_ids: | |
_, image = saver.id2rgb_converter.convert(id) | |
obj_mask = (out_mask == id) | |
rgb_mask[obj_mask] = image | |
out_img = Image.fromarray(rgb_mask) | |
else: | |
rgb_mask = None | |
out_mask = mask.numpy().astype(np.uint8) | |
out_img = Image.fromarray(out_mask) | |
if saver.palette is not None: | |
out_img.putpalette(saver.palette) | |
this_out_path = path.join(saver.output_root, saver.video_name) | |
os.makedirs(this_out_path, exist_ok=True) | |
out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png')) | |
# save scores for multi-scale testing | |
if saver.save_scores: | |
this_out_path = path.join(saver.score_output_root, saver.video_name) | |
os.makedirs(this_out_path, exist_ok=True) | |
prob = (prob.detach().numpy() * 255).astype(np.uint8) | |
if last_frame: | |
tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()} | |
hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w') | |
hkl.dump(prob, | |
path.join(this_out_path, f'{frame_name[:-4]}.hkl'), | |
mode='w', | |
compression='lzf') | |
if saver.visualize: | |
if path_to_image is not None: | |
image_np = np.array(Image.open(path_to_image)) | |
else: | |
raise ValueError('Cannot visualize without path_to_image') | |
if rgb_mask is None: | |
# we need to apply a palette | |
rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8) | |
for id in all_obj_ids: | |
image = saver.colors[id] | |
obj_mask = (out_mask == id) | |
rgb_mask[obj_mask] = image | |
alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5 | |
alpha = alpha[:, :, None] | |
blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8) | |
# find a place to save the visualization | |
this_vis_path = path.join(saver.visualize_output_root, saver.video_name) | |
os.makedirs(this_vis_path, exist_ok=True) | |
Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg')) | |
queue.task_done() | |
def make_zip(dataset, run_dir, exp_id, mask_output_root): | |
if dataset.startswith('y'): | |
# YoutubeVOS | |
log.info('Making zip for YouTubeVOS...') | |
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, | |
'Annotations') | |
elif dataset == 'd17-test-dev': | |
# DAVIS 2017 test-dev -- zip from within the Annotation folder | |
log.info('Making zip for DAVIS test-dev...') | |
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) | |
elif dataset == 'mose-val': | |
# MOSE validation -- same as DAVIS test-dev | |
log.info('Making zip for MOSE validation...') | |
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) | |
elif dataset == 'lvos-test': | |
# LVOS test -- same as YouTubeVOS | |
log.info('Making zip for LVOS test...') | |
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, | |
'Annotations') | |
else: | |
log.info(f'Not making zip for {dataset}.') | |