Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import sys | |
import traceback | |
# Improved import with fallback handling | |
try: | |
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
from diffusers.utils import export_to_gif | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
except ImportError as e: | |
print(f"Import Error: {e}") | |
sys.exit(1) | |
# Comprehensive Base Models | |
BASE_MODELS = { | |
"Realistic": [ | |
"emilianJR/epiCRealism", | |
"SG161222/Realistic_Vision_V5.1_noVAE", | |
"Lykon/dreamshaper-8", | |
"digiplay/AbsoluteReality_v1.8.1", | |
], | |
"Anime & Cartoon": [ | |
"cagliostroaic/ToonYou", | |
"Sangyun/IMP", | |
"Lykon/Mistoon_Anime", | |
"digiplay/DynaVision_v1.0", | |
] | |
} | |
def detect_device(): | |
""" | |
Robust device detection with detailed logging | |
""" | |
try: | |
if torch.cuda.is_available(): | |
print(f"CUDA Available. Using GPU: {torch.cuda.get_device_name(0)}") | |
return "cuda" | |
elif torch.backends.mps.is_available(): | |
print("Using MPS (Apple Silicon)") | |
return "mps" | |
else: | |
print("No GPU detected. Falling back to CPU.") | |
return "cpu" | |
except Exception as e: | |
print(f"Device detection error: {e}") | |
return "cpu" | |
def generate_video(prompt, base_model, steps=4, motion_strength=0.7, guidance_scale=1.0): | |
try: | |
device = detect_device() | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# Official AnimateDiff-Lightning Repository | |
repo = "ByteDance/AnimateDiff-Lightning" | |
ckpt = f"animatediff_lightning_{steps}step_diffusers.safetensors" | |
# Motion Adapter Setup with Error Handling | |
try: | |
adapter = MotionAdapter().to(device, dtype) | |
adapter.load_state_dict( | |
load_file( | |
hf_hub_download(repo, ckpt), | |
device=device | |
) | |
) | |
except Exception as adapter_error: | |
print(f"Motion Adapter Loading Error: {adapter_error}") | |
return None | |
# Flexible Model Loading | |
try: | |
pipe = AnimateDiffPipeline.from_pretrained( | |
base_model, | |
motion_adapter=adapter, | |
torch_dtype=dtype | |
).to(device) | |
except Exception as model_error: | |
print(f"Model loading error with {base_model}: {model_error}") | |
base_model = "emilianJR/epiCRealism" | |
pipe = AnimateDiffPipeline.from_pretrained( | |
base_model, | |
motion_adapter=adapter, | |
torch_dtype=dtype | |
).to(device) | |
pipe.scheduler = EulerDiscreteScheduler.from_config( | |
pipe.scheduler.config, | |
timestep_spacing="trailing", | |
beta_schedule="linear" | |
) | |
# Generation with Enhanced Error Handling | |
try: | |
output = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=steps | |
) | |
gif_path = "animation.gif" | |
export_to_gif(output.frames[0], gif_path) | |
return gif_path | |
except Exception as gen_error: | |
print(f"Video Generation Error: {gen_error}") | |
return None | |
except Exception as e: | |
print(f"Unexpected error in video generation: {e}") | |
traceback.print_exc() | |
return None | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## AnimateDiff-Lightning Video Generator") | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the video you want to generate...", | |
lines=3 | |
) | |
base_model = gr.Dropdown( | |
choices=[ | |
*BASE_MODELS["Realistic"], | |
*BASE_MODELS["Anime & Cartoon"] | |
], | |
label="Base Model", | |
value="emilianJR/epiCRealism" | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
minimum=1, | |
maximum=8, | |
step=1, | |
value=4, | |
label="Inference Steps" | |
) | |
guidance_scale = gr.Slider( | |
minimum=0.1, | |
maximum=20, | |
step=0.1, | |
value=1.0, | |
label="Guidance Scale" | |
) | |
motion_strength = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.7, | |
label="Motion Strength" | |
) | |
generate_btn = gr.Button("Generate Video", variant="primary") | |
output = gr.Image(label="Generated Video") | |
generate_btn.click( | |
generate_video, | |
inputs=[prompt, base_model, steps, motion_strength, guidance_scale], | |
outputs=output | |
) | |
return demo | |
def main(): | |
print("Initializing AnimateDiff-Lightning Gradio Interface...") | |
try: | |
demo = create_interface() | |
# Public sharing with detailed config | |
demo.launch( | |
share=True, # Create public link | |
debug=True, # Detailed error reporting | |
show_error=True, # Display errors in UI | |
server_name="0.0.0.0" # Accessible from any IP | |
) | |
except Exception as e: | |
print(f"Gradio Launch Error: {e}") | |
traceback.print_exc() | |
if __name__ == "__main__": | |
main() |