|
from typing import Any, Dict |
|
|
|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
import torch |
|
from transformers import WavLMForSequenceClassification |
|
|
|
|
|
def feature_extract_simple( |
|
wav, |
|
sr=16_000, |
|
win_len=15.0, |
|
win_stride=15.0, |
|
do_normalize=False, |
|
) -> np.ndarray: |
|
"""Simple feature extraction for WavLM. |
|
Parameters |
|
---------- |
|
wav : str or array-like |
|
path to the wav file, or array-like |
|
sr : int, optional |
|
sample rate, by default 16_000 |
|
win_len : float, optional |
|
window length, by default 15.0 |
|
win_stride : float, optional |
|
window stride, by default 15.0 |
|
do_normalize: bool, optional |
|
whether to normalize the input, by default False. |
|
Returns |
|
------- |
|
np.ndarray |
|
batched input to WavLM |
|
""" |
|
if type(wav) == str: |
|
signal, _ = librosa.core.load(wav, sr=sr) |
|
else: |
|
try: |
|
signal = np.array(wav).squeeze() |
|
except Exception as e: |
|
print(e) |
|
raise RuntimeError |
|
batched_input = [] |
|
stride = int(win_stride * sr) |
|
l = int(win_len * sr) |
|
if len(signal) / sr > win_len: |
|
for i in range(0, len(signal), stride): |
|
if i + int(win_len * sr) > len(signal): |
|
|
|
chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) |
|
else: |
|
chunked = signal[i : i + l] |
|
if do_normalize: |
|
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) |
|
batched_input.append(chunked) |
|
if i + int(win_len * sr) > len(signal): |
|
break |
|
else: |
|
if do_normalize: |
|
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) |
|
batched_input.append(signal) |
|
return np.stack(batched_input) |
|
|
|
|
|
def infer(model, inputs) -> torch.Tensor: |
|
output = model(inputs) |
|
probs = torch.sigmoid(torch.Tensor(output.logits)) |
|
return probs |
|
|
|
|
|
def predict(audio_file) -> Dict[str, Any]: |
|
if audio_file is None: |
|
return {"No prediction available": 0.0} |
|
|
|
try: |
|
input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True) |
|
input_pt = torch.Tensor(input_np) |
|
|
|
probs = infer(model, input_pt) |
|
probs_list = probs.reshape(-1, len(labels)).detach().tolist() |
|
|
|
|
|
if len(probs_list) > 0: |
|
first_segment_probs = probs_list[0] |
|
results = { |
|
label: float(prob) for label, prob in zip(labels, first_segment_probs) |
|
} |
|
|
|
|
|
if len(probs_list) > 1: |
|
results["Note"] = ( |
|
f"Audio contains {len(probs_list)} segments. Showing first segment only." |
|
) |
|
else: |
|
results = {"Error": "No segments detected in audio"} |
|
|
|
|
|
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) |
|
|
|
return sorted_results |
|
except Exception as e: |
|
return {"Error": str(e)} |
|
|
|
|
|
if __name__ == "__main__": |
|
model_path = "Roblox/voice-safety-classifier-v2" |
|
labels = [ |
|
"Discrimination", |
|
"Harassment", |
|
"Sexual", |
|
"IllegalAndRegulated", |
|
"DatingAndRomantic", |
|
"Profanity", |
|
] |
|
|
|
model = WavLMForSequenceClassification.from_pretrained( |
|
model_path, num_labels=len(labels) |
|
) |
|
model.eval() |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Audio(type="filepath", label="Upload or record audio"), |
|
outputs=gr.Label(num_top_classes=6, label="Classification Results"), |
|
title="Voice Safety Classifier", |
|
description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio. |
|
Upload or record an audio file to get started. The model classifies audio into categories including Discrimination, |
|
Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity. |
|
|
|
The model processes audio in 15-second chunks and returns probability scores for each category.""", |
|
examples=[], |
|
flagging_mode="never", |
|
) |
|
|
|
demo.launch() |
|
|