Bils commited on
Commit
653eb14
·
verified ·
1 Parent(s): cc173f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -39,6 +39,7 @@ def clean_text(text: str) -> str:
39
  """
40
  Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
41
  """
 
42
  return re.sub(r'\*', '', text)
43
 
44
  # ---------------------------------------------------------------------
@@ -74,6 +75,7 @@ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
74
 
75
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
76
  processor = AutoProcessor.from_pretrained(model_key)
 
77
  device = "cuda" if torch.cuda.is_available() else "cpu"
78
  model.to(device)
79
  MUSICGEN_MODELS[model_key] = (model, processor)
@@ -203,9 +205,7 @@ def generate_music(prompt: str, audio_length: int):
203
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
204
 
205
  device = "cuda" if torch.cuda.is_available() else "cpu"
206
- # Process the input and move each tensor to the proper device
207
- inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
208
- inputs = {k: v.to(device) for k, v in inputs.items()}
209
 
210
  with torch.inference_mode():
211
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
 
39
  """
40
  Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
41
  """
42
+ # Remove all asterisks. You can add more cleaning steps here as needed.
43
  return re.sub(r'\*', '', text)
44
 
45
  # ---------------------------------------------------------------------
 
75
 
76
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
77
  processor = AutoProcessor.from_pretrained(model_key)
78
+
79
  device = "cuda" if torch.cuda.is_available() else "cpu"
80
  model.to(device)
81
  MUSICGEN_MODELS[model_key] = (model, processor)
 
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
 
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
 
 
209
 
210
  with torch.inference_mode():
211
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)