Bils commited on
Commit
3b58485
·
verified ·
1 Parent(s): 3168a3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -51,6 +51,7 @@ def generate_script(user_prompt: str, model_id: str, token: str):
51
  # Load MusicGen Model (Encapsulated)
52
  # ---------------------------------------------------------------------
53
  @spaces.GPU(duration=300)
 
54
  def generate_audio(prompt: str, audio_length: int):
55
  try:
56
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
@@ -59,19 +60,21 @@ def generate_audio(prompt: str, audio_length: int):
59
  musicgen_model.to("cuda")
60
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
61
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
62
- musicgen_model.to("cpu") # Return the model to CPU
63
 
64
  sr = musicgen_model.config.audio_encoder.sampling_rate
65
  audio_data = outputs[0, 0].cpu().numpy()
66
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
67
 
68
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
69
- write(temp_wav.name, sr, normalized_audio)
70
- return temp_wav.name
 
71
  except Exception as e:
72
  return f"Error generating audio: {e}"
73
 
74
 
 
75
  # ---------------------------------------------------------------------
76
  # Gradio Interface Functions
77
  # ---------------------------------------------------------------------
@@ -127,12 +130,11 @@ with gr.Blocks() as demo:
127
  value=512,
128
  info="Select the desired audio token length."
129
  )
130
- generate_audio_button = gr.Button("Generate Audio 🎶")
131
  audio_output = gr.Audio(
132
- label="🎶 Generated Audio File",
133
- type="filepath",
134
- interactive=False
135
- )
136
 
137
  # Footer
138
  gr.Markdown("""
 
51
  # Load MusicGen Model (Encapsulated)
52
  # ---------------------------------------------------------------------
53
  @spaces.GPU(duration=300)
54
+ @spaces.GPU(duration=300)
55
  def generate_audio(prompt: str, audio_length: int):
56
  try:
57
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
60
  musicgen_model.to("cuda")
61
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
62
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
63
+ musicgen_model.to("cpu")
64
 
65
  sr = musicgen_model.config.audio_encoder.sampling_rate
66
  audio_data = outputs[0, 0].cpu().numpy()
67
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
68
 
69
+ output_path = f"{tempfile.gettempdir()}/generated_audio.wav"
70
+ write(output_path, sr, normalized_audio)
71
+
72
+ return output_path
73
  except Exception as e:
74
  return f"Error generating audio: {e}"
75
 
76
 
77
+
78
  # ---------------------------------------------------------------------
79
  # Gradio Interface Functions
80
  # ---------------------------------------------------------------------
 
130
  value=512,
131
  info="Select the desired audio token length."
132
  )
 
133
  audio_output = gr.Audio(
134
+ label="🎶 Generated Audio File",
135
+ type="file",
136
+ interactive=False
137
+ )
138
 
139
  # Footer
140
  gr.Markdown("""