QNT-ByteDance / app.py
quangnhat's picture
validation check
13aa2f4
raw
history blame contribute delete
6.05 kB
import torch
import gradio as gr
import sys
import traceback
# Improved import with fallback handling
try:
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_gif
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
except ImportError as e:
print(f"Import Error: {e}")
sys.exit(1)
# Comprehensive Base Models
BASE_MODELS = {
"Realistic": [
"emilianJR/epiCRealism",
"SG161222/Realistic_Vision_V5.1_noVAE",
"Lykon/dreamshaper-8",
"digiplay/AbsoluteReality_v1.8.1",
],
"Anime & Cartoon": [
"cagliostroaic/ToonYou",
"Sangyun/IMP",
"Lykon/Mistoon_Anime",
"digiplay/DynaVision_v1.0",
]
}
def detect_device():
"""
Robust device detection with detailed logging
"""
try:
if torch.cuda.is_available():
print(f"CUDA Available. Using GPU: {torch.cuda.get_device_name(0)}")
return "cuda"
elif torch.backends.mps.is_available():
print("Using MPS (Apple Silicon)")
return "mps"
else:
print("No GPU detected. Falling back to CPU.")
return "cpu"
except Exception as e:
print(f"Device detection error: {e}")
return "cpu"
def generate_video(prompt, base_model, steps=4, motion_strength=0.7, guidance_scale=1.0):
try:
device = detect_device()
dtype = torch.float16 if device == "cuda" else torch.float32
# Official AnimateDiff-Lightning Repository
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{steps}step_diffusers.safetensors"
# Motion Adapter Setup with Error Handling
try:
adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(
load_file(
hf_hub_download(repo, ckpt),
device=device
)
)
except Exception as adapter_error:
print(f"Motion Adapter Loading Error: {adapter_error}")
return None
# Flexible Model Loading
try:
pipe = AnimateDiffPipeline.from_pretrained(
base_model,
motion_adapter=adapter,
torch_dtype=dtype
).to(device)
except Exception as model_error:
print(f"Model loading error with {base_model}: {model_error}")
base_model = "emilianJR/epiCRealism"
pipe = AnimateDiffPipeline.from_pretrained(
base_model,
motion_adapter=adapter,
torch_dtype=dtype
).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
beta_schedule="linear"
)
# Generation with Enhanced Error Handling
try:
output = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps
)
gif_path = "animation.gif"
export_to_gif(output.frames[0], gif_path)
return gif_path
except Exception as gen_error:
print(f"Video Generation Error: {gen_error}")
return None
except Exception as e:
print(f"Unexpected error in video generation: {e}")
traceback.print_exc()
return None
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("## AnimateDiff-Lightning Video Generator")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate...",
lines=3
)
base_model = gr.Dropdown(
choices=[
*BASE_MODELS["Realistic"],
*BASE_MODELS["Anime & Cartoon"]
],
label="Base Model",
value="emilianJR/epiCRealism"
)
with gr.Row():
steps = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=4,
label="Inference Steps"
)
guidance_scale = gr.Slider(
minimum=0.1,
maximum=20,
step=0.1,
value=1.0,
label="Guidance Scale"
)
motion_strength = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.7,
label="Motion Strength"
)
generate_btn = gr.Button("Generate Video", variant="primary")
output = gr.Image(label="Generated Video")
generate_btn.click(
generate_video,
inputs=[prompt, base_model, steps, motion_strength, guidance_scale],
outputs=output
)
return demo
def main():
print("Initializing AnimateDiff-Lightning Gradio Interface...")
try:
demo = create_interface()
# Public sharing with detailed config
demo.launch(
share=True, # Create public link
debug=True, # Detailed error reporting
show_error=True, # Display errors in UI
server_name="0.0.0.0" # Accessible from any IP
)
except Exception as e:
print(f"Gradio Launch Error: {e}")
traceback.print_exc()
if __name__ == "__main__":
main()