|
|
|
|
|
import argparse |
|
from functools import lru_cache |
|
import json |
|
import logging |
|
from pathlib import Path |
|
import platform |
|
import tempfile |
|
import time |
|
import uuid |
|
|
|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
from scipy.io import wavfile |
|
|
|
import log |
|
from project_settings import environment, project_path, log_directory |
|
from toolbox.os.command import Command |
|
from toolbox.age_and_gender.models.audeering import AudeeringModel |
|
|
|
log.setup_size_rotating(log_directory=log_directory) |
|
|
|
logger = logging.getLogger("main") |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--examples_dir", |
|
default=(project_path / "data/examples").as_posix(), |
|
type=str, |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def save_input_audio(sample_rate: int, signal: np.ndarray) -> str: |
|
temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio" |
|
temp_audio_dir.mkdir(parents=True, exist_ok=True) |
|
filename = temp_audio_dir / f"{uuid.uuid4()}.wav" |
|
filename = filename.as_posix() |
|
wavfile.write( |
|
filename, |
|
sample_rate, signal |
|
) |
|
return filename |
|
|
|
|
|
def shell(cmd: str): |
|
return Command.popen(cmd) |
|
|
|
|
|
age_and_gender_model_map = { |
|
"audeering-6-ft":{ |
|
"infer_cls": AudeeringModel, |
|
"kwargs": { |
|
"model_path": |
|
(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix() |
|
if platform.system() == "Windows" else "audeering/wav2vec2-large-robust-6-ft-age-gender" |
|
}, |
|
"sample_rate": 16000, |
|
}, |
|
"audeering-24-ft": { |
|
"infer_cls": AudeeringModel, |
|
"kwargs": { |
|
"model_path": |
|
(project_path / "pretrained_models/wav2vec2-large-robust-24-ft-age-gender").as_posix() |
|
if platform.system() == "Windows" else "audeering/wav2vec2-large-robust-24-ft-age-gender", |
|
}, |
|
"sample_rate": 16000, |
|
}, |
|
} |
|
|
|
|
|
@lru_cache(maxsize=3) |
|
def load_get_age_and_gender_model(infer_cls, **kwargs): |
|
infer_engine = infer_cls(**kwargs) |
|
|
|
return infer_engine |
|
|
|
|
|
def when_click_get_age_and_gender_button(audio_t, engine: str): |
|
sample_rate, signal = audio_t |
|
filename = save_input_audio(sample_rate, signal) |
|
|
|
logger.info(f"run get_age_and_gender; engine: {engine}.") |
|
|
|
infer_engine_param = age_and_gender_model_map.get(engine) |
|
if infer_engine_param is None: |
|
raise gr.Error(f"invalid denoise engine: {engine}.") |
|
|
|
try: |
|
infer_cls = infer_engine_param["infer_cls"] |
|
kwargs = infer_engine_param["kwargs"] |
|
sample_rate = infer_engine_param["sample_rate"] |
|
|
|
signal, _ = librosa.load(filename, sr=sample_rate) |
|
duration = len(signal) / sample_rate |
|
|
|
infer_engine = load_get_age_and_gender_model(infer_cls=infer_cls, **kwargs) |
|
|
|
|
|
time_begin = time.time() |
|
age_and_gender = infer_engine.__call__(signal, sample_rate) |
|
time_cost = time.time() - time_begin |
|
|
|
rtf = time_cost / duration |
|
|
|
result = { |
|
**age_and_gender, |
|
"duration": round(duration, 4), |
|
"time_cost": round(time_cost, 4), |
|
"rtf": round(rtf, 4), |
|
} |
|
|
|
result = json.dumps(result, ensure_ascii=False, indent=4) |
|
except Exception as e: |
|
raise gr.Error(f"get_age_and_gender failed, error type: {type(e)}, error text: {str(e)}.") |
|
|
|
return result |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
|
|
examples_dir = Path(args.examples_dir) |
|
|
|
|
|
age_and_gender_model_choices = list(age_and_gender_model_map.keys()) |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
with gr.Tabs(): |
|
with gr.TabItem("age_and_gender"): |
|
with gr.Row(): |
|
with gr.Column(variant="panel", scale=5): |
|
ag_audio = gr.Audio(label="audio") |
|
ag_engine = gr.Dropdown(choices=age_and_gender_model_choices, value=age_and_gender_model_choices[0], label="engine") |
|
ag_button = gr.Button(variant="primary") |
|
with gr.Column(variant="panel", scale=5): |
|
ag_output = gr.Text(label="output") |
|
|
|
gr.Examples( |
|
examples=[ |
|
[filename.as_posix(), age_and_gender_model_choices[0]] |
|
for filename in examples_dir.glob("*.wav") |
|
], |
|
inputs=[ag_audio, ag_engine], |
|
outputs=[ag_output], |
|
fn=when_click_get_age_and_gender_button, |
|
) |
|
ag_button.click( |
|
when_click_get_age_and_gender_button, |
|
inputs=[ag_audio, ag_engine], |
|
outputs=[ag_output], |
|
) |
|
with gr.TabItem("shell"): |
|
shell_text = gr.Textbox(label="cmd") |
|
shell_button = gr.Button("run") |
|
shell_output = gr.Textbox(label="output") |
|
|
|
shell_button.click( |
|
shell, |
|
inputs=[shell_text,], |
|
outputs=[shell_output], |
|
) |
|
|
|
|
|
|
|
blocks.queue().launch( |
|
share=False if platform.system() == "Windows" else False, |
|
|
|
server_name="0.0.0.0", |
|
server_port=environment.get("port", 7860, dtype=int), |
|
) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|