# Copyright 2024 Adobe. All rights reserved. #%% from torchvision.transforms import ToPILImage import torch import pandas as pd import matplotlib.pyplot as plt import numpy as np import torchvision import cv2 import tqdm import matplotlib.pyplot as plt import torchvision.transforms.functional as F from PIL import Image from torchvision.utils import save_image import time import os import sys import pathlib from torchvision.utils import flow_to_image from torch.utils.data import DataLoader from einops import rearrange # %matplotlib inline from kornia.filters.median import MedianBlur median_filter = MedianBlur(kernel_size=(15,15)) from moments_dataset import MomentsDataset try: from processing_utils import aggregate_frames import processing_utils except Exception as e: print(e) print('process failed') exit() import pytorch_lightning as pl import torch from omegaconf import OmegaConf # %% def load_image(img_path, resize_size=None,crop_size=None): img1_pil = Image.open(img_path) img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil) if resize_size: img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size) if crop_size: img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size) img1_batch = torch.unsqueeze(img1_frames, dim=0) return img1_batch def get_grid(size): y = np.repeat(np.arange(size)[None, ...], size) y = y.reshape(size, size) x = y.transpose() out = np.stack([y,x], -1) return out def collage_from_frames(frames_t): # decide forward or backward if np.random.randint(0, 2) == 0: # flip frames_t = frames_t.flip(0) # decide how deep you would go tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20)) tgt_idx = 1 pairwise_flows = [] flow = None init_time = time.time() unsmoothed_agg = None for cur_idx in range(1, tgt_idx_guess+1): # cur_idx = i+1 cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) # passing pairwise flows for efficiency unsmoothed_agg = cur_flow.clone() agg_cur_flow = median_filter(cur_flow) flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten() # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) flow_95 = np.percentile(flow_norm.cpu().numpy(), 95) if cur_idx == 5: # if still small flow then drop if flow_95 < 20.0: # no motion in the frame. skip print('flow is tiny :(') return None if cur_idx == tgt_idx_guess-1: # if still small flow then drop if flow_95 < 50.0: # no motion in the frame. skip print('flow is tiny :(') return None if flow is None: # means first iter if flow_90 < 1.0: # no motion in the frame. skip return None flow = agg_cur_flow if flow_90 <= 300: # maybe should increase this part # update idx tgt_idx = cur_idx flow = agg_cur_flow else: break final_time = time.time() print('time guessing idx', final_time - init_time) _, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None) flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5 if np.mean(flow_warping_mask) < 0.6: return src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0 init_time = time.time() depth = get_depth_from_array(frames_t[0]) finish_time = time.time() print('time getting depth', finish_time - init_time) # flow, pairwise_flows = aggregate_frames(frames_t) # agg_flow = median_filter(flow) src_array_uint = src_array * 255.0 src_array_uint = src_array_uint.astype(np.uint8) segments = processing_utils.mask_generator.generate(src_array_uint) size = src_array.shape[1] grid_np = get_grid(size).astype(np.float16) / size # 512 x 512 x 2get grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2 collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np) lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0) warping_alpha = (lost_alpha_t < 0.5).float() rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha) # basic blending now # print('rgb grid splatted', rgb_grid_splatted.shape) warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy() canvas_alpha_mask = canvas_alpha == 0.0 collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy() collage_mask = collage_mask > 0.5 composite_grid = warped_src * canvas_alpha_mask + collage rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy() return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask def collage_warp(rgb_array, flow, depth, segments, grid_array): avg_depths = [] avg_flows = [] # src_array = src_array.moveaxis(-1, 0).cpu().numpy() #np.array(Image.open(src_path).convert('RGB')) / 255.0 src_array = np.concatenate([rgb_array, grid_array], axis=-1) canvas = np.zeros_like(src_array) canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float) lost_regions = np.zeros_like(canvas[...,-1:]).astype(float) z_buffer = np.ones_like(depth)[..., None] * -1.0 unsqueezed_depth = depth[..., None] affine_transforms = [] filtered_segments = [] for segment in segments: if segment['area'] > 300: filtered_segments.append(segment) for segment in filtered_segments: seg_mask = segment['segmentation'] avg_flow = torch.mean(flow[:, seg_mask],dim=1) avg_flows.append(avg_flow) # median depth (conversion from disparity) avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6)) avg_depths.append(avg_depth) all_y, all_x = np.nonzero(segment['segmentation']) rand_indices = np.random.randint(0, len(all_y), size=50) rand_x = [all_x[i] for i in rand_indices] rand_y = [all_y[i] for i in rand_indices] src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)] # tgt_pairs = [(x + w, y) for x, y in src_pairs] tgt_pairs = [] # print('estimating affine') # TODO this can be faster for i in range(len(src_pairs)): x, y = src_pairs[i] dx, dy = flow[:, y, x] tgt_pairs.append((x+dx, y+dy)) # affine_trans, inliers = cv2.estimateAffine2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) # print('num inliers', np.sum(inliers)) # # print('num inliers', np.sum(inliers)) affine_transforms.append(affine_trans) depth_sorted_indices = np.arange(len(avg_depths)) depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x]) # sorted_masks = [] # print('warping stuff') for idx in depth_sorted_indices: # sorted_masks.append(mask[idx]) alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float) src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1) warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0])) warped_mask = warp_dst[..., -2:-1] # this is warped alpha warped_depth = warp_dst[..., -1:] warped_rgb = warp_dst[...,:-2] good_z_region = warped_depth > z_buffer warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float) kernel = np.ones((3,3), float) # print('og masked shape', warped_mask.shape) # warped_mask = cv2.erode(warped_mask,(5,5))[..., None] # print('eroded masked shape', warped_mask.shape) canvas_alpha += cv2.erode(warped_mask,kernel)[..., None] lost_regions += alpha_mask canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb # TODO check if need to dialate here z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth # TODO check if need to dialate here # print('max lost region', np.max(lost_regions)) return canvas, canvas_alpha, lost_regions def get_depth_from_array(img_t): img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0 # print(img_arr.shape) img_arr *= 255.0 img_arr = img_arr.astype(np.uint8) input_batch = processing_utils.depth_transform(img_arr).cuda() with torch.no_grad(): prediction = processing_utils.midas(input_batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img_arr.shape[:2], mode="bicubic", align_corners=False, ).squeeze() output = prediction.cpu() return output # %% def main(): print('starting main') video_folder = './example_videos' save_dir = pathlib.Path('./processed_data') process_video_folder(video_folder, save_dir) def process_video_folder(video_folder, save_dir): all_counter = 0 success_counter = 0 # save_folder = pathlib.Path('/dev/shm/processed') # save_dir = save_folder / foldername #pathlib.Path('/sensei-fs/users/halzayer/collage2photo/testing_partitioning_dilate_extreme') os.makedirs(save_dir, exist_ok=True) dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5) batch_size = 4 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) with torch.no_grad(): for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size): frames_to_visualize = batch["frames"] bs = frames_to_visualize.shape[0] for j in range(bs): frames = frames_to_visualize[j] caption = batch["caption"][j] collage_init_time = time.time() out = collage_from_frames(frames) collage_finish_time = time.time() print('collage processing time', collage_finish_time - collage_init_time) all_counter += 1 if out is not None: src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out splatted_rgb = splatted[...,:3] splatted_grid = splatted[...,3:].astype(np.float16) collage_rgb = collage[...,:3] collage_grid = collage[...,3:].astype(np.float16) success_counter += 1 else: continue id_str = f'{success_counter:08d}' src_path = str(save_dir / f'src_{id_str}.png') tgt_path = str(save_dir / f'tgt_{id_str}.png') flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png') composite_path = str(save_dir / f'composite_{id_str}.png') flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png') composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png') flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy') composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy') save_image(src_image, src_path) save_image(tgt_image, tgt_path) collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8)) collage_pil.save(composite_path) splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8)) splatted_pil.save(flow_warped_path) flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8)) flow_mask_pil.save(flow_mask_path) composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8)) composite_mask_pil.save(composite_mask_path) splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0) splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64)) collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0) collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64)) np.save(flow_grid_path, splatted_grid_resized.cpu().numpy()) np.save(composite_grid_path, collage_grid_resized.cpu().numpy()) del out del splatted_grid del collage_grid del frames del frames_to_visualize #%% if __name__ == '__main__': try: main() except Exception as e: print(e) print('process failed')