Spaces:
Running
Running
import streamlit as st | |
import imageio | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoProcessor, BlipForConditionalGeneration, MusicgenForConditionalGeneration | |
import soundfile as sf | |
import torch | |
import os | |
# Set page title | |
st.title("Video Sound Effect Generator") | |
# File uploader for video | |
uploaded_file = st.file_uploader( | |
"Upload a short video (MP4, max 10 seconds, high resolution)", | |
type=["mp4"] | |
) | |
if uploaded_file is not None: | |
try: | |
# Save the uploaded video temporarily | |
with open("temp_video.mp4", "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Check video duration | |
video = imageio.get_reader("temp_video.mp4") | |
fps = video.get_meta_data()['fps'] | |
num_frames = len(list(video.iter_data())) | |
duration = num_frames / fps | |
if duration > 10: | |
st.error("Video is too long. Please upload a video of maximum 10 seconds.") | |
else: | |
st.success("Video uploaded successfully!") | |
# Extract 10 evenly spaced frames | |
num_frames_to_extract = 10 | |
step = max(1, num_frames // num_frames_to_extract) | |
frames = [ | |
Image.fromarray(video.get_data(i)) | |
for i in range(0, num_frames, step) | |
][:num_frames_to_extract] | |
# Load BLIP model with caching | |
def load_blip_model(): | |
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
return processor, model | |
processor, model = load_blip_model() | |
# Generate text descriptions for each frame | |
descriptions = [] | |
for frame in frames: | |
inputs = processor(images=frame, return_tensors="pt") | |
out = model.generate(**inputs) | |
description = processor.decode(out[0], skip_special_tokens=True) | |
descriptions.append(description) | |
# Combine descriptions into a single prompt | |
text_prompt = ". ".join(descriptions) | |
st.write("Generated text prompt:", text_prompt) | |
# Load MusicGen model with caching | |
def load_musicgen_model(): | |
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
return processor, model | |
musicgen_processor, musicgen_model = load_musicgen_model() | |
# Generate sound effect | |
inputs = musicgen_processor( | |
text=[text_prompt], | |
padding=True, | |
return_tensors="pt", | |
) | |
audio_values = musicgen_model.generate(**inputs, max_new_tokens=512) | |
audio_array = audio_values[0].numpy() | |
sample_rate = musicgen_model.config.audio_encoder.sampling_rate | |
# Save audio to a WAV file | |
sf.write("output.wav", audio_array, sample_rate) | |
# Provide audio playback and download options | |
st.audio("output.wav", format="audio/wav") | |
with open("output.wav", "rb") as audio_file: | |
st.download_button( | |
label="Download Sound Effect", | |
data=audio_file, | |
file_name="sound_effect.wav", | |
mime="audio/wav" | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.write("Please try uploading a different video or check your connection.") | |
finally: | |
# Clean up temporary files | |
if os.path.exists("temp_video.mp4"): | |
os.remove("temp_video.mp4") | |
if os.path.exists("output.wav"): | |
os.remove("output.wav") | |