garyuzair's picture
Update app.py
b4490a5 verified
import streamlit as st
import imageio
import numpy as np
from PIL import Image
from transformers import AutoProcessor, BlipForConditionalGeneration, MusicgenForConditionalGeneration
import soundfile as sf
import torch
import os
import tempfile
# Try importing moviepy, with fallback
try:
import moviepy.editor as mpy
except ModuleNotFoundError:
st.error("The 'moviepy' library is not installed. Please ensure 'moviepy' is in requirements.txt and installed correctly.")
st.stop()
# Set page title and instructions
st.title("Story Video Sound Effect Sync Generator")
st.write("Upload an MP4 video to auto-generate and sync a high-quality sound effect.")
# User-configurable settings
num_frames_to_extract = st.slider("Number of frames to analyze", 1, 3, 1, help="Fewer frames = faster processing")
mix_original_audio = st.checkbox("Mix with original audio", value=False, help="Blend sound effect with video’s original sound")
# Prompt enhancement function
def enhance_prompt(base_description):
"""Enhance BLIP caption with sound-specific details."""
base = base_description.lower()
if "walk" in base or "run" in base:
return f"{base} with crisp footsteps on a wooden floor"
elif "car" in base or "drive" in base:
return f"{base} with the roar of an engine and tires screeching"
elif "talk" in base or "person" in base:
return f"{base} with soft voices and background crowd murmur"
elif "wind" in base or "tree" in base or "forest" in base:
return f"{base} with gentle wind rustling through leaves"
elif "crash" in base or "fall" in base:
return f"{base} with a loud crash and debris scattering"
else:
return f"{base} with subtle ambient hum and faint echoes"
# File uploader for video
uploaded_file = st.file_uploader("Upload an MP4 video (high resolution)", type=["mp4"])
if uploaded_file is not None:
try:
# Temporary video file
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
temp_video.write(uploaded_file.getbuffer())
temp_video_path = temp_video.name
# Progress bar setup
progress_bar = st.progress(0)
status_text = st.empty()
# Extract frames
status_text.text("Extracting frames...")
video = imageio.get_reader(temp_video_path, "ffmpeg")
total_frames = len(list(video.iter_data()))
step = max(1, total_frames // num_frames_to_extract)
frames = [
Image.fromarray(video.get_data(i))
for i in range(0, min(total_frames, num_frames_to_extract * step), step)
][:num_frames_to_extract]
progress_bar.progress(20)
# Load BLIP model
@st.cache_resource
def load_blip_model():
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
if torch.cuda.is_available():
model = model.half().to("cuda")
return processor, model
processor, model = load_blip_model()
# Generate and enhance text descriptions
status_text.text("Analyzing frames...")
descriptions = []
for i, frame in enumerate(frames):
inputs = processor(images=frame, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
out = model.generate(**inputs)
base_description = processor.decode(out[0], skip_special_tokens=True)
enhanced_description = enhance_prompt(base_description)
descriptions.append(enhanced_description)
progress_bar.progress(20 + int(30 * (i + 1) / len(frames)))
text_prompt = ". ".join(descriptions)
st.write("Enhanced text prompt:", text_prompt)
# Load MusicGen model
@st.cache_resource
def load_musicgen_model():
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
if torch.cuda.is_available():
model = model.half().to("cuda")
return processor, model
musicgen_processor, musicgen_model = load_musicgen_model()
# Generate sound effect (~8 seconds)
status_text.text("Generating sound effect...")
inputs = musicgen_processor(
text=[text_prompt],
padding=True,
return_tensors="pt",
)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
audio_values = musicgen_model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
guidance_scale=3.0,
top_k=50,
top_p=0.95
)
audio_array = audio_values[0].cpu().numpy()
if audio_array.ndim > 1:
audio_array = audio_array.flatten()
audio_array = audio_array / np.max(np.abs(audio_array)) * 0.9
audio_array = np.clip(audio_array, -1.0, 1.0)
sample_rate = 32000
progress_bar.progress(60)
# Save temporary audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
sf.write(temp_audio.name, audio_array, sample_rate)
temp_audio_path = temp_audio.name
# Synchronize with video using mpy
status_text.text("Syncing audio with video...")
video_clip = mpy.VideoFileClip(temp_video_path)
video_duration = video_clip.duration
audio_clip = mpy.AudioFileClip(temp_audio_path)
# Adjust audio length
if audio_clip.duration < video_duration:
loops_needed = int(np.ceil(video_duration / audio_clip.duration))
audio_clip = mpy.concatenate_audioclips([audio_clip] * loops_needed).subclip(0, video_duration)
else:
audio_clip = audio_clip.subclip(0, video_duration)
# Mix or replace audio
if mix_original_audio and video_clip.audio:
final_audio = video_clip.audio.volumex(0.5) + audio_clip.volumex(0.5)
else:
final_audio = audio_clip
# Set audio to video
final_video = video_clip.set_audio(final_audio)
# Save final video with faster preset
output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
final_video.write_videofile(
output_path,
codec="libx264",
audio_codec="aac",
preset="ultrafast",
temp_audiofile="temp-audio.m4a",
remove_temp=True
)
progress_bar.progress(90)
# Provide playback and download
status_text.text("Done!")
st.video(output_path)
with open(output_path, "rb") as video_file:
st.download_button(
label="Download Synced Video",
data=video_file,
file_name="synced_story_video.mp4",
mime="video/mp4"
)
progress_bar.progress(100)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.write("Try reducing frames or uploading a smaller video.")
finally:
# Clean up
for path in [temp_video_path, temp_audio_path, output_path]:
if 'path' in locals() and os.path.exists(path):
os.remove(path)