Beijia11
init
3aba902
import argparse
from typing import Literal
import os
import sys
import torch
from diffusers import (
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
)
from diffusers.utils import export_to_video, load_image, load_video
import numpy as np
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(current_dir, '..'))
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking, CogVideoXPipelineTracking, CogVideoXVideoToVideoPipelineTracking
from models.cogvideox_tracking import CogVideoXTransformer3DModelTracking
def generate_video(
prompt: str,
model_path: str,
tracking_path: str = None,
tracking_video: torch.Tensor = None,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
dtype: torch.dtype = torch.bfloat16,
generate_type: str = Literal["t2v", "i2v"], # i2v: image to video, i2vo: original CogVideoX-5b-I2V
fps: int = 24,
seed: int = 42,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- tracking_path (str): The path of the tracking maps to be used.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
- seed (int): The seed for reproducibility.
"""
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
# function to use Multi GPUs.
image = None
video = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# transformer = CogVideoXTransformer3DModelTracking.from_pretrained(
# model_path,
# subfolder="transformer",
# torch_dtype=dtype
# )
if generate_type == "i2v":
pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype)
image = load_image(image=image_or_video_path)
height, width = image.height, image.width
else:
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=dtype)
image = load_image(image=image_or_video_path)
height, width = image.height, image.width
pipe.transformer.eval()
pipe.text_encoder.eval()
pipe.vae.eval()
for param in pipe.transformer.parameters():
param.requires_grad = False
pipe.transformer.gradient_checkpointing = False
# Convert tracking maps from list of PIL Images to tensor
if tracking_path is not None:
tracking_maps = load_video(tracking_path)
# Convert list of PIL Images to tensor [T, C, H, W]
tracking_maps = torch.stack([
torch.from_numpy(np.array(frame)).permute(2, 0, 1).float() / 255.0
for frame in tracking_maps
])
tracking_maps = tracking_maps.to(device=device, dtype=dtype)
tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W]
height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3]
elif tracking_video is not None:
tracking_maps = tracking_video.float() / 255.0 # [T, C, H, W]
tracking_maps = tracking_maps.to(device=device, dtype=dtype)
tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W]
height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3]
else:
tracking_maps = None
tracking_first_frame = None
# 2. Set Scheduler.
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.to(device, dtype=dtype)
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.transformer.eval()
pipe.text_encoder.eval()
pipe.vae.eval()
pipe.transformer.gradient_checkpointing = False
if tracking_maps is not None and generate_type == "i2v":
print("Encoding tracking maps")
tracking_maps = tracking_maps.unsqueeze(0) # [B, T, C, H, W]
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
with torch.no_grad():
tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist
tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
else:
tracking_maps = None
tracking_first_frame = None
# 4. Generate the video frames based on the prompt.
if generate_type == "i2v":
with torch.no_grad():
video_generate = pipe(
prompt=prompt,
negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
image=image,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
num_frames=49,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
tracking_maps=tracking_maps,
tracking_image=tracking_first_frame,
height=height,
width=width,
).frames[0]
else:
with torch.no_grad():
video_generate = pipe(
prompt=prompt,
negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
image=image,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
num_frames=49,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).frames[0]
# 5. Export the generated frames to a video file. fps must be 8 for original video.
output_path = output_path if output_path else f"{generate_type}_img[{os.path.splitext(os.path.basename(image_or_video_path))[0]}]_txt[{prompt}].mp4"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
export_to_video(video_generate, output_path, fps=fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser.add_argument(
"--image_or_video_path",
type=str,
default=None,
help="The path of the image to be used as the background of the video",
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
parser.add_argument("--tracking_path", type=str, default=None, help="The path of the tracking maps to be used")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
generate_video(
prompt=args.prompt,
model_path=args.model_path,
tracking_path=args.tracking_path,
output_path=args.output_path,
image_or_video_path=args.image_or_video_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
dtype=dtype,
generate_type=args.generate_type,
seed=args.seed,
)