import os import json import shutil import argparse import gradio as gr from generate import generate_music, get_args from utils import _L, WEIGHTS_DIR, TEMP_DIR, EN_US def infer_by_template(dataset: str, v: str, a: str, add_chord: bool): status = "Success" audio = midi = pdf = xml = mxl = tunes = jpg = None try: emotion = "Q1" if v == _L("Low") and a == _L("High"): emotion = "Q2" elif v == _L("Low") and a == _L("Low"): emotion = "Q3" elif v == _L("High") and a == _L("Low"): emotion = "Q4" if add_chord: print("Chord generation comes soon!") parser = argparse.ArgumentParser() args = get_args(parser) args.template = True audio, midi, pdf, xml, mxl, tunes, jpg = generate_music( args, emo=emotion, weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", ) except Exception as e: status = f"{e}" return status, audio, midi, pdf, xml, mxl, tunes, jpg def infer_by_features( dataset: str, pitch_std: str, mode: str, tempo: int, octave: int, rms: int, add_chord: bool, ): status = "Success" audio = midi = pdf = xml = mxl = tunes = jpg = None try: emotion = "Q1" if mode == _L("Minor") and pitch_std == _L("High"): emotion = "Q2" elif mode == _L("Minor") and pitch_std == _L("Low"): emotion = "Q3" elif mode == _L("Major") and pitch_std == _L("Low"): emotion = "Q4" if add_chord: print("Chord generation comes soon!") parser = argparse.ArgumentParser() args = get_args(parser) args.template = False audio, midi, pdf, xml, mxl, tunes, jpg = generate_music( args, emo=emotion, weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", fix_tempo=tempo, fix_pitch=octave, fix_volume=rms, ) except Exception as e: status = f"{e}" return status, audio, midi, pdf, xml, mxl, tunes, jpg def feedback( fixed_emo: str, source_dir=f"./{TEMP_DIR}/output", target_dir=f"./{TEMP_DIR}/feedback", ): try: if not fixed_emo: raise ValueError("Please select feedback before submitting! ") os.makedirs(target_dir, exist_ok=True) for root, _, files in os.walk(source_dir): for file in files: if file.endswith(".mxl"): prompt_emo = file.split("]")[0][1:] if prompt_emo != fixed_emo: file_path = os.path.join(root, file) target_path = os.path.join( target_dir, file.replace(".mxl", f"_{fixed_emo}.mxl") ) shutil.copy(file_path, target_path) return f"Copied {file_path} to {target_path}" else: return "Thanks for your feedback!" return "No .mxl files found in the source directory." except Exception as e: return f"{e}" def save_template(label: str, pitch_std: str, mode: str, tempo: int, octave: int, rms): status = "Success" template = None try: if ( label and pitch_std and mode and tempo != None and octave != None and rms != None ): json_str = json.dumps( { "label": label, "pitch_std": pitch_std == _L("High"), "mode": mode == _L("Major"), "tempo": tempo, "octave": octave, "volume": rms, } ) with open( f"./{TEMP_DIR}/feedback/templates.jsonl", "a", encoding="utf-8", ) as file: file.write(json_str + "\n") template = f"./{TEMP_DIR}/feedback/templates.jsonl" else: raise ValueError("Please check features") except Exception as e: status = f"{e}" return status, template if __name__ == "__main__": with gr.Blocks() as demo: if EN_US: gr.Markdown( "## The current CPU-based version on HuggingFace has slow inference, you can access the GPU-based mirror on [ModelScope](https://www.modelscope.cn/studios/monetjoe/EMelodyGen)" ) with gr.Row(): with gr.Column(): with gr.Accordion(label=_L("Additional info & option"), open=False): gr.Video( "./demo.mp4" if EN_US else "./src/tutorial.mp4", label=_L("Video demo"), show_download_button=False, show_share_button=False, ) gr.Markdown( f"## {_L('Cite')}" + """ ```bibtex @misc{zhou2025emelodygenemotionconditionedmelodygeneration, title = {EMelodyGen: Emotion-Conditioned Melody Generation in ABC Notation with the Musical Feature Template}, author = {Monan Zhou and Xiaobing Li and Feng Yu and Wei Li}, year = {2025}, eprint = {2309.13259}, archiveprefix = {arXiv}, primaryclass = {cs.IR}, url = {https://arxiv.org/abs/2309.13259} } ```""" ) with gr.Row(): data_opt = gr.Dropdown( ["VGMIDI", "EMOPIA", "Rough4Q"], label=_L("Dataset"), value="Rough4Q", ) chord_chk = gr.Checkbox( label=_L("Generate chords coming soon"), value=False, ) with gr.Tab(_L("By template")): gr.Image( ( "https://www.modelscope.cn/studio/monetjoe/EMelodyGen/resolve/master/src/4q.jpg" if EN_US else "./src/4q.jpg" ), show_label=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) v_radio = gr.Radio( [_L("Low"), _L("High")], label=_L( "Valence: reflects negative-positive levels of emotion" ), value=_L("High"), ) a_radio = gr.Radio( [_L("Low"), _L("High")], label=_L( "Arousal: reflects the calmness-intensity of the emotion" ), value=_L("High"), ) gen1_btn = gr.Button(_L("Generate")) with gr.Tab(_L("By feature control")): std_opt = gr.Radio( [_L("Low"), _L("High")], label=_L("Pitch SD"), value=_L("High") ) mode_opt = gr.Radio( [_L("Minor"), _L("Major")], label=_L("Mode"), value=_L("Major") ) tempo_opt = gr.Slider( minimum=40, maximum=228, step=1, value=120, label=_L("BPM tempo"), ) octave_opt = gr.Slider( minimum=-24, maximum=24, step=12, value=0, label=_L("±12 octave"), ) volume_opt = gr.Slider( minimum=-5, maximum=10, step=5, value=0, label=_L("Volume in dB"), ) gen2_btn = gr.Button(_L("Generate")) with gr.Accordion(label=_L("Save template"), open=False): with gr.Row(): with gr.Column(min_width=160): save_radio = gr.Radio( ["Q1", "Q2", "Q3", "Q4"], label=_L( "The emotion to which the current template belongs" ), ) save_btn = gr.Button(_L("Save")) with gr.Column(min_width=160): save_file = gr.File(label=_L("Download template")) with gr.Column(): wav_audio = gr.Audio(label=_L("Audio"), type="filepath") with gr.Accordion(label=_L("Feedback"), open=False): fdb_radio = gr.Radio( ["Q1", "Q2", "Q3", "Q4"], label=_L( "The emotion you believe the generated result should belong to" ), ) fdb_btn = gr.Button(_L("Submit")) status_bar = gr.Textbox(label=_L("Status"), show_copy_button=True) with gr.Row(): mid_file = gr.File(label=_L("Download MIDI"), min_width=80) pdf_file = gr.File(label=_L("Download PDF score"), min_width=80) xml_file = gr.File(label=_L("Download MusicXML"), min_width=80) mxl_file = gr.File(label=_L("Download MXL"), min_width=80) with gr.Row(): abc_txt = gr.TextArea( label=_L("ABC notation"), show_copy_button=True, ) staff_img = gr.Image(label=_L("Staff"), type="filepath") # actions gen1_btn.click( fn=infer_by_template, inputs=[data_opt, v_radio, a_radio, chord_chk], outputs=[ status_bar, wav_audio, mid_file, pdf_file, xml_file, mxl_file, abc_txt, staff_img, ], ) gen2_btn.click( fn=infer_by_features, inputs=[ data_opt, std_opt, mode_opt, tempo_opt, octave_opt, volume_opt, chord_chk, ], outputs=[ status_bar, wav_audio, mid_file, pdf_file, xml_file, mxl_file, abc_txt, staff_img, ], ) save_btn.click( fn=save_template, inputs=[ save_radio, std_opt, mode_opt, tempo_opt, octave_opt, volume_opt, ], outputs=[status_bar, save_file], ) fdb_btn.click(fn=feedback, inputs=fdb_radio, outputs=status_bar) demo.launch()