Spaces:
Runtime error
Runtime error
""" | |
Adapted from: | |
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py | |
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py | |
""" | |
import click | |
import torch | |
import torchvision | |
from pathlib import Path | |
from diffusers import AutoencoderKLMochi, MochiPipeline | |
from transformers import T5EncoderModel, T5Tokenizer | |
from tqdm.auto import tqdm | |
def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str): | |
T, H, W = [int(s) for s in shape.split("x")] | |
assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6" | |
video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs") | |
fps = metadata["video_fps"] | |
video = video.permute(3, 0, 1, 2) | |
og_shape = video.shape | |
assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}" | |
assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}" | |
assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}" | |
if video.shape[1] > T: | |
video = video[:, :T] | |
print(f"Trimmed video from {og_shape[1]} to first {T} frames") | |
video = video.unsqueeze(0) | |
video = video.float() / 127.5 - 1.0 | |
video = video.to(model.device) | |
assert video.ndim == 5 | |
with torch.inference_mode(): | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
ldist = model._encode(video) | |
torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt")) | |
def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None: | |
"""Process all videos and captions in a directory using a single GPU.""" | |
# comment out when running on unsupported hardware | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Get all video paths | |
video_paths = list(output_dir.glob("**/*.mp4")) | |
if not video_paths: | |
print(f"No MP4 files found in {output_dir}") | |
return | |
text_paths = list(output_dir.glob("**/*.txt")) | |
if not text_paths: | |
print(f"No text files found in {output_dir}") | |
return | |
# load the models | |
vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda") | |
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder") | |
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
pipeline = MochiPipeline.from_pretrained( | |
model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None | |
).to("cuda") | |
for idx, video_path in tqdm(enumerate(sorted(video_paths))): | |
print(f"Processing {video_path}") | |
try: | |
if video_path.with_suffix(".latent.pt").exists() and not overwrite: | |
print(f"Skipping {video_path}") | |
continue | |
# encode videos. | |
encode_videos(vae, vid_path=video_path, shape=shape) | |
# embed captions. | |
prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt") | |
embed_path = prompt_path.with_suffix(".embed.pt") | |
if embed_path.exists() and not overwrite: | |
print(f"Skipping {prompt_path} - embeddings already exist") | |
continue | |
with open(prompt_path) as f: | |
text = f.read().strip() | |
with torch.inference_mode(): | |
conditioning = pipeline.encode_prompt(prompt=[text]) | |
conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]} | |
torch.save(conditioning, embed_path) | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
print(f"Error processing {video_path}: {str(e)}") | |
if __name__ == "__main__": | |
batch_process() | |