Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -39,6 +39,7 @@ def clean_text(text: str) -> str:
|
|
39 |
"""
|
40 |
Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
|
41 |
"""
|
|
|
42 |
return re.sub(r'\*', '', text)
|
43 |
|
44 |
# ---------------------------------------------------------------------
|
@@ -74,6 +75,7 @@ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
|
|
74 |
|
75 |
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
|
76 |
processor = AutoProcessor.from_pretrained(model_key)
|
|
|
77 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
78 |
model.to(device)
|
79 |
MUSICGEN_MODELS[model_key] = (model, processor)
|
@@ -203,9 +205,7 @@ def generate_music(prompt: str, audio_length: int):
|
|
203 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
204 |
|
205 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
206 |
-
|
207 |
-
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
|
208 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
209 |
|
210 |
with torch.inference_mode():
|
211 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|
|
|
39 |
"""
|
40 |
Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
|
41 |
"""
|
42 |
+
# Remove all asterisks. You can add more cleaning steps here as needed.
|
43 |
return re.sub(r'\*', '', text)
|
44 |
|
45 |
# ---------------------------------------------------------------------
|
|
|
75 |
|
76 |
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
|
77 |
processor = AutoProcessor.from_pretrained(model_key)
|
78 |
+
|
79 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
80 |
model.to(device)
|
81 |
MUSICGEN_MODELS[model_key] = (model, processor)
|
|
|
205 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
206 |
|
207 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
208 |
+
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
|
|
|
|
|
209 |
|
210 |
with torch.inference_mode():
|
211 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|