garyuzair commited on
Commit
c6fb5b9
·
verified ·
1 Parent(s): 8237612

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -58
app.py CHANGED
@@ -6,89 +6,109 @@ from transformers import AutoProcessor, BlipForConditionalGeneration, MusicgenFo
6
  import soundfile as sf
7
  import torch
8
  import os
 
 
9
 
10
- # Set page title
11
  st.title("Video Sound Effect Generator")
 
 
 
 
12
 
13
  # File uploader for video
14
- uploaded_file = st.file_uploader(
15
- "Upload a short video (MP4, high resolution)",
16
- type=["mp4"]
17
- )
18
 
19
  if uploaded_file is not None:
20
  try:
21
- # Save the uploaded video temporarily
22
- with open("temp_video.mp4", "wb") as f:
23
- f.write(uploaded_file.getbuffer())
24
-
25
- # Extract frames using ffmpeg backend
26
- video = imageio.get_reader("temp_video.mp4", "ffmpeg")
27
- num_frames = len(list(video.iter_data()))
28
-
29
- # Extract 10 evenly spaced frames
30
- num_frames_to_extract = 10
31
- step = max(1, num_frames // num_frames_to_extract)
 
 
 
32
  frames = [
33
  Image.fromarray(video.get_data(i))
34
- for i in range(0, num_frames, step)
35
  ][:num_frames_to_extract]
36
-
37
- # Load BLIP model for image captioning
 
38
  @st.cache_resource
39
  def load_blip_model():
40
  processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
41
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
 
42
  return processor, model
43
 
44
  processor, model = load_blip_model()
45
-
46
- # Generate text descriptions for each frame
 
47
  descriptions = []
48
- for frame in frames:
49
  inputs = processor(images=frame, return_tensors="pt")
 
 
50
  out = model.generate(**inputs)
51
  description = processor.decode(out[0], skip_special_tokens=True)
52
  descriptions.append(description)
53
-
54
- # Combine descriptions into a single prompt
55
  text_prompt = ". ".join(descriptions)
56
  st.write("Generated text prompt:", text_prompt)
57
-
58
- # Load MusicGen model for sound generation
59
  @st.cache_resource
60
  def load_musicgen_model():
61
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
62
  model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
 
63
  return processor, model
64
 
65
  musicgen_processor, musicgen_model = load_musicgen_model()
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Generate sound effect
68
- with st.spinner("Generating sound effect..."):
69
- inputs = musicgen_processor(
70
- text=[text_prompt],
71
- padding=True,
72
- return_tensors="pt",
73
- )
74
- audio_values = musicgen_model.generate(**inputs, max_new_tokens=512)
75
-
76
- # Convert audio_values to a 1D NumPy array and normalize
77
- audio_array = audio_values[0].cpu().numpy() # Move to CPU and convert to NumPy
78
- if audio_array.ndim > 1: # Ensure it’s 1D
79
- audio_array = audio_array.flatten()
80
- audio_array = audio_array / np.max(np.abs(audio_array)) # Normalize to [-1, 1]
81
-
82
- # Define sample rate (MusicGen small uses 32kHz)
83
- sample_rate = 32000
84
-
85
- # Save audio to WAV file
86
- sf.write("output.wav", audio_array, sample_rate)
87
-
88
- # Verify file exists and provide playback/download
89
- if os.path.exists("output.wav"):
90
- st.audio("output.wav", format="audio/wav")
91
- with open("output.wav", "rb") as audio_file:
92
  st.download_button(
93
  label="Download Sound Effect",
94
  data=audio_file,
@@ -97,14 +117,14 @@ if uploaded_file is not None:
97
  )
98
  else:
99
  st.error("Failed to generate the audio file.")
100
-
101
  except Exception as e:
102
  st.error(f"An error occurred: {str(e)}")
103
- st.write("Please try uploading a different video or check your connection.")
104
-
105
  finally:
106
  # Clean up temporary files
107
- if os.path.exists("temp_video.mp4"):
108
- os.remove("temp_video.mp4")
109
- if os.path.exists("output.wav"):
110
- os.remove("output.wav")
 
6
  import soundfile as sf
7
  import torch
8
  import os
9
+ import tempfile
10
+ import time
11
 
12
+ # Set page title and instructions
13
  st.title("Video Sound Effect Generator")
14
+ st.write("Upload an MP4 video to generate a sound effect based on its content.")
15
+
16
+ # User-configurable settings
17
+ num_frames_to_extract = st.slider("Number of frames to analyze", 1, 10, 3, help="Fewer frames = faster processing")
18
 
19
  # File uploader for video
20
+ uploaded_file = st.file_uploader("Upload an MP4 video (high resolution)", type=["mp4"])
 
 
 
21
 
22
  if uploaded_file is not None:
23
  try:
24
+ # Use a temporary file for video
25
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
26
+ temp_video.write(uploaded_file.getbuffer())
27
+ temp_video_path = temp_video.name
28
+
29
+ # Progress bar setup
30
+ progress_bar = st.progress(0)
31
+ status_text = st.empty()
32
+
33
+ # Extract frames
34
+ status_text.text("Extracting frames...")
35
+ video = imageio.get_reader(temp_video_path, "ffmpeg")
36
+ total_frames = len(list(video.iter_data()))
37
+ step = max(1, total_frames // num_frames_to_extract)
38
  frames = [
39
  Image.fromarray(video.get_data(i))
40
+ for i in range(0, min(total_frames, num_frames_to_extract * step), step)
41
  ][:num_frames_to_extract]
42
+ progress_bar.progress(25)
43
+
44
+ # Load BLIP model with FP16 if GPU available
45
  @st.cache_resource
46
  def load_blip_model():
47
  processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
48
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
49
+ if torch.cuda.is_available():
50
+ model = model.half().to("cuda")
51
  return processor, model
52
 
53
  processor, model = load_blip_model()
54
+
55
+ # Generate text descriptions
56
+ status_text.text("Analyzing frames with BLIP...")
57
  descriptions = []
58
+ for i, frame in enumerate(frames):
59
  inputs = processor(images=frame, return_tensors="pt")
60
+ if torch.cuda.is_available():
61
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
62
  out = model.generate(**inputs)
63
  description = processor.decode(out[0], skip_special_tokens=True)
64
  descriptions.append(description)
65
+ progress_bar.progress(25 + int(25 * (i + 1) / len(frames)))
66
+
67
  text_prompt = ". ".join(descriptions)
68
  st.write("Generated text prompt:", text_prompt)
69
+
70
+ # Load MusicGen model with FP16 if GPU available
71
  @st.cache_resource
72
  def load_musicgen_model():
73
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
74
  model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
75
+ if torch.cuda.is_available():
76
+ model = model.half().to("cuda")
77
  return processor, model
78
 
79
  musicgen_processor, musicgen_model = load_musicgen_model()
80
+
81
+ # Generate sound effect (limit to ~5 seconds)
82
+ status_text.text("Generating sound effect with MusicGen...")
83
+ inputs = musicgen_processor(
84
+ text=[text_prompt],
85
+ padding=True,
86
+ return_tensors="pt",
87
+ )
88
+ if torch.cuda.is_available():
89
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
90
 
91
+ # max_new_tokens = 160 (5 seconds at 32kHz)
92
+ audio_values = musicgen_model.generate(**inputs, max_new_tokens=160)
93
+ audio_array = audio_values[0].cpu().numpy()
94
+ if audio_array.ndim > 1:
95
+ audio_array = audio_array.flatten()
96
+ audio_array = audio_array / np.max(np.abs(audio_array)) # Normalize
97
+ sample_rate = 32000 # MusicGen small uses 32kHz
98
+ progress_bar.progress(75)
99
+
100
+ # Save audio to temporary file
101
+ status_text.text("Saving audio...")
102
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
103
+ sf.write(temp_audio.name, audio_array, sample_rate)
104
+ temp_audio_path = temp_audio.name
105
+
106
+ # Provide playback and download
107
+ progress_bar.progress(100)
108
+ status_text.text("Done!")
109
+ if os.path.exists(temp_audio_path):
110
+ st.audio(temp_audio_path, format="audio/wav")
111
+ with open(temp_audio_path, "rb") as audio_file:
 
 
 
 
112
  st.download_button(
113
  label="Download Sound Effect",
114
  data=audio_file,
 
117
  )
118
  else:
119
  st.error("Failed to generate the audio file.")
120
+
121
  except Exception as e:
122
  st.error(f"An error occurred: {str(e)}")
123
+ st.write("Try reducing the number of frames or uploading a smaller video.")
124
+
125
  finally:
126
  # Clean up temporary files
127
+ if os.path.exists(temp_video_path):
128
+ os.remove(temp_video_path)
129
+ if os.path.exists(temp_audio_path):
130
+ os.remove(temp_audio_path)