garyuzair commited on
Commit
b48b44d
·
verified ·
1 Parent(s): b9f8827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -133
app.py CHANGED
@@ -1,145 +1,106 @@
1
  import streamlit as st
2
- import cv2
3
  import numpy as np
4
- from transformers import (
5
- pipeline,
6
- AutoProcessor,
7
- MusicgenForCausalLM,
8
- set_seed,
9
- )
10
  import torch
11
- import tempfile
12
  import os
13
- import ffmpeg
14
- from PIL import Image
15
- import warnings
16
- import io
17
-
18
- warnings.filterwarnings("ignore")
19
-
20
- def analyze_video(video_path):
21
- """Analyze video with optimized BLIP processing"""
22
- cap = cv2.VideoCapture(video_path)
23
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
24
- sample_rate = max(frame_count // 30, 1) # Increased frame sampling
25
-
26
- frames = []
27
- try:
28
- for i in range(0, frame_count, sample_rate):
29
- cap.set(cv2.CAP_PROP_POS_FRAMES, i)
30
- ret, frame = cap.read()
31
- if ret:
32
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33
- small_frame = cv2.resize(frame_rgb, (128, 128)) # Reduced resolution
34
- frames.append(Image.fromarray(small_frame))
35
- finally:
36
- cap.release()
37
-
38
- if not frames:
39
- return ""
40
 
41
- outputs = st.session_state.model(frames)
42
- captions = [out[0]['generated_text'] for out in outputs]
43
 
44
- return " ".join(list(set(captions)))
45
-
46
- def generate_audio(description, duration):
47
- """Optimized audio generation with corrected parameters"""
48
- inputs = st.session_state.processor(
49
- text=[description],
50
- padding=True,
51
- return_tensors="pt",
52
- ).to(st.session_state.device)
53
-
54
- set_seed(42) # For reproducibility
55
-
56
- max_new_tokens = int(duration * 16) # 16 tokens per second
57
 
 
58
  try:
59
- audio = st.session_state.music_gen.generate(
60
- **inputs,
61
- do_sample=True,
62
- top_k=250,
63
- max_new_tokens=max_new_tokens,
64
- )
65
- if audio.shape[0] == 1 and len(audio.shape) == 3:
66
- return audio[0][0].cpu().numpy()
67
- else:
68
- st.error(f"Unexpected audio shape: {audio.shape}")
69
- return None
70
-
71
- except torch.cuda.OutOfMemoryError:
72
- st.error("Out of GPU memory. Try a shorter video or smaller model.")
73
- return None
74
- except Exception as e:
75
- st.error(f"Error during audio generation: {e}")
76
- return None
77
-
78
- def process_video(uploaded_file):
79
- """Generates audio from video description"""
80
- with tempfile.TemporaryDirectory() as tmp_dir:
81
- video_path = os.path.join(tmp_dir, "input.mp4")
82
- with open(video_path, "wb") as f:
83
  f.write(uploaded_file.getbuffer())
84
-
85
- probe = ffmpeg.probe(video_path)
86
- duration = min(float(probe['format']['duration']), 10)
87
-
88
- description = analyze_video(video_path)
89
- audio_array = generate_audio(description, duration)
90
-
91
- if audio_array is None:
92
- return None
93
-
94
- return audio_array
95
-
96
- # Streamlit UI
97
- st.set_page_config(page_title="Video Sound FX", layout="wide")
98
- st.title("🎬 Video Sound Generator (Audio Only)")
99
-
100
- if 'initialized' not in st.session_state:
101
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
- if device.type == 'cpu':
103
- st.warning("Running on CPU. Audio generation will be slow.")
104
-
105
- st.session_state.model = pipeline(
106
- "image-to-text",
107
- model="Salesforce/blip-image-captioning-base",
108
- device=device
109
- )
110
-
111
- st.session_state.music_gen = MusicgenForCausalLM.from_pretrained(
112
- "facebook/musicgen-small"
113
- ).to(device)
114
-
115
- st.session_state.processor = AutoProcessor.from_pretrained(
116
- "facebook/musicgen-small"
117
- )
118
-
119
- st.session_state.device = device
120
- st.session_state.initialized = True
121
-
122
- uploaded_file = st.file_uploader("Upload video (MP4, max 10s and low resolution)", type=["mp4"])
123
-
124
- if uploaded_file and st.button("Generate Audio"):
125
- with st.spinner("Processing..."):
126
- try:
127
- audio_array = process_video(uploaded_file)
128
- if audio_array is not None:
129
- st.success("Audio Generation Complete!")
130
-
131
- # Convert numpy array to bytes for download
132
- audio_bytes = (audio_array * 32767).astype(np.int16).tobytes()
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  st.download_button(
135
- "Download Audio (WAV)",
136
- data=audio_bytes,
137
- file_name="generated_audio.wav",
138
  mime="audio/wav"
139
  )
140
- st.audio(audio_bytes, format='audio/wav')
141
- else:
142
- st.error("Error during audio generation. Please check the video and try again. Ensure video is max 10 seconds and low resolution.")
143
-
144
- except Exception as e:
145
- st.error(f"Error: {str(e)}")
 
 
 
 
 
 
1
  import streamlit as st
2
+ import imageio
3
  import numpy as np
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, BlipForConditionalGeneration, MusicgenForConditionalGeneration
6
+ import soundfile as sf
 
 
 
7
  import torch
 
8
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Set page title
11
+ st.title("Video Sound Effect Generator")
12
 
13
+ # File uploader for video
14
+ uploaded_file = st.file_uploader(
15
+ "Upload a short video (MP4, max 10 seconds, high resolution)",
16
+ type=["mp4"]
17
+ )
 
 
 
 
 
 
 
 
18
 
19
+ if uploaded_file is not None:
20
  try:
21
+ # Save the uploaded video temporarily
22
+ with open("temp_video.mp4", "wb") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  f.write(uploaded_file.getbuffer())
24
+
25
+ # Check video duration
26
+ video = imageio.get_reader("temp_video.mp4")
27
+ fps = video.get_meta_data()['fps']
28
+ num_frames = len(list(video.iter_data()))
29
+ duration = num_frames / fps
30
+
31
+ if duration > 10:
32
+ st.error("Video is too long. Please upload a video of maximum 10 seconds.")
33
+ else:
34
+ st.success("Video uploaded successfully!")
35
+
36
+ # Extract 10 evenly spaced frames
37
+ num_frames_to_extract = 10
38
+ step = max(1, num_frames // num_frames_to_extract)
39
+ frames = [
40
+ Image.fromarray(video.get_data(i))
41
+ for i in range(0, num_frames, step)
42
+ ][:num_frames_to_extract]
43
+
44
+ # Load BLIP model with caching
45
+ @st.cache_resource
46
+ def load_blip_model():
47
+ processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
48
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
49
+ return processor, model
50
+
51
+ processor, model = load_blip_model()
52
+
53
+ # Generate text descriptions for each frame
54
+ descriptions = []
55
+ for frame in frames:
56
+ inputs = processor(images=frame, return_tensors="pt")
57
+ out = model.generate(**inputs)
58
+ description = processor.decode(out[0], skip_special_tokens=True)
59
+ descriptions.append(description)
60
+
61
+ # Combine descriptions into a single prompt
62
+ text_prompt = ". ".join(descriptions)
63
+ st.write("Generated text prompt:", text_prompt)
64
+
65
+ # Load MusicGen model with caching
66
+ @st.cache_resource
67
+ def load_musicgen_model():
68
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
69
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
70
+ return processor, model
71
+
72
+ musicgen_processor, musicgen_model = load_musicgen_model()
73
+
74
+ # Generate sound effect
75
+ inputs = musicgen_processor(
76
+ text=[text_prompt],
77
+ padding=True,
78
+ return_tensors="pt",
79
+ )
80
+ audio_values = musicgen_model.generate(**inputs, max_new_tokens=512)
81
+ audio_array = audio_values[0].numpy()
82
+ sample_rate = musicgen_model.config.audio_encoder.sampling_rate
83
+
84
+ # Save audio to a WAV file
85
+ sf.write("output.wav", audio_array, sample_rate)
86
+
87
+ # Provide audio playback and download options
88
+ st.audio("output.wav", format="audio/wav")
89
+ with open("output.wav", "rb") as audio_file:
90
  st.download_button(
91
+ label="Download Sound Effect",
92
+ data=audio_file,
93
+ file_name="sound_effect.wav",
94
  mime="audio/wav"
95
  )
96
+
97
+ except Exception as e:
98
+ st.error(f"An error occurred: {str(e)}")
99
+ st.write("Please try uploading a different video or check your connection.")
100
+
101
+ finally:
102
+ # Clean up temporary files
103
+ if os.path.exists("temp_video.mp4"):
104
+ os.remove("temp_video.mp4")
105
+ if os.path.exists("output.wav"):
106
+ os.remove("output.wav")