nx_denoise / main.py
HoneyTian's picture
update
8ce0f99
raw
history blame
3.72 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from pathlib import Path
import platform
import gradio as gr
from huggingface_hub import snapshot_download
import numpy as np
import torch
from project_settings import environment, project_path
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/vm_sound_classification",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
denoise_engines = {
"mpnet": InferenceMPNet(
pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet_aishell_20250221.zip").as_posix(),
),
}
def when_click_denoise_button(noisy_audio_t, engine: str):
sample_rate, signal = noisy_audio_t
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
infer_engine = denoise_engines.get(engine)
if infer_engine is None:
raise gr.Error(f"invalid denoise engine: {engine}.")
try:
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
except Exception as e:
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
enhanced_audio_t = (sample_rate, enhanced_audio)
return enhanced_audio_t, None
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# choices
denoise_engine_choices = list(denoise_engines.keys())
# ui
with gr.Blocks() as blocks:
gr.Markdown(value="nx denoise.")
with gr.Tabs():
with gr.TabItem("denoise"):
with gr.Row():
with gr.Column(variant="panel", scale=5):
dn_noisy_audio = gr.Audio(label="noisy_audio")
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
dn_button = gr.Button(variant="primary")
with gr.Column(variant="panel", scale=5):
dn_enhanced_audio = gr.Audio(label="enhanced_audio")
dn_clean_audio = gr.Audio(label="clean_audio")
dn_button.click(
when_click_denoise_button,
inputs=[dn_noisy_audio, dn_engine],
outputs=[dn_enhanced_audio, dn_clean_audio]
)
# http://127.0.0.1:7864/
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()