VoiceStar / app.py
mrfakename's picture
Update app.py
c1ad8f4 verified
raw
history blame
6.72 kB
import spaces
import gradio as gr
import sys
import threading
import queue
from io import TextIOBase
from inference import inference_patch
import datetime
import subprocess
import os
# Predefined valid combinations set
with open('prompts.txt', 'r') as f:
prompts = f.readlines()
valid_combinations = set()
for prompt in prompts:
prompt = prompt.strip()
parts = prompt.split('_')
valid_combinations.add((parts[0], parts[1], parts[2]))
# Generate available options
periods = sorted({p for p, _, _ in valid_combinations})
composers = sorted({c for _, c, _ in valid_combinations})
instruments = sorted({i for _, _, i in valid_combinations})
# Dynamic component updates
def update_components(period, composer):
if not period:
return [
gr.Dropdown(choices=[], value=None, interactive=False),
gr.Dropdown(choices=[], value=None, interactive=False)
]
valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
return [
gr.Dropdown(
choices=valid_composers,
value=composer if composer in valid_composers else None,
interactive=True
),
gr.Dropdown(
choices=valid_instruments,
value=None,
interactive=bool(valid_instruments)
)
]
class RealtimeStream(TextIOBase):
def __init__(self, queue):
self.queue = queue
def write(self, text):
self.queue.put(text)
return len(text)
def save_and_convert(abc_content, period, composer, instrumentation):
if not all([period, composer, instrumentation]):
raise gr.Error("Please complete a valid generation first before saving")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
prompt_str = f"{period}_{composer}_{instrumentation}"
filename_base = f"{timestamp}_{prompt_str}"
abc_filename = f"{filename_base}.abc"
with open(abc_filename, "w", encoding="utf-8") as f:
f.write(abc_content)
xml_filename = f"{filename_base}.xml"
try:
subprocess.run(
["python", "abc2xml.py", '-o', '.', abc_filename, ],
check=True,
capture_output=True,
text=True
)
except subprocess.CalledProcessError as e:
error_msg = f"Conversion failed: {e.stderr}" if e.stderr else "Unknown error"
raise gr.Error(f"ABC to XML conversion failed: {error_msg}. Please try to generate another composition.")
return f"Saved successfully: {abc_filename} -> {xml_filename}"
@spaces.GPU
def generate_music(period, composer, instrumentation):
if (period, composer, instrumentation) not in valid_combinations:
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
output_queue = queue.Queue()
original_stdout = sys.stdout
sys.stdout = RealtimeStream(output_queue)
result_container = []
def run_inference():
try:
result_container.append(inference_patch(period, composer, instrumentation))
finally:
sys.stdout = original_stdout
thread = threading.Thread(target=run_inference)
thread.start()
process_output = ""
while thread.is_alive():
try:
text = output_queue.get(timeout=0.1)
process_output += text
yield process_output, None
except queue.Empty:
continue
while not output_queue.empty():
text = output_queue.get()
process_output += text
yield process_output, None
final_result = result_container[0] if result_container else ""
yield process_output, final_result
with gr.Blocks() as demo:
gr.Markdown("## NotaGen")
with gr.Row():
# 左侧栏
with gr.Column():
period_dd = gr.Dropdown(
choices=periods,
value=None,
label="Period",
interactive=True
)
composer_dd = gr.Dropdown(
choices=[],
value=None,
label="Composer",
interactive=False
)
instrument_dd = gr.Dropdown(
choices=[],
value=None,
label="Instrumentation",
interactive=False
)
generate_btn = gr.Button("Generate!", variant="primary")
process_output = gr.Textbox(
label="Generation process",
interactive=False,
lines=15,
max_lines=15,
placeholder="Generation progress will be shown here...",
elem_classes="process-output"
)
# 右侧栏
with gr.Column():
final_output = gr.Textbox(
label="Post-processed ABC notation scores",
interactive=True,
lines=23,
placeholder="Post-processed ABC scores will be shown here...",
elem_classes="final-output"
)
with gr.Row():
save_btn = gr.Button("💾 Save as ABC & XML files", variant="secondary")
save_status = gr.Textbox(
label="Save Status",
interactive=False,
visible=True,
max_lines=2
)
period_dd.change(
update_components,
inputs=[period_dd, composer_dd],
outputs=[composer_dd, instrument_dd]
)
composer_dd.change(
update_components,
inputs=[period_dd, composer_dd],
outputs=[composer_dd, instrument_dd]
)
generate_btn.click(
generate_music,
inputs=[period_dd, composer_dd, instrument_dd],
outputs=[process_output, final_output]
)
save_btn.click(
save_and_convert,
inputs=[final_output, period_dd, composer_dd, instrument_dd],
outputs=[save_status]
)
css = """
.process-output {
background-color: #f0f0f0;
font-family: monospace;
padding: 10px;
border-radius: 5px;
}
.final-output {
background-color: #ffffff;
font-family: sans-serif;
padding: 10px;
border-radius: 5px;
}
.process-output textarea {
max-height: 500px !important;
overflow-y: auto !important;
white-space: pre-wrap;
}
"""
css += """
button#💾-save-convert:hover {
background-color: #ffe6e6;
}
"""
demo.css = css
if __name__ == "__main__":
demo.queue().launch()