File size: 4,322 Bytes
00613da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from typing import Any, Dict
import gradio as gr
import librosa
import numpy as np
import torch
from transformers import WavLMForSequenceClassification
def feature_extract_simple(
wav,
sr=16_000,
win_len=15.0,
win_stride=15.0,
do_normalize=False,
) -> np.ndarray:
"""Simple feature extraction for WavLM.
Parameters
----------
wav : str or array-like
path to the wav file, or array-like
sr : int, optional
sample rate, by default 16_000
win_len : float, optional
window length, by default 15.0
win_stride : float, optional
window stride, by default 15.0
do_normalize: bool, optional
whether to normalize the input, by default False.
Returns
-------
np.ndarray
batched input to WavLM
"""
if type(wav) == str:
signal, _ = librosa.core.load(wav, sr=sr)
else:
try:
signal = np.array(wav).squeeze()
except Exception as e:
print(e)
raise RuntimeError
batched_input = []
stride = int(win_stride * sr)
l = int(win_len * sr)
if len(signal) / sr > win_len:
for i in range(0, len(signal), stride):
if i + int(win_len * sr) > len(signal):
# padding the last chunk to make it the same length as others
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
else:
chunked = signal[i : i + l]
if do_normalize:
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
batched_input.append(chunked)
if i + int(win_len * sr) > len(signal):
break
else:
if do_normalize:
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
batched_input.append(signal)
return np.stack(batched_input) # [N, T]
def infer(model, inputs) -> torch.Tensor:
output = model(inputs)
probs = torch.sigmoid(torch.Tensor(output.logits))
return probs
def predict(audio_file) -> Dict[str, Any]:
if audio_file is None:
return {"No prediction available": 0.0}
try:
input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True)
input_pt = torch.Tensor(input_np)
probs = infer(model, input_pt)
probs_list = probs.reshape(-1, len(labels)).detach().tolist()
# Create a results dictionary
if len(probs_list) > 0:
first_segment_probs = probs_list[0]
results = {
label: float(prob) for label, prob in zip(labels, first_segment_probs)
}
# If there are multiple segments, include that information in the results
if len(probs_list) > 1:
results["Note"] = (
f"Audio contains {len(probs_list)} segments. Showing first segment only."
)
else:
results = {"Error": "No segments detected in audio"}
# Sort by confidence score
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
return sorted_results
except Exception as e:
return {"Error": str(e)}
if __name__ == "__main__":
model_path = "Roblox/voice-safety-classifier-v2"
labels = [
"Discrimination",
"Harassment",
"Sexual",
"IllegalAndRegulated",
"DatingAndRomantic",
"Profanity",
]
model = WavLMForSequenceClassification.from_pretrained(
model_path, num_labels=len(labels)
)
model.eval()
demo = gr.Interface(
fn=predict,
inputs=gr.Audio(type="filepath", label="Upload or record audio"),
outputs=gr.Label(num_top_classes=6, label="Classification Results"),
title="Voice Safety Classifier",
description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio.
Upload or record an audio file to get started. The model classifies audio into categories including Discrimination,
Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity.
The model processes audio in 15-second chunks and returns probability scores for each category.""",
examples=[],
flagging_mode="never",
)
demo.launch()
|