Spaces:
Running
Running
Anurag Bhardwaj
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,92 +1,32 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
# Set CUDA environment variables _before_ any torch imports.
|
4 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
5 |
-
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
6 |
-
|
7 |
import subprocess
|
8 |
-
import
|
9 |
-
import torch
|
10 |
-
|
11 |
-
# For debugging, print CUDA availability.
|
12 |
-
print("torch.cuda.is_available():", torch.cuda.is_available())
|
13 |
-
|
14 |
-
# If torch reports no CUDA devices, apply monkey-patches.
|
15 |
-
if not torch.cuda.is_available():
|
16 |
-
print("Monkey patching torch.cuda to force CUDA availability.")
|
17 |
-
torch.cuda.is_available = lambda: True
|
18 |
-
torch.cuda.current_device = lambda: 0
|
19 |
-
torch.cuda.set_device = lambda device: None
|
20 |
-
# Override lazy initialization to do nothing.
|
21 |
-
torch.cuda._lazy_init = lambda: None
|
22 |
-
|
23 |
-
# Try to force initialization (will now use our patch instead of raising an error).
|
24 |
-
try:
|
25 |
-
_ = torch.cuda.current_device()
|
26 |
-
print("Current CUDA device (patched):", torch.cuda.current_device())
|
27 |
-
except Exception as e:
|
28 |
-
print("Error forcing CUDA initialization:", e)
|
29 |
-
|
30 |
-
# List of required packages.
|
31 |
-
required_packages = ["easydict", "diffusers", "ftfy", "transformers"]
|
32 |
-
|
33 |
-
for package in required_packages:
|
34 |
-
try:
|
35 |
-
__import__(package)
|
36 |
-
except ModuleNotFoundError:
|
37 |
-
print(f"{package} not found, installing now...")
|
38 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
39 |
-
|
40 |
-
from easydict import EasyDict
|
41 |
-
import gradio as gr
|
42 |
-
from huggingface_hub import snapshot_download
|
43 |
-
|
44 |
-
# Define local directory names.
|
45 |
-
REPO_DIR = "Wan2_1" # Renamed from Wan2.1 to avoid invalid module names.
|
46 |
-
CHECKPOINT_PARENT_DIR = "model_checkpoints"
|
47 |
-
CHECKPOINT_SUBDIR = "Wan2.1-T2V-14B" # Example checkpoint subdirectory.
|
48 |
-
|
49 |
-
# Clone the repository if it does not exist.
|
50 |
-
if not os.path.exists(REPO_DIR):
|
51 |
-
print("Cloning Wan2.1 repository...")
|
52 |
-
subprocess.run(["git", "clone", "https://github.com/Wan-Video/Wan2.1.git", REPO_DIR])
|
53 |
-
|
54 |
-
# Add the cloned repository to Python's module search path.
|
55 |
-
sys.path.insert(0, os.path.abspath(REPO_DIR))
|
56 |
-
|
57 |
-
# Download the model checkpoint snapshot from Hugging Face Hub.
|
58 |
-
os.makedirs(CHECKPOINT_PARENT_DIR, exist_ok=True)
|
59 |
-
checkpoint_dir = os.path.join(CHECKPOINT_PARENT_DIR, CHECKPOINT_SUBDIR)
|
60 |
-
print("Downloading model checkpoint from Hugging Face Hub...")
|
61 |
-
model_id = "Wan-AI/Wan2.1-T2V-14B" # Update if necessary.
|
62 |
-
ckpt_path = snapshot_download(repo_id=model_id, cache_dir=checkpoint_dir)
|
63 |
-
print(f"Model checkpoint downloaded to: {ckpt_path}")
|
64 |
-
|
65 |
-
# Import WanPipeline from the local repository.
|
66 |
-
from wan.pipeline import WanPipeline
|
67 |
-
|
68 |
-
# Initialize the pipeline.
|
69 |
-
pipe = WanPipeline(
|
70 |
-
model_dir=ckpt_path, # Directory containing model files.
|
71 |
-
device="cuda" if torch.cuda.is_available() else "cpu"
|
72 |
-
)
|
73 |
-
|
74 |
-
def generate_video(prompt):
|
75 |
-
"""
|
76 |
-
Generate a video from a text prompt using Wan2.1.
|
77 |
-
Calls the pipeline's generate() method and returns the video frames.
|
78 |
-
"""
|
79 |
-
video_frames = pipe.generate(prompt)
|
80 |
-
return video_frames
|
81 |
-
|
82 |
-
# Create the Gradio interface.
|
83 |
-
iface = gr.Interface(
|
84 |
-
fn=generate_video,
|
85 |
-
inputs=gr.Textbox(label="Prompt", placeholder="Enter a video prompt here..."),
|
86 |
-
outputs=gr.Video(label="Generated Video"),
|
87 |
-
title="Wan2.1 Text-to-Video Generation",
|
88 |
-
description="Generate videos from text prompts using the Wan2.1 model."
|
89 |
-
)
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
2 |
import subprocess
|
3 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
# Title
|
6 |
+
st.title("🎥 WAN 2.1 - 14B AI Text-to-Video Generator")
|
7 |
+
|
8 |
+
# Input fields
|
9 |
+
prompt = st.text_area("Enter your text prompt:", "A cat in military dress wearing headphones, laughing and walking.")
|
10 |
+
frame_num = st.slider("Number of frames:", min_value=30, max_value=120, value=60, step=10)
|
11 |
+
resolution = st.selectbox("Select resolution:", ["832*480", "1280*720"])
|
12 |
+
sample_steps = st.slider("Sampling steps:", min_value=10, max_value=50, value=20, step=5)
|
13 |
+
|
14 |
+
# Button to generate video
|
15 |
+
if st.button("Generate Video"):
|
16 |
+
st.info("Generating video... This may take a few minutes.")
|
17 |
+
|
18 |
+
# Run WAN 2.1 - 14B Model
|
19 |
+
command = f"python generate.py --task t2v-14B --size {resolution} --frame_num {frame_num} --sample_steps {sample_steps} --ckpt_dir ./Wan2.1-T2V-14B --offload_model True --prompt \"{prompt}\""
|
20 |
+
|
21 |
+
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
22 |
+
stdout, stderr = process.communicate()
|
23 |
+
|
24 |
+
# Print logs for debugging
|
25 |
+
st.text_area("📜 Logs", stdout.decode() + stderr.decode())
|
26 |
+
|
27 |
+
# Check if video was created
|
28 |
+
if os.path.exists("output.mp4"):
|
29 |
+
st.video("output.mp4")
|
30 |
+
st.success("✅ Video generated successfully!")
|
31 |
+
else:
|
32 |
+
st.error("❌ Video generation failed! Check logs above.")
|