Anurag Bhardwaj commited on
Commit
e028809
·
verified ·
1 Parent(s): 08fe821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -90
app.py CHANGED
@@ -1,92 +1,32 @@
1
- import os
2
-
3
- # Set CUDA environment variables _before_ any torch imports.
4
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
6
-
7
  import subprocess
8
- import sys
9
- import torch
10
-
11
- # For debugging, print CUDA availability.
12
- print("torch.cuda.is_available():", torch.cuda.is_available())
13
-
14
- # If torch reports no CUDA devices, apply monkey-patches.
15
- if not torch.cuda.is_available():
16
- print("Monkey patching torch.cuda to force CUDA availability.")
17
- torch.cuda.is_available = lambda: True
18
- torch.cuda.current_device = lambda: 0
19
- torch.cuda.set_device = lambda device: None
20
- # Override lazy initialization to do nothing.
21
- torch.cuda._lazy_init = lambda: None
22
-
23
- # Try to force initialization (will now use our patch instead of raising an error).
24
- try:
25
- _ = torch.cuda.current_device()
26
- print("Current CUDA device (patched):", torch.cuda.current_device())
27
- except Exception as e:
28
- print("Error forcing CUDA initialization:", e)
29
-
30
- # List of required packages.
31
- required_packages = ["easydict", "diffusers", "ftfy", "transformers"]
32
-
33
- for package in required_packages:
34
- try:
35
- __import__(package)
36
- except ModuleNotFoundError:
37
- print(f"{package} not found, installing now...")
38
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
39
-
40
- from easydict import EasyDict
41
- import gradio as gr
42
- from huggingface_hub import snapshot_download
43
-
44
- # Define local directory names.
45
- REPO_DIR = "Wan2_1" # Renamed from Wan2.1 to avoid invalid module names.
46
- CHECKPOINT_PARENT_DIR = "model_checkpoints"
47
- CHECKPOINT_SUBDIR = "Wan2.1-T2V-14B" # Example checkpoint subdirectory.
48
-
49
- # Clone the repository if it does not exist.
50
- if not os.path.exists(REPO_DIR):
51
- print("Cloning Wan2.1 repository...")
52
- subprocess.run(["git", "clone", "https://github.com/Wan-Video/Wan2.1.git", REPO_DIR])
53
-
54
- # Add the cloned repository to Python's module search path.
55
- sys.path.insert(0, os.path.abspath(REPO_DIR))
56
-
57
- # Download the model checkpoint snapshot from Hugging Face Hub.
58
- os.makedirs(CHECKPOINT_PARENT_DIR, exist_ok=True)
59
- checkpoint_dir = os.path.join(CHECKPOINT_PARENT_DIR, CHECKPOINT_SUBDIR)
60
- print("Downloading model checkpoint from Hugging Face Hub...")
61
- model_id = "Wan-AI/Wan2.1-T2V-14B" # Update if necessary.
62
- ckpt_path = snapshot_download(repo_id=model_id, cache_dir=checkpoint_dir)
63
- print(f"Model checkpoint downloaded to: {ckpt_path}")
64
-
65
- # Import WanPipeline from the local repository.
66
- from wan.pipeline import WanPipeline
67
-
68
- # Initialize the pipeline.
69
- pipe = WanPipeline(
70
- model_dir=ckpt_path, # Directory containing model files.
71
- device="cuda" if torch.cuda.is_available() else "cpu"
72
- )
73
-
74
- def generate_video(prompt):
75
- """
76
- Generate a video from a text prompt using Wan2.1.
77
- Calls the pipeline's generate() method and returns the video frames.
78
- """
79
- video_frames = pipe.generate(prompt)
80
- return video_frames
81
-
82
- # Create the Gradio interface.
83
- iface = gr.Interface(
84
- fn=generate_video,
85
- inputs=gr.Textbox(label="Prompt", placeholder="Enter a video prompt here..."),
86
- outputs=gr.Video(label="Generated Video"),
87
- title="Wan2.1 Text-to-Video Generation",
88
- description="Generate videos from text prompts using the Wan2.1 model."
89
- )
90
 
91
- if __name__ == "__main__":
92
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
 
 
 
 
 
2
  import subprocess
3
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Title
6
+ st.title("🎥 WAN 2.1 - 14B AI Text-to-Video Generator")
7
+
8
+ # Input fields
9
+ prompt = st.text_area("Enter your text prompt:", "A cat in military dress wearing headphones, laughing and walking.")
10
+ frame_num = st.slider("Number of frames:", min_value=30, max_value=120, value=60, step=10)
11
+ resolution = st.selectbox("Select resolution:", ["832*480", "1280*720"])
12
+ sample_steps = st.slider("Sampling steps:", min_value=10, max_value=50, value=20, step=5)
13
+
14
+ # Button to generate video
15
+ if st.button("Generate Video"):
16
+ st.info("Generating video... This may take a few minutes.")
17
+
18
+ # Run WAN 2.1 - 14B Model
19
+ command = f"python generate.py --task t2v-14B --size {resolution} --frame_num {frame_num} --sample_steps {sample_steps} --ckpt_dir ./Wan2.1-T2V-14B --offload_model True --prompt \"{prompt}\""
20
+
21
+ process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
22
+ stdout, stderr = process.communicate()
23
+
24
+ # Print logs for debugging
25
+ st.text_area("📜 Logs", stdout.decode() + stderr.decode())
26
+
27
+ # Check if video was created
28
+ if os.path.exists("output.mp4"):
29
+ st.video("output.mp4")
30
+ st.success("✅ Video generated successfully!")
31
+ else:
32
+ st.error("❌ Video generation failed! Check logs above.")