AkitoP's picture
Update app.py
5dbf2c1 verified
raw
history blame contribute delete
1.4 kB
import argparse
import json
from pathlib import Path
import gradio as gr
import torch
from models import AudioClassifier
from utils import logger
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {device}")
ckpt_dir = Path("ckpt/")
config_path = ckpt_dir / "config.json"
assert config_path.exists(), f"config.json not found in {ckpt_dir}"
config = json.loads((ckpt_dir / "config.json").read_text())
model = AudioClassifier(device=device, **config["model"]).to(device)
# Latest checkpoint
if (ckpt_dir / "model_final.pth").exists():
ckpt = ckpt_dir / "model_final.pth"
else:
ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
logger.info(f"Loading {ckpt}...")
model.load_state_dict(torch.load(ckpt, map_location=device))
def classify_audio(audio_file: str):
logger.info(f"Classifying {audio_file}...")
output = model.infer_from_file(audio_file)
logger.success(f"Predicted: {output}")
return output
desc = """
# NSFW音声分類器
出力は以下の2つのクラスの確率です。
- usual: 通常の音声
- ecchi: エッチな音声
元Space:https://huggingface.co/spaces/litagin/Japanese-Ero-Voice-Classifier
"""
with gr.Interface(
fn=classify_audio,
inputs=gr.Audio(label="Input audio", type="filepath"),
outputs=gr.Text(label="Classification"),
description=desc,
allow_flagging="never",
) as iface:
iface.launch()