Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,423 Bytes
3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 5b2a969 3aba902 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import os
import sys
import argparse
from PIL import Image
project_root = os.path.dirname(os.path.abspath(__file__))
try:
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
sys.path.append(os.path.join(project_root, "submodules/vggt"))
os.environ["TOKENIZERS_PARALLELISM"] = "false"
except:
print("Warning: MoGe not found, motion transfer will not be applied")
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from moviepy.editor import VideoFileClip
from diffusers.utils import load_image, load_video
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
from submodules.MoGe.moge.model import MoGeModel
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from submodules.vggt.vggt.models.vggt import VGGT
def load_media(media_path, max_frames=49, transform=None):
"""Load video or image frames and convert to tensor
Args:
media_path (str): Path to video or image file
max_frames (int): Maximum number of frames to load
transform (callable): Transform to apply to frames
Returns:
Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag
"""
if transform is None:
transform = transforms.Compose([
transforms.Resize((480, 720)),
transforms.ToTensor()
])
# Determine if input is video or image based on extension
ext = os.path.splitext(media_path)[1].lower()
is_video = ext in ['.mp4', '.avi', '.mov']
if is_video:
# Load video file info
video_clip = VideoFileClip(media_path)
duration = video_clip.duration
original_fps = video_clip.fps
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
if duration > 6.0:
sampling_fps = 8 # 8 frames per second
frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
fps = sampling_fps
# Cases 2 and 3: Video shorter than 6 seconds
else:
# Load all frames
frames = load_video(media_path)
# Case 2: Total frames less than max_frames, need interpolation
if len(frames) < max_frames:
fps = len(frames) / duration # Keep original fps
# Evenly interpolate to max_frames
indices = np.linspace(0, len(frames) - 1, max_frames)
new_frames = []
for i in indices:
idx = int(i)
new_frames.append(frames[idx])
frames = new_frames
# Case 3: Total frames more than max_frames but video less than 6 seconds
else:
# Evenly sample to max_frames
indices = np.linspace(0, len(frames) - 1, max_frames)
new_frames = []
for i in indices:
idx = int(i)
new_frames.append(frames[idx])
frames = new_frames
fps = max_frames / duration # New fps to maintain duration
else:
# Handle image as single frame
image = load_image(media_path)
frames = [image]
fps = 8 # Default fps for images
# Duplicate frame to max_frames
while len(frames) < max_frames:
frames.append(frames[0].copy())
# Convert frames to tensor
video_tensor = torch.stack([transform(frame) for frame in frames])
return video_tensor, fps, is_video
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image')
parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt')
parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')
parser.add_argument('--checkpoint_path', type=str, default="EXCAI/Diffusion-As-Shader", help='Path to model checkpoint')
parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image')
parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied')
parser.add_argument('--repaint', type=str, default=None,
help='Path to repainted image, or "true" to perform repainting, if not provided use original frame')
parser.add_argument('--camera_motion', type=str, default=None,
help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge', 'cotracker'],
help='Tracking method to use (spatracker, cotracker or moge)')
args = parser.parse_args()
# Load input video/image
video_tensor, fps, is_video = load_media(args.input_path)
if not is_video:
args.tracking_method = "moge"
print("Image input detected, using MoGe for tracking video generation.")
# Initialize pipeline
das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
das.fps = fps
if args.tracking_method == "moge" and args.tracking_path is None:
moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
# Repaint first frame if requested
repaint_img_tensor = None
if args.repaint:
if args.repaint.lower() == "true":
repainter = FirstFrameRepainter(gpu_id=args.gpu, output_dir=args.output_dir)
repaint_img_tensor = repainter.repaint(
video_tensor[0],
prompt=args.prompt,
depth_path=args.depth_path
)
else:
repaint_img_tensor, _, _ = load_media(args.repaint)
repaint_img_tensor = repaint_img_tensor[0] # Take first frame
# Generate tracking if not provided
tracking_tensor = None
pred_tracks = None
cam_motion = CameraMotionGenerator(args.camera_motion)
if args.tracking_path:
tracking_tensor, _, _ = load_media(args.tracking_path)
elif args.tracking_method == "moge":
# Use the first frame from previously loaded video_tensor
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
H, W = infer_result["points"].shape[0:2]
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
cam_motion.set_intr(infer_result["intrinsics"])
# Apply object motion if specified
if args.object_motion:
if args.object_mask is None:
raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
# Load mask image
mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale
mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size
# Convert to binary mask
mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127
motion_generator = ObjectMotionGenerator(device=das.device)
pred_tracks = motion_generator.apply_motion(
pred_tracks=pred_tracks,
mask=mask,
motion_type=args.object_motion,
distance=50,
num_frames=49,
tracking_method="moge"
)
print("Object motion applied")
# Apply camera motion if specified
if args.camera_motion:
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
print("Camera motion applied")
else:
# no poses
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
# change pred_tracks into screen coordinate
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
pred_tracks = cam_motion.w2s_moge(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
_, tracking_tensor = das.visualize_tracking_moge(
pred_tracks.cpu().numpy(),
infer_result["mask"].cpu().numpy()
)
print('export tracking video via MoGe.')
else:
if args.tracking_method == "cotracker":
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor) # T N 3, T N
else:
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) # T N 3, T N, B N
# Preprocess video tensor to match VGGT requirements
t, c, h, w = video_tensor.shape
new_width = 518
new_height = round(h * (new_width / w) / 14) * 14
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
if new_height > 518:
start_y = (new_height - 518) // 2
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
# Get extrinsic and intrinsic matrices
vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=das.dtype):
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_vggt, ps_idx)
cam_motion.set_intr(intr)
cam_motion.set_extr(extr)
# Apply camera motion if specified
if args.camera_motion:
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
print("Camera motion applied")
# Apply object motion if specified
if args.object_motion:
if args.object_mask is None:
raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
# Load mask image
mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale
mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size
# Convert to binary mask
mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127
motion_generator = ObjectMotionGenerator(device=das.device)
pred_tracks = motion_generator.apply_motion(
pred_tracks=pred_tracks,
mask=mask,
motion_type=args.object_motion,
distance=50,
num_frames=49,
tracking_method="spatracker"
).unsqueeze(0)
print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
if args.tracking_method == "cotracker":
_, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
else:
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
das.apply_tracking(
video_tensor=video_tensor,
fps=fps,
tracking_tensor=tracking_tensor,
img_cond_tensor=repaint_img_tensor,
prompt=args.prompt,
checkpoint_path=args.checkpoint_path
)
|