Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
from typing import Any, Dict, List, Literal, Tuple | |
import pandas as pd | |
import os | |
import sys | |
import torch | |
from diffusers import ( | |
CogVideoXPipeline, | |
CogVideoXDDIMScheduler, | |
CogVideoXDPMScheduler, | |
CogVideoXImageToVideoPipeline, | |
CogVideoXVideoToVideoPipeline, | |
) | |
from diffusers.utils import export_to_video, load_image, load_video | |
import numpy as np | |
import random | |
import cv2 | |
from pathlib import Path | |
import decord | |
from torchvision import transforms | |
from torchvision.transforms.functional import resize | |
import PIL.Image | |
from PIL import Image | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(os.path.join(current_dir, '..')) | |
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking, CogVideoXPipelineTracking, CogVideoXVideoToVideoPipelineTracking | |
from training.dataset import VideoDataset, VideoDatasetWithResizingTracking | |
class VideoDatasetWithResizingTrackingEval(VideoDataset): | |
def __init__(self, *args, **kwargs) -> None: | |
self.tracking_column = kwargs.pop("tracking_column", None) | |
self.image_paths = kwargs.pop("image_paths", None) | |
super().__init__(*args, **kwargs) | |
def _preprocess_video(self, path: Path, tracking_path: Path, image_paths: Path = None) -> torch.Tensor: | |
if self.load_tensors: | |
return self._load_preprocessed_latents_and_embeds(path, tracking_path) | |
else: | |
video_reader = decord.VideoReader(uri=path.as_posix()) | |
video_num_frames = len(video_reader) | |
nearest_frame_bucket = min( | |
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) | |
) | |
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) | |
frames = video_reader.get_batch(frame_indices) | |
frames = frames[:nearest_frame_bucket].float() | |
frames = frames.permute(0, 3, 1, 2).contiguous() | |
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) | |
frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) | |
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) | |
image = Image.open(image_paths) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image = torch.from_numpy(np.array(image)).float() | |
image = image.permute(2, 0, 1).contiguous() | |
image = resize(image, nearest_res) | |
image = self.video_transforms(image) | |
tracking_reader = decord.VideoReader(uri=tracking_path.as_posix()) | |
tracking_frames = tracking_reader.get_batch(frame_indices) | |
tracking_frames = tracking_frames[:nearest_frame_bucket].float() | |
tracking_frames = tracking_frames.permute(0, 3, 1, 2).contiguous() | |
tracking_frames_resized = torch.stack([resize(tracking_frame, nearest_res) for tracking_frame in tracking_frames], dim=0) | |
tracking_frames = torch.stack([self.video_transforms(tracking_frame) for tracking_frame in tracking_frames_resized], dim=0) | |
return image, frames, tracking_frames, None | |
def _find_nearest_resolution(self, height, width): | |
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) | |
return nearest_res[1], nearest_res[2] | |
def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str], List[str]]: | |
if not self.data_root.exists(): | |
raise ValueError("Root folder for videos does not exist") | |
prompt_path = self.data_root.joinpath(self.caption_column) | |
video_path = self.data_root.joinpath(self.video_column) | |
tracking_path = self.data_root.joinpath(self.tracking_column) | |
image_paths = self.data_root.joinpath(self.image_paths) | |
if not prompt_path.exists() or not prompt_path.is_file(): | |
raise ValueError( | |
"Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." | |
) | |
if not video_path.exists() or not video_path.is_file(): | |
raise ValueError( | |
"Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." | |
) | |
if not tracking_path.exists() or not tracking_path.is_file(): | |
raise ValueError( | |
"Expected `--tracking_column` to be path to a file in `--data_root` containing line-separated tracking information." | |
) | |
with open(prompt_path, "r", encoding="utf-8") as file: | |
prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] | |
with open(video_path, "r", encoding="utf-8") as file: | |
video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] | |
with open(tracking_path, "r", encoding="utf-8") as file: | |
tracking_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] | |
with open(image_paths, "r", encoding="utf-8") as file: | |
image_paths_list = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] | |
if not self.load_tensors and any(not path.is_file() for path in video_paths): | |
raise ValueError( | |
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." | |
) | |
self.tracking_paths = tracking_paths | |
self.image_paths = image_paths_list | |
return prompts, video_paths | |
def _load_dataset_from_csv(self) -> Tuple[List[str], List[str], List[str]]: | |
df = pd.read_csv(self.dataset_file) | |
prompts = df[self.caption_column].tolist() | |
video_paths = df[self.video_column].tolist() | |
tracking_paths = df[self.tracking_column].tolist() | |
image_paths = df[self.image_paths].tolist() | |
video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] | |
tracking_paths = [self.data_root.joinpath(line.strip()) for line in tracking_paths] | |
image_paths = [self.data_root.joinpath(line.strip()) for line in image_paths] | |
if any(not path.is_file() for path in video_paths): | |
raise ValueError( | |
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found at least one path that is not a valid file." | |
) | |
self.tracking_paths = tracking_paths | |
self.image_paths = image_paths | |
return prompts, video_paths | |
def __getitem__(self, index: int) -> Dict[str, Any]: | |
if isinstance(index, list): | |
return index | |
if self.load_tensors: | |
image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index], self.tracking_paths[index]) | |
# The VAE's temporal compression ratio is 4. | |
# The VAE's spatial compression ratio is 8. | |
latent_num_frames = video_latents.size(1) | |
if latent_num_frames % 2 == 0: | |
num_frames = latent_num_frames * 4 | |
else: | |
num_frames = (latent_num_frames - 1) * 4 + 1 | |
height = video_latents.size(2) * 8 | |
width = video_latents.size(3) * 8 | |
return { | |
"prompt": prompt_embeds, | |
"image": image_latents, | |
"video": video_latents, | |
"tracking_map": tracking_map, | |
"video_metadata": { | |
"num_frames": num_frames, | |
"height": height, | |
"width": width, | |
}, | |
} | |
else: | |
image, video, tracking_map, _ = self._preprocess_video(self.video_paths[index], self.tracking_paths[index], self.image_paths[index]) | |
return { | |
"prompt": self.id_token + self.prompts[index], | |
"image": image, | |
"video": video, | |
"tracking_map": tracking_map, | |
"video_metadata": { | |
"num_frames": video.shape[0], | |
"height": video.shape[2], | |
"width": video.shape[3], | |
}, | |
} | |
def _load_preprocessed_latents_and_embeds(self, path: Path, tracking_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: | |
filename_without_ext = path.name.split(".")[0] | |
pt_filename = f"{filename_without_ext}.pt" | |
# The current path is something like: /a/b/c/d/videos/00001.mp4 | |
# We need to reach: /a/b/c/d/video_latents/00001.pt | |
image_latents_path = path.parent.parent.joinpath("image_latents") | |
video_latents_path = path.parent.parent.joinpath("video_latents") | |
tracking_map_path = path.parent.parent.joinpath("tracking_map") | |
embeds_path = path.parent.parent.joinpath("prompt_embeds") | |
if ( | |
not video_latents_path.exists() | |
or not embeds_path.exists() | |
or not tracking_map_path.exists() | |
or (self.image_to_video and not image_latents_path.exists()) | |
): | |
raise ValueError( | |
f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains folders named `video_latents`, `prompt_embeds`, and `tracking_map`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." | |
) | |
if self.image_to_video: | |
image_latent_filepath = image_latents_path.joinpath(pt_filename) | |
video_latent_filepath = video_latents_path.joinpath(pt_filename) | |
tracking_map_filepath = tracking_map_path.joinpath(pt_filename) | |
embeds_filepath = embeds_path.joinpath(pt_filename) | |
if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not tracking_map_filepath.is_file(): | |
if self.image_to_video: | |
image_latent_filepath = image_latent_filepath.as_posix() | |
video_latent_filepath = video_latent_filepath.as_posix() | |
tracking_map_filepath = tracking_map_filepath.as_posix() | |
embeds_filepath = embeds_filepath.as_posix() | |
raise ValueError( | |
f"The file {video_latent_filepath=} or {embeds_filepath=} or {tracking_map_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." | |
) | |
images = ( | |
torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None | |
) | |
latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) | |
tracking_map = torch.load(tracking_map_filepath, map_location="cpu", weights_only=True) | |
embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) | |
return images, latents, tracking_map, embeds | |
def sample_from_dataset( | |
data_root: str, | |
caption_column: str, | |
tracking_column: str, | |
image_paths: str, | |
video_column: str, | |
num_samples: int = -1, | |
random_seed: int = 42 | |
): | |
"""Sample from dataset""" | |
if image_paths: | |
# If image_paths is provided, use VideoDatasetWithResizingTrackingEval | |
dataset = VideoDatasetWithResizingTrackingEval( | |
data_root=data_root, | |
caption_column=caption_column, | |
tracking_column=tracking_column, | |
image_paths=image_paths, | |
video_column=video_column, | |
max_num_frames=49, | |
load_tensors=False, | |
random_flip=None, | |
frame_buckets=[49], | |
image_to_video=True | |
) | |
else: | |
# If image_paths is not provided, use VideoDatasetWithResizingTracking | |
dataset = VideoDatasetWithResizingTracking( | |
data_root=data_root, | |
caption_column=caption_column, | |
tracking_column=tracking_column, | |
video_column=video_column, | |
max_num_frames=49, | |
load_tensors=False, | |
random_flip=None, | |
frame_buckets=[49], | |
image_to_video=True | |
) | |
# Set random seed | |
random.seed(random_seed) | |
# Randomly sample from dataset | |
total_samples = len(dataset) | |
if num_samples == -1: | |
# If num_samples is -1, process all samples | |
selected_indices = range(total_samples) | |
else: | |
selected_indices = random.sample(range(total_samples), min(num_samples, total_samples)) | |
samples = [] | |
for idx in selected_indices: | |
sample = dataset[idx] | |
# Get data based on dataset.__getitem__ return value | |
image = sample["image"] # Already processed tensor | |
video = sample["video"] # Already processed tensor | |
tracking_map = sample["tracking_map"] # Already processed tensor | |
prompt = sample["prompt"] | |
samples.append({ | |
"prompt": prompt, | |
"tracking_frame": tracking_map[0], # Get first frame | |
"video_frame": image, # Get first frame | |
"video": video, # Complete video | |
"tracking_maps": tracking_map, # Complete tracking maps | |
"height": sample["video_metadata"]["height"], | |
"width": sample["video_metadata"]["width"] | |
}) | |
return samples | |
def generate_video( | |
prompt: str, | |
model_path: str, | |
tracking_path: str = None, | |
output_path: str = "./output.mp4", | |
image_or_video_path: str = "", | |
num_inference_steps: int = 50, | |
guidance_scale: float = 6.0, | |
num_videos_per_prompt: int = 1, | |
dtype: torch.dtype = torch.bfloat16, | |
generate_type: str = Literal["i2v", "i2vo"], # i2v: image to video, i2vo: original CogVideoX-5b-I2V | |
seed: int = 42, | |
data_root: str = None, | |
caption_column: str = None, | |
tracking_column: str = None, | |
video_column: str = None, | |
image_paths: str = None, | |
num_samples: int = -1, | |
evaluation_dir: str = "evaluations", | |
fps: int = 8, | |
): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# If dataset parameters are provided, sample from dataset | |
samples = None | |
if all([data_root, caption_column, tracking_column, video_column]): | |
samples = sample_from_dataset( | |
data_root=data_root, | |
caption_column=caption_column, | |
tracking_column=tracking_column, | |
image_paths=image_paths, | |
video_column=video_column, | |
num_samples=num_samples, | |
random_seed=seed | |
) | |
# Load model and data | |
if generate_type == "i2v": | |
pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype) | |
if not samples: | |
image = load_image(image=image_or_video_path) | |
height, width = image.height, image.width | |
else: | |
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=dtype) | |
if not samples: | |
image = load_image(image=image_or_video_path) | |
height, width = image.height, image.width | |
# Set model parameters | |
pipe.to(device, dtype=dtype) | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
pipe.transformer.eval() | |
pipe.text_encoder.eval() | |
pipe.vae.eval() | |
pipe.transformer.gradient_checkpointing = False | |
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
# Generate video | |
if samples: | |
from tqdm import tqdm | |
for i, sample in tqdm(enumerate(samples), desc="Samples Num:"): | |
print(f"Prompt: {sample['prompt'][:30]}") | |
tracking_frame = sample["tracking_frame"].to(device=device, dtype=dtype) | |
video_frame = sample["video_frame"].to(device=device, dtype=dtype) | |
video = sample["video"].to(device=device, dtype=dtype) | |
tracking_maps = sample["tracking_maps"].to(device=device, dtype=dtype) | |
# VAE | |
print("encoding tracking maps") | |
tracking_video = tracking_maps | |
tracking_maps = tracking_maps.unsqueeze(0) | |
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] | |
with torch.no_grad(): | |
tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist | |
tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor | |
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] | |
pipeline_args = { | |
"prompt": sample["prompt"], | |
"negative_prompt": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", | |
"num_inference_steps": num_inference_steps, | |
"num_frames": 49, | |
"use_dynamic_cfg": True, | |
"guidance_scale": guidance_scale, | |
"generator": torch.Generator(device=device).manual_seed(seed), | |
"height": sample["height"], | |
"width": sample["width"] | |
} | |
pipeline_args["image"] = (video_frame + 1.0) / 2.0 | |
if tracking_column and generate_type == "i2v": | |
pipeline_args["tracking_maps"] = tracking_maps | |
pipeline_args["tracking_image"] = (tracking_frame.unsqueeze(0) + 1.0) / 2.0 | |
with torch.no_grad(): | |
video_generate = pipe(**pipeline_args).frames[0] | |
output_dir = os.path.join(data_root, evaluation_dir) | |
output_name = f"{i:04d}.mp4" | |
output_file = os.path.join(output_dir, output_name) | |
os.makedirs(output_dir, exist_ok=True) | |
export_concat_video(video_generate, video, tracking_video, output_file, fps=fps) | |
else: | |
pipeline_args = { | |
"prompt": prompt, | |
"num_videos_per_prompt": num_videos_per_prompt, | |
"num_inference_steps": num_inference_steps, | |
"num_frames": 49, | |
"use_dynamic_cfg": True, | |
"guidance_scale": guidance_scale, | |
"generator": torch.Generator().manual_seed(seed), | |
} | |
pipeline_args["video"] = video | |
pipeline_args["image"] = image | |
pipeline_args["height"] = height | |
pipeline_args["width"] = width | |
if tracking_path and generate_type == "i2v": | |
tracking_maps = load_video(tracking_path) | |
tracking_maps = torch.stack([ | |
torch.from_numpy(np.array(frame)).permute(2, 0, 1).float() / 255.0 | |
for frame in tracking_maps | |
]).to(device=device, dtype=dtype) | |
tracking_video = tracking_maps | |
tracking_maps = tracking_maps.unsqueeze(0) | |
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) | |
with torch.no_grad(): | |
tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist | |
tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor | |
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) | |
pipeline_args["tracking_maps"] = tracking_maps | |
pipeline_args["tracking_image"] = tracking_maps[:, :1] | |
with torch.no_grad(): | |
video_generate = pipe(**pipeline_args).frames[0] | |
output_dir = os.path.join(data_root, evaluation_dir) | |
output_name = f"{os.path.splitext(os.path.basename(image_or_video_path))[0]}.mp4" | |
output_file = os.path.join(output_dir, output_name) | |
os.makedirs(output_dir, exist_ok=True) | |
export_concat_video(video_generate, video, tracking_video, output_file, fps=fps) | |
def create_frame_grid(frames: List[np.ndarray], interval: int = 9, max_cols: int = 7) -> np.ndarray: | |
""" | |
Arrange video frames into a grid image by sampling at intervals | |
Args: | |
frames: List of video frames | |
interval: Sampling interval | |
max_cols: Maximum number of frames per row | |
Returns: | |
Grid image array | |
""" | |
# Sample frames at intervals | |
sampled_frames = frames[::interval] | |
# Calculate number of rows and columns | |
n_frames = len(sampled_frames) | |
n_cols = min(max_cols, n_frames) | |
n_rows = (n_frames + n_cols - 1) // n_cols | |
# Get height and width of single frame | |
frame_height, frame_width = sampled_frames[0].shape[:2] | |
# Create blank canvas | |
grid = np.zeros((frame_height * n_rows, frame_width * n_cols, 3), dtype=np.uint8) | |
# Fill frames | |
for idx, frame in enumerate(sampled_frames): | |
i = idx // n_cols | |
j = idx % n_cols | |
grid[i*frame_height:(i+1)*frame_height, j*frame_width:(j+1)*frame_width] = frame | |
return grid | |
def export_concat_video( | |
generated_frames: List[PIL.Image.Image], | |
original_video: torch.Tensor, | |
tracking_maps: torch.Tensor = None, | |
output_video_path: str = None, | |
fps: int = 8 | |
) -> str: | |
""" | |
Export generated video frames, original video and tracking maps as video files, | |
and save sampled frames to different folders | |
""" | |
import imageio | |
import os | |
if output_video_path is None: | |
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name | |
# Create subdirectories | |
base_dir = os.path.dirname(output_video_path) | |
generated_dir = os.path.join(base_dir, "generated") # For storing generated videos | |
group_dir = os.path.join(base_dir, "group") # For storing concatenated videos | |
# Get filename (without path) and create video-specific folder | |
filename = os.path.basename(output_video_path) | |
name_without_ext = os.path.splitext(filename)[0] | |
video_frames_dir = os.path.join(base_dir, "frames", name_without_ext) # frames/video_name/ | |
# Create three subdirectories under video-specific folder | |
groundtruth_dir = os.path.join(video_frames_dir, "gt") | |
generated_frames_dir = os.path.join(video_frames_dir, "generated") | |
tracking_dir = os.path.join(video_frames_dir, "tracking") | |
# Create all required directories | |
os.makedirs(generated_dir, exist_ok=True) | |
os.makedirs(group_dir, exist_ok=True) | |
os.makedirs(groundtruth_dir, exist_ok=True) | |
os.makedirs(generated_frames_dir, exist_ok=True) | |
os.makedirs(tracking_dir, exist_ok=True) | |
# Convert original video tensor to numpy array and adjust format | |
original_frames = [] | |
for frame in original_video: | |
frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy() | |
frame = ((frame + 1.0) * 127.5).astype(np.uint8) | |
original_frames.append(frame) | |
tracking_frames = [] | |
if tracking_maps is not None: | |
for frame in tracking_maps: | |
frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy() | |
frame = ((frame + 1.0) * 127.5).astype(np.uint8) | |
tracking_frames.append(frame) | |
# Ensure all videos have same number of frames | |
num_frames = min(len(generated_frames), len(original_frames)) | |
if tracking_maps is not None: | |
num_frames = min(num_frames, len(tracking_frames)) | |
generated_frames = generated_frames[:num_frames] | |
original_frames = original_frames[:num_frames] | |
if tracking_maps is not None: | |
tracking_frames = tracking_frames[:num_frames] | |
# Convert generated PIL images to numpy arrays | |
generated_frames_np = [np.array(frame) for frame in generated_frames] | |
# Save generated video separately to generated folder | |
gen_video_path = os.path.join(generated_dir, f"{name_without_ext}_generated.mp4") | |
with imageio.get_writer(gen_video_path, fps=fps) as writer: | |
for frame in generated_frames_np: | |
writer.append_data(frame) | |
# Concatenate frames vertically and save sampled frames | |
concat_frames = [] | |
for i in range(num_frames): | |
gen_frame = generated_frames_np[i] | |
orig_frame = original_frames[i] | |
width = min(gen_frame.shape[1], orig_frame.shape[1]) | |
height = orig_frame.shape[0] | |
gen_frame = Image.fromarray(gen_frame).resize((width, height)) | |
gen_frame = np.array(gen_frame) | |
orig_frame = Image.fromarray(orig_frame).resize((width, height)) | |
orig_frame = np.array(orig_frame) | |
if tracking_maps is not None: | |
track_frame = tracking_frames[i] | |
track_frame = Image.fromarray(track_frame).resize((width, height)) | |
track_frame = np.array(track_frame) | |
right_concat = np.concatenate([orig_frame, track_frame], axis=0) | |
right_concat_pil = Image.fromarray(right_concat) | |
new_height = right_concat.shape[0] // 2 | |
new_width = right_concat.shape[1] // 2 | |
right_concat_resized = right_concat_pil.resize((new_width, new_height)) | |
right_concat_resized = np.array(right_concat_resized) | |
concat_frame = np.concatenate([gen_frame, right_concat_resized], axis=1) | |
else: | |
orig_frame_pil = Image.fromarray(orig_frame) | |
new_height = orig_frame.shape[0] // 2 | |
new_width = orig_frame.shape[1] // 2 | |
orig_frame_resized = orig_frame_pil.resize((new_width, new_height)) | |
orig_frame_resized = np.array(orig_frame_resized) | |
concat_frame = np.concatenate([gen_frame, orig_frame_resized], axis=1) | |
concat_frames.append(concat_frame) | |
# Save every 9 frames of each type of frame | |
if i % 9 == 0: | |
# Save generated frame | |
gen_frame_path = os.path.join(generated_frames_dir, f"{i:04d}.png") | |
Image.fromarray(gen_frame).save(gen_frame_path) | |
# Save original frame | |
gt_frame_path = os.path.join(groundtruth_dir, f"{i:04d}.png") | |
Image.fromarray(orig_frame).save(gt_frame_path) | |
# If tracking maps, save tracking frame | |
if tracking_maps is not None: | |
track_frame_path = os.path.join(tracking_dir, f"{i:04d}.png") | |
Image.fromarray(track_frame).save(track_frame_path) | |
# Export concatenated video to group folder | |
group_video_path = os.path.join(group_dir, filename) | |
with imageio.get_writer(group_video_path, fps=fps) as writer: | |
for frame in concat_frames: | |
writer.append_data(frame) | |
return group_video_path | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") | |
parser.add_argument("--prompt", type=str, help="Optional: override the prompt from dataset") | |
parser.add_argument( | |
"--image_or_video_path", | |
type=str, | |
default=None, | |
help="The path of the image to be used as the background of the video", | |
) | |
parser.add_argument( | |
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" | |
) | |
parser.add_argument( | |
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" | |
) | |
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") | |
parser.add_argument( | |
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" | |
) | |
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") | |
parser.add_argument( | |
"--generate_type", type=str, default="i2v", help="The type of video generation (e.g., 'i2v', 'i2vo')" | |
) | |
parser.add_argument( | |
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" | |
) | |
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") | |
parser.add_argument("--tracking_path", type=str, default=None, help="The path of the tracking maps to be used") | |
# Dataset related parameters are required | |
parser.add_argument("--data_root", type=str, required=True, help="Root directory of the dataset") | |
parser.add_argument("--caption_column", type=str, required=True, help="Name of the caption column") | |
parser.add_argument("--tracking_column", type=str, required=True, help="Name of the tracking column") | |
parser.add_argument("--video_column", type=str, required=True, help="Name of the video column") | |
parser.add_argument("--image_paths", type=str, required=False, help="Name of the image column") | |
# Add num_samples parameter | |
parser.add_argument("--num_samples", type=int, default=-1, | |
help="Number of samples to process. -1 means process all samples") | |
# Add evaluation_dir parameter | |
parser.add_argument("--evaluation_dir", type=str, default="evaluations", | |
help="Name of the directory to store evaluation results") | |
# Add fps parameter | |
parser.add_argument("--fps", type=int, default=8, | |
help="Frames per second for the output video") | |
args = parser.parse_args() | |
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 | |
# If prompt is not provided, generate_video function will use prompts from dataset | |
generate_video( | |
prompt=args.prompt, # Can be None | |
model_path=args.model_path, | |
tracking_path=args.tracking_path, | |
image_paths=args.image_paths, | |
output_path=args.output_path, | |
image_or_video_path=args.image_or_video_path, | |
num_inference_steps=args.num_inference_steps, | |
guidance_scale=args.guidance_scale, | |
num_videos_per_prompt=args.num_videos_per_prompt, | |
dtype=dtype, | |
generate_type=args.generate_type, | |
seed=args.seed, | |
data_root=args.data_root, | |
caption_column=args.caption_column, | |
tracking_column=args.tracking_column, | |
video_column=args.video_column, | |
num_samples=args.num_samples, | |
evaluation_dir=args.evaluation_dir, | |
fps=args.fps, | |
) |