Commit
·
00613da
1
Parent(s):
f217a11
Add demo inference
Browse files- README.md +24 -1
- app.py +135 -0
- requirements.txt +5 -0
README.md
CHANGED
@@ -11,4 +11,27 @@ license: apache-2.0
|
|
11 |
short_description: Demo deployment for voice safety classifier
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
short_description: Demo deployment for voice safety classifier
|
12 |
---
|
13 |
|
14 |
+
# Voice Safety Classifier Demo
|
15 |
+
|
16 |
+
This is a demo application for the [Roblox Voice Safety Classifier v2](https://huggingface.co/Roblox/voice-safety-classifier-v2) model.
|
17 |
+
|
18 |
+
## Model Information
|
19 |
+
|
20 |
+
The Voice Safety Classifier is designed to detect potentially unsafe content in audio. It can classify audio into various safety categories to help identify problematic content.
|
21 |
+
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
1. Upload an audio file or record audio directly in your browser
|
25 |
+
2. The model will process the audio and return classification results
|
26 |
+
3. Results are displayed with confidence scores for each category
|
27 |
+
|
28 |
+
## Technical Details
|
29 |
+
|
30 |
+
This demo uses:
|
31 |
+
- Hugging Face Transformers
|
32 |
+
- Gradio for the web interface
|
33 |
+
- PyTorch and TorchAudio for audio processing
|
34 |
+
|
35 |
+
## License
|
36 |
+
|
37 |
+
This demo uses the Roblox Voice Safety Classifier v2 model. Please refer to the [model card](https://huggingface.co/Roblox/voice-safety-classifier-v2) for licensing information.
|
app.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from transformers import WavLMForSequenceClassification
|
8 |
+
|
9 |
+
|
10 |
+
def feature_extract_simple(
|
11 |
+
wav,
|
12 |
+
sr=16_000,
|
13 |
+
win_len=15.0,
|
14 |
+
win_stride=15.0,
|
15 |
+
do_normalize=False,
|
16 |
+
) -> np.ndarray:
|
17 |
+
"""Simple feature extraction for WavLM.
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
wav : str or array-like
|
21 |
+
path to the wav file, or array-like
|
22 |
+
sr : int, optional
|
23 |
+
sample rate, by default 16_000
|
24 |
+
win_len : float, optional
|
25 |
+
window length, by default 15.0
|
26 |
+
win_stride : float, optional
|
27 |
+
window stride, by default 15.0
|
28 |
+
do_normalize: bool, optional
|
29 |
+
whether to normalize the input, by default False.
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
np.ndarray
|
33 |
+
batched input to WavLM
|
34 |
+
"""
|
35 |
+
if type(wav) == str:
|
36 |
+
signal, _ = librosa.core.load(wav, sr=sr)
|
37 |
+
else:
|
38 |
+
try:
|
39 |
+
signal = np.array(wav).squeeze()
|
40 |
+
except Exception as e:
|
41 |
+
print(e)
|
42 |
+
raise RuntimeError
|
43 |
+
batched_input = []
|
44 |
+
stride = int(win_stride * sr)
|
45 |
+
l = int(win_len * sr)
|
46 |
+
if len(signal) / sr > win_len:
|
47 |
+
for i in range(0, len(signal), stride):
|
48 |
+
if i + int(win_len * sr) > len(signal):
|
49 |
+
# padding the last chunk to make it the same length as others
|
50 |
+
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
|
51 |
+
else:
|
52 |
+
chunked = signal[i : i + l]
|
53 |
+
if do_normalize:
|
54 |
+
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
|
55 |
+
batched_input.append(chunked)
|
56 |
+
if i + int(win_len * sr) > len(signal):
|
57 |
+
break
|
58 |
+
else:
|
59 |
+
if do_normalize:
|
60 |
+
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
|
61 |
+
batched_input.append(signal)
|
62 |
+
return np.stack(batched_input) # [N, T]
|
63 |
+
|
64 |
+
|
65 |
+
def infer(model, inputs) -> torch.Tensor:
|
66 |
+
output = model(inputs)
|
67 |
+
probs = torch.sigmoid(torch.Tensor(output.logits))
|
68 |
+
return probs
|
69 |
+
|
70 |
+
|
71 |
+
def predict(audio_file) -> Dict[str, Any]:
|
72 |
+
if audio_file is None:
|
73 |
+
return {"No prediction available": 0.0}
|
74 |
+
|
75 |
+
try:
|
76 |
+
input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True)
|
77 |
+
input_pt = torch.Tensor(input_np)
|
78 |
+
|
79 |
+
probs = infer(model, input_pt)
|
80 |
+
probs_list = probs.reshape(-1, len(labels)).detach().tolist()
|
81 |
+
|
82 |
+
# Create a results dictionary
|
83 |
+
if len(probs_list) > 0:
|
84 |
+
first_segment_probs = probs_list[0]
|
85 |
+
results = {
|
86 |
+
label: float(prob) for label, prob in zip(labels, first_segment_probs)
|
87 |
+
}
|
88 |
+
|
89 |
+
# If there are multiple segments, include that information in the results
|
90 |
+
if len(probs_list) > 1:
|
91 |
+
results["Note"] = (
|
92 |
+
f"Audio contains {len(probs_list)} segments. Showing first segment only."
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
results = {"Error": "No segments detected in audio"}
|
96 |
+
|
97 |
+
# Sort by confidence score
|
98 |
+
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
99 |
+
|
100 |
+
return sorted_results
|
101 |
+
except Exception as e:
|
102 |
+
return {"Error": str(e)}
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
model_path = "Roblox/voice-safety-classifier-v2"
|
107 |
+
labels = [
|
108 |
+
"Discrimination",
|
109 |
+
"Harassment",
|
110 |
+
"Sexual",
|
111 |
+
"IllegalAndRegulated",
|
112 |
+
"DatingAndRomantic",
|
113 |
+
"Profanity",
|
114 |
+
]
|
115 |
+
|
116 |
+
model = WavLMForSequenceClassification.from_pretrained(
|
117 |
+
model_path, num_labels=len(labels)
|
118 |
+
)
|
119 |
+
model.eval()
|
120 |
+
|
121 |
+
demo = gr.Interface(
|
122 |
+
fn=predict,
|
123 |
+
inputs=gr.Audio(type="filepath", label="Upload or record audio"),
|
124 |
+
outputs=gr.Label(num_top_classes=6, label="Classification Results"),
|
125 |
+
title="Voice Safety Classifier",
|
126 |
+
description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio.
|
127 |
+
Upload or record an audio file to get started. The model classifies audio into categories including Discrimination,
|
128 |
+
Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity.
|
129 |
+
|
130 |
+
The model processes audio in 15-second chunks and returns probability scores for each category.""",
|
131 |
+
examples=[],
|
132 |
+
flagging_mode="never",
|
133 |
+
)
|
134 |
+
|
135 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=5
|
2 |
+
librosa>=0.10.0
|
3 |
+
numpy>=1.24.0
|
4 |
+
torch>=2.0.0
|
5 |
+
transformers>=4.30.0
|