File size: 12,423 Bytes
3aba902
 
 
 
 
 
 
5b2a969
3aba902
 
 
 
 
 
 
 
 
 
 
 
 
5b2a969
 
3aba902
 
 
 
 
 
 
 
 
 
5b2a969
3aba902
 
 
 
 
 
 
 
 
 
 
 
5b2a969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aba902
 
 
 
 
5b2a969
 
3aba902
5b2a969
 
3aba902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2a969
 
3aba902
 
 
 
 
 
 
 
 
 
5b2a969
3aba902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2a969
3aba902
 
 
 
 
 
 
5b2a969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aba902
 
 
 
5b2a969
 
3aba902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2a969
3aba902
 
 
 
 
 
 
 
5b2a969
 
 
 
3aba902
 
 
5b2a969
3aba902
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import os
import sys
import argparse
from PIL import Image
project_root = os.path.dirname(os.path.abspath(__file__))
try:
    sys.path.append(os.path.join(project_root, "submodules/MoGe"))
    sys.path.append(os.path.join(project_root, "submodules/vggt"))
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
except:
    print("Warning: MoGe not found, motion transfer will not be applied")
    
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from moviepy.editor import VideoFileClip
from diffusers.utils import load_image, load_video

from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
from submodules.MoGe.moge.model import MoGeModel
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from submodules.vggt.vggt.models.vggt import VGGT

def load_media(media_path, max_frames=49, transform=None):
    """Load video or image frames and convert to tensor
    
    Args:
        media_path (str): Path to video or image file
        max_frames (int): Maximum number of frames to load
        transform (callable): Transform to apply to frames
        
    Returns:
        Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag
    """
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((480, 720)),
            transforms.ToTensor()
        ])
    
    # Determine if input is video or image based on extension
    ext = os.path.splitext(media_path)[1].lower()
    is_video = ext in ['.mp4', '.avi', '.mov']
    
    if is_video:
        # Load video file info
        video_clip = VideoFileClip(media_path)
        duration = video_clip.duration
        original_fps = video_clip.fps
        
        # Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
        if duration > 6.0:
            sampling_fps = 8  # 8 frames per second
            frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
            fps = sampling_fps
        # Cases 2 and 3: Video shorter than 6 seconds
        else:
            # Load all frames
            frames = load_video(media_path)
            
            # Case 2: Total frames less than max_frames, need interpolation
            if len(frames) < max_frames:
                fps = len(frames) / duration  # Keep original fps
                
                # Evenly interpolate to max_frames
                indices = np.linspace(0, len(frames) - 1, max_frames)
                new_frames = []
                for i in indices:
                    idx = int(i)
                    new_frames.append(frames[idx])
                frames = new_frames
            # Case 3: Total frames more than max_frames but video less than 6 seconds
            else:
                # Evenly sample to max_frames
                indices = np.linspace(0, len(frames) - 1, max_frames)
                new_frames = []
                for i in indices:
                    idx = int(i)
                    new_frames.append(frames[idx])
                frames = new_frames
                fps = max_frames / duration  # New fps to maintain duration
    else:
        # Handle image as single frame
        image = load_image(media_path)
        frames = [image]
        fps = 8  # Default fps for images
        
        # Duplicate frame to max_frames
        while len(frames) < max_frames:
            frames.append(frames[0].copy())
    
    # Convert frames to tensor
    video_tensor = torch.stack([transform(frame) for frame in frames])
    
    return video_tensor, fps, is_video

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image')
    parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt')
    parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
    parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')
    parser.add_argument('--checkpoint_path', type=str, default="EXCAI/Diffusion-As-Shader", help='Path to model checkpoint')
    parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image')
    parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied')
    parser.add_argument('--repaint', type=str, default=None, 
                       help='Path to repainted image, or "true" to perform repainting, if not provided use original frame')
    parser.add_argument('--camera_motion', type=str, default=None, 
                    help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
    parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
    parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
    parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge', 'cotracker'], 
                    help='Tracking method to use (spatracker, cotracker or moge)')
    args = parser.parse_args()
    
    # Load input video/image
    video_tensor, fps, is_video = load_media(args.input_path)
    if not is_video:
        args.tracking_method = "moge"
        print("Image input detected, using MoGe for tracking video generation.")

    # Initialize pipeline
    das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
    das.fps = fps
    if args.tracking_method == "moge" and args.tracking_path is None:
        moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
    
    # Repaint first frame if requested
    repaint_img_tensor = None
    if args.repaint:
        if args.repaint.lower() == "true":
            repainter = FirstFrameRepainter(gpu_id=args.gpu, output_dir=args.output_dir)
            repaint_img_tensor = repainter.repaint(
                video_tensor[0], 
                prompt=args.prompt,
                depth_path=args.depth_path
            )
        else:
            repaint_img_tensor, _, _ = load_media(args.repaint)
            repaint_img_tensor = repaint_img_tensor[0]  # Take first frame

    # Generate tracking if not provided
    tracking_tensor = None
    pred_tracks = None
    cam_motion = CameraMotionGenerator(args.camera_motion)

    if args.tracking_path:
        tracking_tensor, _, _ = load_media(args.tracking_path)
        
    elif args.tracking_method == "moge":
        # Use the first frame from previously loaded video_tensor
        infer_result = moge.infer(video_tensor[0].to(das.device))  # [C, H, W] in range [0,1]
        H, W = infer_result["points"].shape[0:2]
        pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
        cam_motion.set_intr(infer_result["intrinsics"])

        # Apply object motion if specified
        if args.object_motion:
            if args.object_mask is None:
                raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
                
            # Load mask image
            mask_image = Image.open(args.object_mask).convert('L')  # Convert to grayscale
            mask_image = transforms.Resize((480, 720))(mask_image)  # Resize to match video size
            # Convert to binary mask
            mask = torch.from_numpy(np.array(mask_image) > 127)  # Threshold at 127
            
            motion_generator = ObjectMotionGenerator(device=das.device)

            pred_tracks = motion_generator.apply_motion(
                pred_tracks=pred_tracks,
                mask=mask,
                motion_type=args.object_motion,
                distance=50,
                num_frames=49,
                tracking_method="moge"
            )
            print("Object motion applied")

        # Apply camera motion if specified
        if args.camera_motion:
            poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
            print("Camera motion applied")
        else:
            # no poses
            poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
        # change pred_tracks into screen coordinate
        pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
        pred_tracks = cam_motion.w2s_moge(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
        _, tracking_tensor = das.visualize_tracking_moge(
            pred_tracks.cpu().numpy(), 
            infer_result["mask"].cpu().numpy()
        )
        print('export tracking video via MoGe.')

    else:

        if args.tracking_method == "cotracker":
            pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor) # T N 3, T N
        else:
            pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) # T N 3, T N, B N

        # Preprocess video tensor to match VGGT requirements
        t, c, h, w = video_tensor.shape
        new_width = 518
        new_height = round(h * (new_width / w) / 14) * 14
        resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
        video_vggt = resize_transform(video_tensor)  # [T, C, H, W]
        
        if new_height > 518:
            start_y = (new_height - 518) // 2
            video_vggt = video_vggt[:, :, start_y:start_y + 518, :]

        # Get extrinsic and intrinsic matrices
        vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)

        with torch.no_grad():
            with torch.cuda.amp.autocast(dtype=das.dtype):

                video_vggt = video_vggt.unsqueeze(0)  # [1, T, C, H, W]
                aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
            
                # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
                extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
                depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_vggt, ps_idx)
        
        cam_motion.set_intr(intr)
        cam_motion.set_extr(extr)

        # Apply camera motion if specified
        if args.camera_motion:
            poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
            pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
            pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
            print("Camera motion applied")
        
        # Apply object motion if specified
        if args.object_motion:
            if args.object_mask is None:
                raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
                
            # Load mask image
            mask_image = Image.open(args.object_mask).convert('L')  # Convert to grayscale
            mask_image = transforms.Resize((480, 720))(mask_image)  # Resize to match video size
            # Convert to binary mask
            mask = torch.from_numpy(np.array(mask_image) > 127)  # Threshold at 127
            
            motion_generator = ObjectMotionGenerator(device=das.device)
            
            pred_tracks = motion_generator.apply_motion(
                pred_tracks=pred_tracks,
                mask=mask,
                motion_type=args.object_motion,
                distance=50,
                num_frames=49,
                tracking_method="spatracker"
            ).unsqueeze(0)
            print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
    
        if args.tracking_method == "cotracker":
            _, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
        else:
            _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
    
    das.apply_tracking(
        video_tensor=video_tensor,
        fps=fps,
        tracking_tensor=tracking_tensor,
        img_cond_tensor=repaint_img_tensor,
        prompt=args.prompt,
        checkpoint_path=args.checkpoint_path
    )