Bils commited on
Commit
73e3afa
·
verified ·
1 Parent(s): 6386945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -23,7 +23,14 @@ from transformers import (
23
  from TTS.api import TTS
24
 
25
  # Diffusers for sound design generation
26
- from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
27
 
28
  # ---------------------------------------------------------------------
29
  # Setup Logging and Environment Variables
@@ -100,13 +107,11 @@ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
100
  def get_sound_design_pipeline(model_name: str, token: str):
101
  """
102
  Returns a cached DiffusionPipeline for sound design if available;
103
- otherwise, it loads and caches the pipeline using the correct pipeline class.
104
  """
105
  if model_name in SOUND_DESIGN_PIPELINES:
106
  return SOUND_DESIGN_PIPELINES[model_name]
107
- # Import the correct pipeline class from diffusers
108
- from diffusers import AudioLDMPipeline
109
- pipe = DiffusionPipeline.from_pretrained(model_name, pipeline_class=AudioLDMPipeline, use_auth_token=token)
110
  SOUND_DESIGN_PIPELINES[model_name] = pipe
111
  return pipe
112
 
@@ -221,7 +226,7 @@ def generate_music(prompt: str, audio_length: int):
221
  @spaces.GPU(duration=200)
222
  def generate_sound_design(prompt: str):
223
  """
224
- Generates a sound design audio file based on the provided prompt using Audioldm.
225
  Returns the file path to the generated .wav file.
226
  """
227
  try:
 
23
  from TTS.api import TTS
24
 
25
  # Diffusers for sound design generation
26
+ from diffusers import DiffusionPipeline, AudioLDMPipeline
27
+ import diffusers
28
+
29
+ # Monkey-patch: Create a patched pipeline class so that any reference to AudioLDM2Pipeline is resolved correctly.
30
+ class PatchedAudioLDM2Pipeline(AudioLDMPipeline):
31
+ pass
32
+
33
+ setattr(diffusers, "AudioLDM2Pipeline", PatchedAudioLDM2Pipeline)
34
 
35
  # ---------------------------------------------------------------------
36
  # Setup Logging and Environment Variables
 
107
  def get_sound_design_pipeline(model_name: str, token: str):
108
  """
109
  Returns a cached DiffusionPipeline for sound design if available;
110
+ otherwise, it loads and caches the pipeline using the patched pipeline class.
111
  """
112
  if model_name in SOUND_DESIGN_PIPELINES:
113
  return SOUND_DESIGN_PIPELINES[model_name]
114
+ pipe = DiffusionPipeline.from_pretrained(model_name, pipeline_class=PatchedAudioLDM2Pipeline, use_auth_token=token)
 
 
115
  SOUND_DESIGN_PIPELINES[model_name] = pipe
116
  return pipe
117
 
 
226
  @spaces.GPU(duration=200)
227
  def generate_sound_design(prompt: str):
228
  """
229
+ Generates a sound design audio file based on the provided prompt using AudioLDM 2.
230
  Returns the file path to the generated .wav file.
231
  """
232
  try: