Spaces:
Running
Running
import streamlit as st | |
import imageio | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoProcessor, BlipForConditionalGeneration, MusicgenForConditionalGeneration | |
import soundfile as sf | |
import torch | |
import os | |
import tempfile | |
# Try importing moviepy, with fallback | |
try: | |
import moviepy.editor as mpy | |
except ModuleNotFoundError: | |
st.error("The 'moviepy' library is not installed. Please ensure 'moviepy' is in requirements.txt and installed correctly.") | |
st.stop() | |
# Set page title and instructions | |
st.title("Story Video Sound Effect Sync Generator") | |
st.write("Upload an MP4 video to auto-generate and sync a high-quality sound effect.") | |
# User-configurable settings | |
num_frames_to_extract = st.slider("Number of frames to analyze", 1, 3, 1, help="Fewer frames = faster processing") | |
mix_original_audio = st.checkbox("Mix with original audio", value=False, help="Blend sound effect with video’s original sound") | |
# Prompt enhancement function | |
def enhance_prompt(base_description): | |
"""Enhance BLIP caption with sound-specific details.""" | |
base = base_description.lower() | |
if "walk" in base or "run" in base: | |
return f"{base} with crisp footsteps on a wooden floor" | |
elif "car" in base or "drive" in base: | |
return f"{base} with the roar of an engine and tires screeching" | |
elif "talk" in base or "person" in base: | |
return f"{base} with soft voices and background crowd murmur" | |
elif "wind" in base or "tree" in base or "forest" in base: | |
return f"{base} with gentle wind rustling through leaves" | |
elif "crash" in base or "fall" in base: | |
return f"{base} with a loud crash and debris scattering" | |
else: | |
return f"{base} with subtle ambient hum and faint echoes" | |
# File uploader for video | |
uploaded_file = st.file_uploader("Upload an MP4 video (high resolution)", type=["mp4"]) | |
if uploaded_file is not None: | |
try: | |
# Temporary video file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video: | |
temp_video.write(uploaded_file.getbuffer()) | |
temp_video_path = temp_video.name | |
# Progress bar setup | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
# Extract frames | |
status_text.text("Extracting frames...") | |
video = imageio.get_reader(temp_video_path, "ffmpeg") | |
total_frames = len(list(video.iter_data())) | |
step = max(1, total_frames // num_frames_to_extract) | |
frames = [ | |
Image.fromarray(video.get_data(i)) | |
for i in range(0, min(total_frames, num_frames_to_extract * step), step) | |
][:num_frames_to_extract] | |
progress_bar.progress(20) | |
# Load BLIP model | |
def load_blip_model(): | |
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
if torch.cuda.is_available(): | |
model = model.half().to("cuda") | |
return processor, model | |
processor, model = load_blip_model() | |
# Generate and enhance text descriptions | |
status_text.text("Analyzing frames...") | |
descriptions = [] | |
for i, frame in enumerate(frames): | |
inputs = processor(images=frame, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
out = model.generate(**inputs) | |
base_description = processor.decode(out[0], skip_special_tokens=True) | |
enhanced_description = enhance_prompt(base_description) | |
descriptions.append(enhanced_description) | |
progress_bar.progress(20 + int(30 * (i + 1) / len(frames))) | |
text_prompt = ". ".join(descriptions) | |
st.write("Enhanced text prompt:", text_prompt) | |
# Load MusicGen model | |
def load_musicgen_model(): | |
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
if torch.cuda.is_available(): | |
model = model.half().to("cuda") | |
return processor, model | |
musicgen_processor, musicgen_model = load_musicgen_model() | |
# Generate sound effect (~8 seconds) | |
status_text.text("Generating sound effect...") | |
inputs = musicgen_processor( | |
text=[text_prompt], | |
padding=True, | |
return_tensors="pt", | |
) | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
audio_values = musicgen_model.generate( | |
**inputs, | |
max_new_tokens=256, | |
do_sample=True, | |
guidance_scale=3.0, | |
top_k=50, | |
top_p=0.95 | |
) | |
audio_array = audio_values[0].cpu().numpy() | |
if audio_array.ndim > 1: | |
audio_array = audio_array.flatten() | |
audio_array = audio_array / np.max(np.abs(audio_array)) * 0.9 | |
audio_array = np.clip(audio_array, -1.0, 1.0) | |
sample_rate = 32000 | |
progress_bar.progress(60) | |
# Save temporary audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
sf.write(temp_audio.name, audio_array, sample_rate) | |
temp_audio_path = temp_audio.name | |
# Synchronize with video using mpy | |
status_text.text("Syncing audio with video...") | |
video_clip = mpy.VideoFileClip(temp_video_path) | |
video_duration = video_clip.duration | |
audio_clip = mpy.AudioFileClip(temp_audio_path) | |
# Adjust audio length | |
if audio_clip.duration < video_duration: | |
loops_needed = int(np.ceil(video_duration / audio_clip.duration)) | |
audio_clip = mpy.concatenate_audioclips([audio_clip] * loops_needed).subclip(0, video_duration) | |
else: | |
audio_clip = audio_clip.subclip(0, video_duration) | |
# Mix or replace audio | |
if mix_original_audio and video_clip.audio: | |
final_audio = video_clip.audio.volumex(0.5) + audio_clip.volumex(0.5) | |
else: | |
final_audio = audio_clip | |
# Set audio to video | |
final_video = video_clip.set_audio(final_audio) | |
# Save final video with faster preset | |
output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name | |
final_video.write_videofile( | |
output_path, | |
codec="libx264", | |
audio_codec="aac", | |
preset="ultrafast", | |
temp_audiofile="temp-audio.m4a", | |
remove_temp=True | |
) | |
progress_bar.progress(90) | |
# Provide playback and download | |
status_text.text("Done!") | |
st.video(output_path) | |
with open(output_path, "rb") as video_file: | |
st.download_button( | |
label="Download Synced Video", | |
data=video_file, | |
file_name="synced_story_video.mp4", | |
mime="video/mp4" | |
) | |
progress_bar.progress(100) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.write("Try reducing frames or uploading a smaller video.") | |
finally: | |
# Clean up | |
for path in [temp_video_path, temp_audio_path, output_path]: | |
if 'path' in locals() and os.path.exists(path): | |
os.remove(path) | |