Diffusion-As-Shader / testing /evaluation.py
Beijia11
init
3aba902
import argparse
from typing import Any, Dict, List, Literal, Tuple
import pandas as pd
import os
import sys
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXVideoToVideoPipeline,
)
from diffusers.utils import export_to_video, load_image, load_video
import numpy as np
import random
import cv2
from pathlib import Path
import decord
from torchvision import transforms
from torchvision.transforms.functional import resize
import PIL.Image
from PIL import Image
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(current_dir, '..'))
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking, CogVideoXPipelineTracking, CogVideoXVideoToVideoPipelineTracking
from training.dataset import VideoDataset, VideoDatasetWithResizingTracking
class VideoDatasetWithResizingTrackingEval(VideoDataset):
def __init__(self, *args, **kwargs) -> None:
self.tracking_column = kwargs.pop("tracking_column", None)
self.image_paths = kwargs.pop("image_paths", None)
super().__init__(*args, **kwargs)
def _preprocess_video(self, path: Path, tracking_path: Path, image_paths: Path = None) -> torch.Tensor:
if self.load_tensors:
return self._load_preprocessed_latents_and_embeds(path, tracking_path)
else:
video_reader = decord.VideoReader(uri=path.as_posix())
video_num_frames = len(video_reader)
nearest_frame_bucket = min(
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
)
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
frames = video_reader.get_batch(frame_indices)
frames = frames[:nearest_frame_bucket].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
image = Image.open(image_paths)
if image.mode != 'RGB':
image = image.convert('RGB')
image = torch.from_numpy(np.array(image)).float()
image = image.permute(2, 0, 1).contiguous()
image = resize(image, nearest_res)
image = self.video_transforms(image)
tracking_reader = decord.VideoReader(uri=tracking_path.as_posix())
tracking_frames = tracking_reader.get_batch(frame_indices)
tracking_frames = tracking_frames[:nearest_frame_bucket].float()
tracking_frames = tracking_frames.permute(0, 3, 1, 2).contiguous()
tracking_frames_resized = torch.stack([resize(tracking_frame, nearest_res) for tracking_frame in tracking_frames], dim=0)
tracking_frames = torch.stack([self.video_transforms(tracking_frame) for tracking_frame in tracking_frames_resized], dim=0)
return image, frames, tracking_frames, None
def _find_nearest_resolution(self, height, width):
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
return nearest_res[1], nearest_res[2]
def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str], List[str]]:
if not self.data_root.exists():
raise ValueError("Root folder for videos does not exist")
prompt_path = self.data_root.joinpath(self.caption_column)
video_path = self.data_root.joinpath(self.video_column)
tracking_path = self.data_root.joinpath(self.tracking_column)
image_paths = self.data_root.joinpath(self.image_paths)
if not prompt_path.exists() or not prompt_path.is_file():
raise ValueError(
"Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
)
if not video_path.exists() or not video_path.is_file():
raise ValueError(
"Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
)
if not tracking_path.exists() or not tracking_path.is_file():
raise ValueError(
"Expected `--tracking_column` to be path to a file in `--data_root` containing line-separated tracking information."
)
with open(prompt_path, "r", encoding="utf-8") as file:
prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
with open(video_path, "r", encoding="utf-8") as file:
video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
with open(tracking_path, "r", encoding="utf-8") as file:
tracking_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
with open(image_paths, "r", encoding="utf-8") as file:
image_paths_list = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
if not self.load_tensors and any(not path.is_file() for path in video_paths):
raise ValueError(
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
)
self.tracking_paths = tracking_paths
self.image_paths = image_paths_list
return prompts, video_paths
def _load_dataset_from_csv(self) -> Tuple[List[str], List[str], List[str]]:
df = pd.read_csv(self.dataset_file)
prompts = df[self.caption_column].tolist()
video_paths = df[self.video_column].tolist()
tracking_paths = df[self.tracking_column].tolist()
image_paths = df[self.image_paths].tolist()
video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
tracking_paths = [self.data_root.joinpath(line.strip()) for line in tracking_paths]
image_paths = [self.data_root.joinpath(line.strip()) for line in image_paths]
if any(not path.is_file() for path in video_paths):
raise ValueError(
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found at least one path that is not a valid file."
)
self.tracking_paths = tracking_paths
self.image_paths = image_paths
return prompts, video_paths
def __getitem__(self, index: int) -> Dict[str, Any]:
if isinstance(index, list):
return index
if self.load_tensors:
image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index], self.tracking_paths[index])
# The VAE's temporal compression ratio is 4.
# The VAE's spatial compression ratio is 8.
latent_num_frames = video_latents.size(1)
if latent_num_frames % 2 == 0:
num_frames = latent_num_frames * 4
else:
num_frames = (latent_num_frames - 1) * 4 + 1
height = video_latents.size(2) * 8
width = video_latents.size(3) * 8
return {
"prompt": prompt_embeds,
"image": image_latents,
"video": video_latents,
"tracking_map": tracking_map,
"video_metadata": {
"num_frames": num_frames,
"height": height,
"width": width,
},
}
else:
image, video, tracking_map, _ = self._preprocess_video(self.video_paths[index], self.tracking_paths[index], self.image_paths[index])
return {
"prompt": self.id_token + self.prompts[index],
"image": image,
"video": video,
"tracking_map": tracking_map,
"video_metadata": {
"num_frames": video.shape[0],
"height": video.shape[2],
"width": video.shape[3],
},
}
def _load_preprocessed_latents_and_embeds(self, path: Path, tracking_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
filename_without_ext = path.name.split(".")[0]
pt_filename = f"{filename_without_ext}.pt"
# The current path is something like: /a/b/c/d/videos/00001.mp4
# We need to reach: /a/b/c/d/video_latents/00001.pt
image_latents_path = path.parent.parent.joinpath("image_latents")
video_latents_path = path.parent.parent.joinpath("video_latents")
tracking_map_path = path.parent.parent.joinpath("tracking_map")
embeds_path = path.parent.parent.joinpath("prompt_embeds")
if (
not video_latents_path.exists()
or not embeds_path.exists()
or not tracking_map_path.exists()
or (self.image_to_video and not image_latents_path.exists())
):
raise ValueError(
f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains folders named `video_latents`, `prompt_embeds`, and `tracking_map`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
)
if self.image_to_video:
image_latent_filepath = image_latents_path.joinpath(pt_filename)
video_latent_filepath = video_latents_path.joinpath(pt_filename)
tracking_map_filepath = tracking_map_path.joinpath(pt_filename)
embeds_filepath = embeds_path.joinpath(pt_filename)
if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not tracking_map_filepath.is_file():
if self.image_to_video:
image_latent_filepath = image_latent_filepath.as_posix()
video_latent_filepath = video_latent_filepath.as_posix()
tracking_map_filepath = tracking_map_filepath.as_posix()
embeds_filepath = embeds_filepath.as_posix()
raise ValueError(
f"The file {video_latent_filepath=} or {embeds_filepath=} or {tracking_map_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
)
images = (
torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
)
latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
tracking_map = torch.load(tracking_map_filepath, map_location="cpu", weights_only=True)
embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
return images, latents, tracking_map, embeds
def sample_from_dataset(
data_root: str,
caption_column: str,
tracking_column: str,
image_paths: str,
video_column: str,
num_samples: int = -1,
random_seed: int = 42
):
"""Sample from dataset"""
if image_paths:
# If image_paths is provided, use VideoDatasetWithResizingTrackingEval
dataset = VideoDatasetWithResizingTrackingEval(
data_root=data_root,
caption_column=caption_column,
tracking_column=tracking_column,
image_paths=image_paths,
video_column=video_column,
max_num_frames=49,
load_tensors=False,
random_flip=None,
frame_buckets=[49],
image_to_video=True
)
else:
# If image_paths is not provided, use VideoDatasetWithResizingTracking
dataset = VideoDatasetWithResizingTracking(
data_root=data_root,
caption_column=caption_column,
tracking_column=tracking_column,
video_column=video_column,
max_num_frames=49,
load_tensors=False,
random_flip=None,
frame_buckets=[49],
image_to_video=True
)
# Set random seed
random.seed(random_seed)
# Randomly sample from dataset
total_samples = len(dataset)
if num_samples == -1:
# If num_samples is -1, process all samples
selected_indices = range(total_samples)
else:
selected_indices = random.sample(range(total_samples), min(num_samples, total_samples))
samples = []
for idx in selected_indices:
sample = dataset[idx]
# Get data based on dataset.__getitem__ return value
image = sample["image"] # Already processed tensor
video = sample["video"] # Already processed tensor
tracking_map = sample["tracking_map"] # Already processed tensor
prompt = sample["prompt"]
samples.append({
"prompt": prompt,
"tracking_frame": tracking_map[0], # Get first frame
"video_frame": image, # Get first frame
"video": video, # Complete video
"tracking_maps": tracking_map, # Complete tracking maps
"height": sample["video_metadata"]["height"],
"width": sample["video_metadata"]["width"]
})
return samples
def generate_video(
prompt: str,
model_path: str,
tracking_path: str = None,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
dtype: torch.dtype = torch.bfloat16,
generate_type: str = Literal["i2v", "i2vo"], # i2v: image to video, i2vo: original CogVideoX-5b-I2V
seed: int = 42,
data_root: str = None,
caption_column: str = None,
tracking_column: str = None,
video_column: str = None,
image_paths: str = None,
num_samples: int = -1,
evaluation_dir: str = "evaluations",
fps: int = 8,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
# If dataset parameters are provided, sample from dataset
samples = None
if all([data_root, caption_column, tracking_column, video_column]):
samples = sample_from_dataset(
data_root=data_root,
caption_column=caption_column,
tracking_column=tracking_column,
image_paths=image_paths,
video_column=video_column,
num_samples=num_samples,
random_seed=seed
)
# Load model and data
if generate_type == "i2v":
pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype)
if not samples:
image = load_image(image=image_or_video_path)
height, width = image.height, image.width
else:
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=dtype)
if not samples:
image = load_image(image=image_or_video_path)
height, width = image.height, image.width
# Set model parameters
pipe.to(device, dtype=dtype)
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.transformer.eval()
pipe.text_encoder.eval()
pipe.vae.eval()
pipe.transformer.gradient_checkpointing = False
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Generate video
if samples:
from tqdm import tqdm
for i, sample in tqdm(enumerate(samples), desc="Samples Num:"):
print(f"Prompt: {sample['prompt'][:30]}")
tracking_frame = sample["tracking_frame"].to(device=device, dtype=dtype)
video_frame = sample["video_frame"].to(device=device, dtype=dtype)
video = sample["video"].to(device=device, dtype=dtype)
tracking_maps = sample["tracking_maps"].to(device=device, dtype=dtype)
# VAE
print("encoding tracking maps")
tracking_video = tracking_maps
tracking_maps = tracking_maps.unsqueeze(0)
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
with torch.no_grad():
tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist
tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
pipeline_args = {
"prompt": sample["prompt"],
"negative_prompt": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
"num_inference_steps": num_inference_steps,
"num_frames": 49,
"use_dynamic_cfg": True,
"guidance_scale": guidance_scale,
"generator": torch.Generator(device=device).manual_seed(seed),
"height": sample["height"],
"width": sample["width"]
}
pipeline_args["image"] = (video_frame + 1.0) / 2.0
if tracking_column and generate_type == "i2v":
pipeline_args["tracking_maps"] = tracking_maps
pipeline_args["tracking_image"] = (tracking_frame.unsqueeze(0) + 1.0) / 2.0
with torch.no_grad():
video_generate = pipe(**pipeline_args).frames[0]
output_dir = os.path.join(data_root, evaluation_dir)
output_name = f"{i:04d}.mp4"
output_file = os.path.join(output_dir, output_name)
os.makedirs(output_dir, exist_ok=True)
export_concat_video(video_generate, video, tracking_video, output_file, fps=fps)
else:
pipeline_args = {
"prompt": prompt,
"num_videos_per_prompt": num_videos_per_prompt,
"num_inference_steps": num_inference_steps,
"num_frames": 49,
"use_dynamic_cfg": True,
"guidance_scale": guidance_scale,
"generator": torch.Generator().manual_seed(seed),
}
pipeline_args["video"] = video
pipeline_args["image"] = image
pipeline_args["height"] = height
pipeline_args["width"] = width
if tracking_path and generate_type == "i2v":
tracking_maps = load_video(tracking_path)
tracking_maps = torch.stack([
torch.from_numpy(np.array(frame)).permute(2, 0, 1).float() / 255.0
for frame in tracking_maps
]).to(device=device, dtype=dtype)
tracking_video = tracking_maps
tracking_maps = tracking_maps.unsqueeze(0)
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4)
with torch.no_grad():
tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist
tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4)
pipeline_args["tracking_maps"] = tracking_maps
pipeline_args["tracking_image"] = tracking_maps[:, :1]
with torch.no_grad():
video_generate = pipe(**pipeline_args).frames[0]
output_dir = os.path.join(data_root, evaluation_dir)
output_name = f"{os.path.splitext(os.path.basename(image_or_video_path))[0]}.mp4"
output_file = os.path.join(output_dir, output_name)
os.makedirs(output_dir, exist_ok=True)
export_concat_video(video_generate, video, tracking_video, output_file, fps=fps)
def create_frame_grid(frames: List[np.ndarray], interval: int = 9, max_cols: int = 7) -> np.ndarray:
"""
Arrange video frames into a grid image by sampling at intervals
Args:
frames: List of video frames
interval: Sampling interval
max_cols: Maximum number of frames per row
Returns:
Grid image array
"""
# Sample frames at intervals
sampled_frames = frames[::interval]
# Calculate number of rows and columns
n_frames = len(sampled_frames)
n_cols = min(max_cols, n_frames)
n_rows = (n_frames + n_cols - 1) // n_cols
# Get height and width of single frame
frame_height, frame_width = sampled_frames[0].shape[:2]
# Create blank canvas
grid = np.zeros((frame_height * n_rows, frame_width * n_cols, 3), dtype=np.uint8)
# Fill frames
for idx, frame in enumerate(sampled_frames):
i = idx // n_cols
j = idx % n_cols
grid[i*frame_height:(i+1)*frame_height, j*frame_width:(j+1)*frame_width] = frame
return grid
def export_concat_video(
generated_frames: List[PIL.Image.Image],
original_video: torch.Tensor,
tracking_maps: torch.Tensor = None,
output_video_path: str = None,
fps: int = 8
) -> str:
"""
Export generated video frames, original video and tracking maps as video files,
and save sampled frames to different folders
"""
import imageio
import os
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
# Create subdirectories
base_dir = os.path.dirname(output_video_path)
generated_dir = os.path.join(base_dir, "generated") # For storing generated videos
group_dir = os.path.join(base_dir, "group") # For storing concatenated videos
# Get filename (without path) and create video-specific folder
filename = os.path.basename(output_video_path)
name_without_ext = os.path.splitext(filename)[0]
video_frames_dir = os.path.join(base_dir, "frames", name_without_ext) # frames/video_name/
# Create three subdirectories under video-specific folder
groundtruth_dir = os.path.join(video_frames_dir, "gt")
generated_frames_dir = os.path.join(video_frames_dir, "generated")
tracking_dir = os.path.join(video_frames_dir, "tracking")
# Create all required directories
os.makedirs(generated_dir, exist_ok=True)
os.makedirs(group_dir, exist_ok=True)
os.makedirs(groundtruth_dir, exist_ok=True)
os.makedirs(generated_frames_dir, exist_ok=True)
os.makedirs(tracking_dir, exist_ok=True)
# Convert original video tensor to numpy array and adjust format
original_frames = []
for frame in original_video:
frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy()
frame = ((frame + 1.0) * 127.5).astype(np.uint8)
original_frames.append(frame)
tracking_frames = []
if tracking_maps is not None:
for frame in tracking_maps:
frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy()
frame = ((frame + 1.0) * 127.5).astype(np.uint8)
tracking_frames.append(frame)
# Ensure all videos have same number of frames
num_frames = min(len(generated_frames), len(original_frames))
if tracking_maps is not None:
num_frames = min(num_frames, len(tracking_frames))
generated_frames = generated_frames[:num_frames]
original_frames = original_frames[:num_frames]
if tracking_maps is not None:
tracking_frames = tracking_frames[:num_frames]
# Convert generated PIL images to numpy arrays
generated_frames_np = [np.array(frame) for frame in generated_frames]
# Save generated video separately to generated folder
gen_video_path = os.path.join(generated_dir, f"{name_without_ext}_generated.mp4")
with imageio.get_writer(gen_video_path, fps=fps) as writer:
for frame in generated_frames_np:
writer.append_data(frame)
# Concatenate frames vertically and save sampled frames
concat_frames = []
for i in range(num_frames):
gen_frame = generated_frames_np[i]
orig_frame = original_frames[i]
width = min(gen_frame.shape[1], orig_frame.shape[1])
height = orig_frame.shape[0]
gen_frame = Image.fromarray(gen_frame).resize((width, height))
gen_frame = np.array(gen_frame)
orig_frame = Image.fromarray(orig_frame).resize((width, height))
orig_frame = np.array(orig_frame)
if tracking_maps is not None:
track_frame = tracking_frames[i]
track_frame = Image.fromarray(track_frame).resize((width, height))
track_frame = np.array(track_frame)
right_concat = np.concatenate([orig_frame, track_frame], axis=0)
right_concat_pil = Image.fromarray(right_concat)
new_height = right_concat.shape[0] // 2
new_width = right_concat.shape[1] // 2
right_concat_resized = right_concat_pil.resize((new_width, new_height))
right_concat_resized = np.array(right_concat_resized)
concat_frame = np.concatenate([gen_frame, right_concat_resized], axis=1)
else:
orig_frame_pil = Image.fromarray(orig_frame)
new_height = orig_frame.shape[0] // 2
new_width = orig_frame.shape[1] // 2
orig_frame_resized = orig_frame_pil.resize((new_width, new_height))
orig_frame_resized = np.array(orig_frame_resized)
concat_frame = np.concatenate([gen_frame, orig_frame_resized], axis=1)
concat_frames.append(concat_frame)
# Save every 9 frames of each type of frame
if i % 9 == 0:
# Save generated frame
gen_frame_path = os.path.join(generated_frames_dir, f"{i:04d}.png")
Image.fromarray(gen_frame).save(gen_frame_path)
# Save original frame
gt_frame_path = os.path.join(groundtruth_dir, f"{i:04d}.png")
Image.fromarray(orig_frame).save(gt_frame_path)
# If tracking maps, save tracking frame
if tracking_maps is not None:
track_frame_path = os.path.join(tracking_dir, f"{i:04d}.png")
Image.fromarray(track_frame).save(track_frame_path)
# Export concatenated video to group folder
group_video_path = os.path.join(group_dir, filename)
with imageio.get_writer(group_video_path, fps=fps) as writer:
for frame in concat_frames:
writer.append_data(frame)
return group_video_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, help="Optional: override the prompt from dataset")
parser.add_argument(
"--image_or_video_path",
type=str,
default=None,
help="The path of the image to be used as the background of the video",
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--generate_type", type=str, default="i2v", help="The type of video generation (e.g., 'i2v', 'i2vo')"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
parser.add_argument("--tracking_path", type=str, default=None, help="The path of the tracking maps to be used")
# Dataset related parameters are required
parser.add_argument("--data_root", type=str, required=True, help="Root directory of the dataset")
parser.add_argument("--caption_column", type=str, required=True, help="Name of the caption column")
parser.add_argument("--tracking_column", type=str, required=True, help="Name of the tracking column")
parser.add_argument("--video_column", type=str, required=True, help="Name of the video column")
parser.add_argument("--image_paths", type=str, required=False, help="Name of the image column")
# Add num_samples parameter
parser.add_argument("--num_samples", type=int, default=-1,
help="Number of samples to process. -1 means process all samples")
# Add evaluation_dir parameter
parser.add_argument("--evaluation_dir", type=str, default="evaluations",
help="Name of the directory to store evaluation results")
# Add fps parameter
parser.add_argument("--fps", type=int, default=8,
help="Frames per second for the output video")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
# If prompt is not provided, generate_video function will use prompts from dataset
generate_video(
prompt=args.prompt, # Can be None
model_path=args.model_path,
tracking_path=args.tracking_path,
image_paths=args.image_paths,
output_path=args.output_path,
image_or_video_path=args.image_or_video_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
dtype=dtype,
generate_type=args.generate_type,
seed=args.seed,
data_root=args.data_root,
caption_column=args.caption_column,
tracking_column=args.tracking_column,
video_column=args.video_column,
num_samples=args.num_samples,
evaluation_dir=args.evaluation_dir,
fps=args.fps,
)