import argparse from typing import Any, Dict, List, Literal, Tuple import pandas as pd import os import sys import torch from diffusers import ( CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, CogVideoXImageToVideoPipeline, CogVideoXVideoToVideoPipeline, ) from diffusers.utils import export_to_video, load_image, load_video import numpy as np import random import cv2 from pathlib import Path import decord from torchvision import transforms from torchvision.transforms.functional import resize import PIL.Image from PIL import Image current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(current_dir, '..')) from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking, CogVideoXPipelineTracking, CogVideoXVideoToVideoPipelineTracking from training.dataset import VideoDataset, VideoDatasetWithResizingTracking class VideoDatasetWithResizingTrackingEval(VideoDataset): def __init__(self, *args, **kwargs) -> None: self.tracking_column = kwargs.pop("tracking_column", None) self.image_paths = kwargs.pop("image_paths", None) super().__init__(*args, **kwargs) def _preprocess_video(self, path: Path, tracking_path: Path, image_paths: Path = None) -> torch.Tensor: if self.load_tensors: return self._load_preprocessed_latents_and_embeds(path, tracking_path) else: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) frames = video_reader.get_batch(frame_indices) frames = frames[:nearest_frame_bucket].float() frames = frames.permute(0, 3, 1, 2).contiguous() nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) image = Image.open(image_paths) if image.mode != 'RGB': image = image.convert('RGB') image = torch.from_numpy(np.array(image)).float() image = image.permute(2, 0, 1).contiguous() image = resize(image, nearest_res) image = self.video_transforms(image) tracking_reader = decord.VideoReader(uri=tracking_path.as_posix()) tracking_frames = tracking_reader.get_batch(frame_indices) tracking_frames = tracking_frames[:nearest_frame_bucket].float() tracking_frames = tracking_frames.permute(0, 3, 1, 2).contiguous() tracking_frames_resized = torch.stack([resize(tracking_frame, nearest_res) for tracking_frame in tracking_frames], dim=0) tracking_frames = torch.stack([self.video_transforms(tracking_frame) for tracking_frame in tracking_frames_resized], dim=0) return image, frames, tracking_frames, None def _find_nearest_resolution(self, height, width): nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) return nearest_res[1], nearest_res[2] def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str], List[str]]: if not self.data_root.exists(): raise ValueError("Root folder for videos does not exist") prompt_path = self.data_root.joinpath(self.caption_column) video_path = self.data_root.joinpath(self.video_column) tracking_path = self.data_root.joinpath(self.tracking_column) image_paths = self.data_root.joinpath(self.image_paths) if not prompt_path.exists() or not prompt_path.is_file(): raise ValueError( "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." ) if not video_path.exists() or not video_path.is_file(): raise ValueError( "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." ) if not tracking_path.exists() or not tracking_path.is_file(): raise ValueError( "Expected `--tracking_column` to be path to a file in `--data_root` containing line-separated tracking information." ) with open(prompt_path, "r", encoding="utf-8") as file: prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] with open(video_path, "r", encoding="utf-8") as file: video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] with open(tracking_path, "r", encoding="utf-8") as file: tracking_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] with open(image_paths, "r", encoding="utf-8") as file: image_paths_list = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] if not self.load_tensors and any(not path.is_file() for path in video_paths): raise ValueError( f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." ) self.tracking_paths = tracking_paths self.image_paths = image_paths_list return prompts, video_paths def _load_dataset_from_csv(self) -> Tuple[List[str], List[str], List[str]]: df = pd.read_csv(self.dataset_file) prompts = df[self.caption_column].tolist() video_paths = df[self.video_column].tolist() tracking_paths = df[self.tracking_column].tolist() image_paths = df[self.image_paths].tolist() video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] tracking_paths = [self.data_root.joinpath(line.strip()) for line in tracking_paths] image_paths = [self.data_root.joinpath(line.strip()) for line in image_paths] if any(not path.is_file() for path in video_paths): raise ValueError( f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found at least one path that is not a valid file." ) self.tracking_paths = tracking_paths self.image_paths = image_paths return prompts, video_paths def __getitem__(self, index: int) -> Dict[str, Any]: if isinstance(index, list): return index if self.load_tensors: image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index], self.tracking_paths[index]) # The VAE's temporal compression ratio is 4. # The VAE's spatial compression ratio is 8. latent_num_frames = video_latents.size(1) if latent_num_frames % 2 == 0: num_frames = latent_num_frames * 4 else: num_frames = (latent_num_frames - 1) * 4 + 1 height = video_latents.size(2) * 8 width = video_latents.size(3) * 8 return { "prompt": prompt_embeds, "image": image_latents, "video": video_latents, "tracking_map": tracking_map, "video_metadata": { "num_frames": num_frames, "height": height, "width": width, }, } else: image, video, tracking_map, _ = self._preprocess_video(self.video_paths[index], self.tracking_paths[index], self.image_paths[index]) return { "prompt": self.id_token + self.prompts[index], "image": image, "video": video, "tracking_map": tracking_map, "video_metadata": { "num_frames": video.shape[0], "height": video.shape[2], "width": video.shape[3], }, } def _load_preprocessed_latents_and_embeds(self, path: Path, tracking_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: filename_without_ext = path.name.split(".")[0] pt_filename = f"{filename_without_ext}.pt" # The current path is something like: /a/b/c/d/videos/00001.mp4 # We need to reach: /a/b/c/d/video_latents/00001.pt image_latents_path = path.parent.parent.joinpath("image_latents") video_latents_path = path.parent.parent.joinpath("video_latents") tracking_map_path = path.parent.parent.joinpath("tracking_map") embeds_path = path.parent.parent.joinpath("prompt_embeds") if ( not video_latents_path.exists() or not embeds_path.exists() or not tracking_map_path.exists() or (self.image_to_video and not image_latents_path.exists()) ): raise ValueError( f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains folders named `video_latents`, `prompt_embeds`, and `tracking_map`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." ) if self.image_to_video: image_latent_filepath = image_latents_path.joinpath(pt_filename) video_latent_filepath = video_latents_path.joinpath(pt_filename) tracking_map_filepath = tracking_map_path.joinpath(pt_filename) embeds_filepath = embeds_path.joinpath(pt_filename) if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not tracking_map_filepath.is_file(): if self.image_to_video: image_latent_filepath = image_latent_filepath.as_posix() video_latent_filepath = video_latent_filepath.as_posix() tracking_map_filepath = tracking_map_filepath.as_posix() embeds_filepath = embeds_filepath.as_posix() raise ValueError( f"The file {video_latent_filepath=} or {embeds_filepath=} or {tracking_map_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." ) images = ( torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None ) latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) tracking_map = torch.load(tracking_map_filepath, map_location="cpu", weights_only=True) embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) return images, latents, tracking_map, embeds def sample_from_dataset( data_root: str, caption_column: str, tracking_column: str, image_paths: str, video_column: str, num_samples: int = -1, random_seed: int = 42 ): """Sample from dataset""" if image_paths: # If image_paths is provided, use VideoDatasetWithResizingTrackingEval dataset = VideoDatasetWithResizingTrackingEval( data_root=data_root, caption_column=caption_column, tracking_column=tracking_column, image_paths=image_paths, video_column=video_column, max_num_frames=49, load_tensors=False, random_flip=None, frame_buckets=[49], image_to_video=True ) else: # If image_paths is not provided, use VideoDatasetWithResizingTracking dataset = VideoDatasetWithResizingTracking( data_root=data_root, caption_column=caption_column, tracking_column=tracking_column, video_column=video_column, max_num_frames=49, load_tensors=False, random_flip=None, frame_buckets=[49], image_to_video=True ) # Set random seed random.seed(random_seed) # Randomly sample from dataset total_samples = len(dataset) if num_samples == -1: # If num_samples is -1, process all samples selected_indices = range(total_samples) else: selected_indices = random.sample(range(total_samples), min(num_samples, total_samples)) samples = [] for idx in selected_indices: sample = dataset[idx] # Get data based on dataset.__getitem__ return value image = sample["image"] # Already processed tensor video = sample["video"] # Already processed tensor tracking_map = sample["tracking_map"] # Already processed tensor prompt = sample["prompt"] samples.append({ "prompt": prompt, "tracking_frame": tracking_map[0], # Get first frame "video_frame": image, # Get first frame "video": video, # Complete video "tracking_maps": tracking_map, # Complete tracking maps "height": sample["video_metadata"]["height"], "width": sample["video_metadata"]["width"] }) return samples def generate_video( prompt: str, model_path: str, tracking_path: str = None, output_path: str = "./output.mp4", image_or_video_path: str = "", num_inference_steps: int = 50, guidance_scale: float = 6.0, num_videos_per_prompt: int = 1, dtype: torch.dtype = torch.bfloat16, generate_type: str = Literal["i2v", "i2vo"], # i2v: image to video, i2vo: original CogVideoX-5b-I2V seed: int = 42, data_root: str = None, caption_column: str = None, tracking_column: str = None, video_column: str = None, image_paths: str = None, num_samples: int = -1, evaluation_dir: str = "evaluations", fps: int = 8, ): device = "cuda" if torch.cuda.is_available() else "cpu" # If dataset parameters are provided, sample from dataset samples = None if all([data_root, caption_column, tracking_column, video_column]): samples = sample_from_dataset( data_root=data_root, caption_column=caption_column, tracking_column=tracking_column, image_paths=image_paths, video_column=video_column, num_samples=num_samples, random_seed=seed ) # Load model and data if generate_type == "i2v": pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype) if not samples: image = load_image(image=image_or_video_path) height, width = image.height, image.width else: pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=dtype) if not samples: image = load_image(image=image_or_video_path) height, width = image.height, image.width # Set model parameters pipe.to(device, dtype=dtype) pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.transformer.eval() pipe.text_encoder.eval() pipe.vae.eval() pipe.transformer.gradient_checkpointing = False pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") # Generate video if samples: from tqdm import tqdm for i, sample in tqdm(enumerate(samples), desc="Samples Num:"): print(f"Prompt: {sample['prompt'][:30]}") tracking_frame = sample["tracking_frame"].to(device=device, dtype=dtype) video_frame = sample["video_frame"].to(device=device, dtype=dtype) video = sample["video"].to(device=device, dtype=dtype) tracking_maps = sample["tracking_maps"].to(device=device, dtype=dtype) # VAE print("encoding tracking maps") tracking_video = tracking_maps tracking_maps = tracking_maps.unsqueeze(0) tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] with torch.no_grad(): tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] pipeline_args = { "prompt": sample["prompt"], "negative_prompt": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", "num_inference_steps": num_inference_steps, "num_frames": 49, "use_dynamic_cfg": True, "guidance_scale": guidance_scale, "generator": torch.Generator(device=device).manual_seed(seed), "height": sample["height"], "width": sample["width"] } pipeline_args["image"] = (video_frame + 1.0) / 2.0 if tracking_column and generate_type == "i2v": pipeline_args["tracking_maps"] = tracking_maps pipeline_args["tracking_image"] = (tracking_frame.unsqueeze(0) + 1.0) / 2.0 with torch.no_grad(): video_generate = pipe(**pipeline_args).frames[0] output_dir = os.path.join(data_root, evaluation_dir) output_name = f"{i:04d}.mp4" output_file = os.path.join(output_dir, output_name) os.makedirs(output_dir, exist_ok=True) export_concat_video(video_generate, video, tracking_video, output_file, fps=fps) else: pipeline_args = { "prompt": prompt, "num_videos_per_prompt": num_videos_per_prompt, "num_inference_steps": num_inference_steps, "num_frames": 49, "use_dynamic_cfg": True, "guidance_scale": guidance_scale, "generator": torch.Generator().manual_seed(seed), } pipeline_args["video"] = video pipeline_args["image"] = image pipeline_args["height"] = height pipeline_args["width"] = width if tracking_path and generate_type == "i2v": tracking_maps = load_video(tracking_path) tracking_maps = torch.stack([ torch.from_numpy(np.array(frame)).permute(2, 0, 1).float() / 255.0 for frame in tracking_maps ]).to(device=device, dtype=dtype) tracking_video = tracking_maps tracking_maps = tracking_maps.unsqueeze(0) tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) with torch.no_grad(): tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) pipeline_args["tracking_maps"] = tracking_maps pipeline_args["tracking_image"] = tracking_maps[:, :1] with torch.no_grad(): video_generate = pipe(**pipeline_args).frames[0] output_dir = os.path.join(data_root, evaluation_dir) output_name = f"{os.path.splitext(os.path.basename(image_or_video_path))[0]}.mp4" output_file = os.path.join(output_dir, output_name) os.makedirs(output_dir, exist_ok=True) export_concat_video(video_generate, video, tracking_video, output_file, fps=fps) def create_frame_grid(frames: List[np.ndarray], interval: int = 9, max_cols: int = 7) -> np.ndarray: """ Arrange video frames into a grid image by sampling at intervals Args: frames: List of video frames interval: Sampling interval max_cols: Maximum number of frames per row Returns: Grid image array """ # Sample frames at intervals sampled_frames = frames[::interval] # Calculate number of rows and columns n_frames = len(sampled_frames) n_cols = min(max_cols, n_frames) n_rows = (n_frames + n_cols - 1) // n_cols # Get height and width of single frame frame_height, frame_width = sampled_frames[0].shape[:2] # Create blank canvas grid = np.zeros((frame_height * n_rows, frame_width * n_cols, 3), dtype=np.uint8) # Fill frames for idx, frame in enumerate(sampled_frames): i = idx // n_cols j = idx % n_cols grid[i*frame_height:(i+1)*frame_height, j*frame_width:(j+1)*frame_width] = frame return grid def export_concat_video( generated_frames: List[PIL.Image.Image], original_video: torch.Tensor, tracking_maps: torch.Tensor = None, output_video_path: str = None, fps: int = 8 ) -> str: """ Export generated video frames, original video and tracking maps as video files, and save sampled frames to different folders """ import imageio import os if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name # Create subdirectories base_dir = os.path.dirname(output_video_path) generated_dir = os.path.join(base_dir, "generated") # For storing generated videos group_dir = os.path.join(base_dir, "group") # For storing concatenated videos # Get filename (without path) and create video-specific folder filename = os.path.basename(output_video_path) name_without_ext = os.path.splitext(filename)[0] video_frames_dir = os.path.join(base_dir, "frames", name_without_ext) # frames/video_name/ # Create three subdirectories under video-specific folder groundtruth_dir = os.path.join(video_frames_dir, "gt") generated_frames_dir = os.path.join(video_frames_dir, "generated") tracking_dir = os.path.join(video_frames_dir, "tracking") # Create all required directories os.makedirs(generated_dir, exist_ok=True) os.makedirs(group_dir, exist_ok=True) os.makedirs(groundtruth_dir, exist_ok=True) os.makedirs(generated_frames_dir, exist_ok=True) os.makedirs(tracking_dir, exist_ok=True) # Convert original video tensor to numpy array and adjust format original_frames = [] for frame in original_video: frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy() frame = ((frame + 1.0) * 127.5).astype(np.uint8) original_frames.append(frame) tracking_frames = [] if tracking_maps is not None: for frame in tracking_maps: frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy() frame = ((frame + 1.0) * 127.5).astype(np.uint8) tracking_frames.append(frame) # Ensure all videos have same number of frames num_frames = min(len(generated_frames), len(original_frames)) if tracking_maps is not None: num_frames = min(num_frames, len(tracking_frames)) generated_frames = generated_frames[:num_frames] original_frames = original_frames[:num_frames] if tracking_maps is not None: tracking_frames = tracking_frames[:num_frames] # Convert generated PIL images to numpy arrays generated_frames_np = [np.array(frame) for frame in generated_frames] # Save generated video separately to generated folder gen_video_path = os.path.join(generated_dir, f"{name_without_ext}_generated.mp4") with imageio.get_writer(gen_video_path, fps=fps) as writer: for frame in generated_frames_np: writer.append_data(frame) # Concatenate frames vertically and save sampled frames concat_frames = [] for i in range(num_frames): gen_frame = generated_frames_np[i] orig_frame = original_frames[i] width = min(gen_frame.shape[1], orig_frame.shape[1]) height = orig_frame.shape[0] gen_frame = Image.fromarray(gen_frame).resize((width, height)) gen_frame = np.array(gen_frame) orig_frame = Image.fromarray(orig_frame).resize((width, height)) orig_frame = np.array(orig_frame) if tracking_maps is not None: track_frame = tracking_frames[i] track_frame = Image.fromarray(track_frame).resize((width, height)) track_frame = np.array(track_frame) right_concat = np.concatenate([orig_frame, track_frame], axis=0) right_concat_pil = Image.fromarray(right_concat) new_height = right_concat.shape[0] // 2 new_width = right_concat.shape[1] // 2 right_concat_resized = right_concat_pil.resize((new_width, new_height)) right_concat_resized = np.array(right_concat_resized) concat_frame = np.concatenate([gen_frame, right_concat_resized], axis=1) else: orig_frame_pil = Image.fromarray(orig_frame) new_height = orig_frame.shape[0] // 2 new_width = orig_frame.shape[1] // 2 orig_frame_resized = orig_frame_pil.resize((new_width, new_height)) orig_frame_resized = np.array(orig_frame_resized) concat_frame = np.concatenate([gen_frame, orig_frame_resized], axis=1) concat_frames.append(concat_frame) # Save every 9 frames of each type of frame if i % 9 == 0: # Save generated frame gen_frame_path = os.path.join(generated_frames_dir, f"{i:04d}.png") Image.fromarray(gen_frame).save(gen_frame_path) # Save original frame gt_frame_path = os.path.join(groundtruth_dir, f"{i:04d}.png") Image.fromarray(orig_frame).save(gt_frame_path) # If tracking maps, save tracking frame if tracking_maps is not None: track_frame_path = os.path.join(tracking_dir, f"{i:04d}.png") Image.fromarray(track_frame).save(track_frame_path) # Export concatenated video to group folder group_video_path = os.path.join(group_dir, filename) with imageio.get_writer(group_video_path, fps=fps) as writer: for frame in concat_frames: writer.append_data(frame) return group_video_path if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser.add_argument("--prompt", type=str, help="Optional: override the prompt from dataset") parser.add_argument( "--image_or_video_path", type=str, default=None, help="The path of the image to be used as the background of the video", ) parser.add_argument( "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" ) parser.add_argument( "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" ) parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") parser.add_argument( "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" ) parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") parser.add_argument( "--generate_type", type=str, default="i2v", help="The type of video generation (e.g., 'i2v', 'i2vo')" ) parser.add_argument( "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" ) parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") parser.add_argument("--tracking_path", type=str, default=None, help="The path of the tracking maps to be used") # Dataset related parameters are required parser.add_argument("--data_root", type=str, required=True, help="Root directory of the dataset") parser.add_argument("--caption_column", type=str, required=True, help="Name of the caption column") parser.add_argument("--tracking_column", type=str, required=True, help="Name of the tracking column") parser.add_argument("--video_column", type=str, required=True, help="Name of the video column") parser.add_argument("--image_paths", type=str, required=False, help="Name of the image column") # Add num_samples parameter parser.add_argument("--num_samples", type=int, default=-1, help="Number of samples to process. -1 means process all samples") # Add evaluation_dir parameter parser.add_argument("--evaluation_dir", type=str, default="evaluations", help="Name of the directory to store evaluation results") # Add fps parameter parser.add_argument("--fps", type=int, default=8, help="Frames per second for the output video") args = parser.parse_args() dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 # If prompt is not provided, generate_video function will use prompts from dataset generate_video( prompt=args.prompt, # Can be None model_path=args.model_path, tracking_path=args.tracking_path, image_paths=args.image_paths, output_path=args.output_path, image_or_video_path=args.image_or_video_path, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, num_videos_per_prompt=args.num_videos_per_prompt, dtype=dtype, generate_type=args.generate_type, seed=args.seed, data_root=args.data_root, caption_column=args.caption_column, tracking_column=args.tracking_column, video_column=args.video_column, num_samples=args.num_samples, evaluation_dir=args.evaluation_dir, fps=args.fps, )