mlang238965 commited on
Commit
00613da
·
1 Parent(s): f217a11

Add demo inference

Browse files
Files changed (3) hide show
  1. README.md +24 -1
  2. app.py +135 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -11,4 +11,27 @@ license: apache-2.0
11
  short_description: Demo deployment for voice safety classifier
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: Demo deployment for voice safety classifier
12
  ---
13
 
14
+ # Voice Safety Classifier Demo
15
+
16
+ This is a demo application for the [Roblox Voice Safety Classifier v2](https://huggingface.co/Roblox/voice-safety-classifier-v2) model.
17
+
18
+ ## Model Information
19
+
20
+ The Voice Safety Classifier is designed to detect potentially unsafe content in audio. It can classify audio into various safety categories to help identify problematic content.
21
+
22
+ ## Usage
23
+
24
+ 1. Upload an audio file or record audio directly in your browser
25
+ 2. The model will process the audio and return classification results
26
+ 3. Results are displayed with confidence scores for each category
27
+
28
+ ## Technical Details
29
+
30
+ This demo uses:
31
+ - Hugging Face Transformers
32
+ - Gradio for the web interface
33
+ - PyTorch and TorchAudio for audio processing
34
+
35
+ ## License
36
+
37
+ This demo uses the Roblox Voice Safety Classifier v2 model. Please refer to the [model card](https://huggingface.co/Roblox/voice-safety-classifier-v2) for licensing information.
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import gradio as gr
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from transformers import WavLMForSequenceClassification
8
+
9
+
10
+ def feature_extract_simple(
11
+ wav,
12
+ sr=16_000,
13
+ win_len=15.0,
14
+ win_stride=15.0,
15
+ do_normalize=False,
16
+ ) -> np.ndarray:
17
+ """Simple feature extraction for WavLM.
18
+ Parameters
19
+ ----------
20
+ wav : str or array-like
21
+ path to the wav file, or array-like
22
+ sr : int, optional
23
+ sample rate, by default 16_000
24
+ win_len : float, optional
25
+ window length, by default 15.0
26
+ win_stride : float, optional
27
+ window stride, by default 15.0
28
+ do_normalize: bool, optional
29
+ whether to normalize the input, by default False.
30
+ Returns
31
+ -------
32
+ np.ndarray
33
+ batched input to WavLM
34
+ """
35
+ if type(wav) == str:
36
+ signal, _ = librosa.core.load(wav, sr=sr)
37
+ else:
38
+ try:
39
+ signal = np.array(wav).squeeze()
40
+ except Exception as e:
41
+ print(e)
42
+ raise RuntimeError
43
+ batched_input = []
44
+ stride = int(win_stride * sr)
45
+ l = int(win_len * sr)
46
+ if len(signal) / sr > win_len:
47
+ for i in range(0, len(signal), stride):
48
+ if i + int(win_len * sr) > len(signal):
49
+ # padding the last chunk to make it the same length as others
50
+ chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
51
+ else:
52
+ chunked = signal[i : i + l]
53
+ if do_normalize:
54
+ chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
55
+ batched_input.append(chunked)
56
+ if i + int(win_len * sr) > len(signal):
57
+ break
58
+ else:
59
+ if do_normalize:
60
+ signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
61
+ batched_input.append(signal)
62
+ return np.stack(batched_input) # [N, T]
63
+
64
+
65
+ def infer(model, inputs) -> torch.Tensor:
66
+ output = model(inputs)
67
+ probs = torch.sigmoid(torch.Tensor(output.logits))
68
+ return probs
69
+
70
+
71
+ def predict(audio_file) -> Dict[str, Any]:
72
+ if audio_file is None:
73
+ return {"No prediction available": 0.0}
74
+
75
+ try:
76
+ input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True)
77
+ input_pt = torch.Tensor(input_np)
78
+
79
+ probs = infer(model, input_pt)
80
+ probs_list = probs.reshape(-1, len(labels)).detach().tolist()
81
+
82
+ # Create a results dictionary
83
+ if len(probs_list) > 0:
84
+ first_segment_probs = probs_list[0]
85
+ results = {
86
+ label: float(prob) for label, prob in zip(labels, first_segment_probs)
87
+ }
88
+
89
+ # If there are multiple segments, include that information in the results
90
+ if len(probs_list) > 1:
91
+ results["Note"] = (
92
+ f"Audio contains {len(probs_list)} segments. Showing first segment only."
93
+ )
94
+ else:
95
+ results = {"Error": "No segments detected in audio"}
96
+
97
+ # Sort by confidence score
98
+ sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
99
+
100
+ return sorted_results
101
+ except Exception as e:
102
+ return {"Error": str(e)}
103
+
104
+
105
+ if __name__ == "__main__":
106
+ model_path = "Roblox/voice-safety-classifier-v2"
107
+ labels = [
108
+ "Discrimination",
109
+ "Harassment",
110
+ "Sexual",
111
+ "IllegalAndRegulated",
112
+ "DatingAndRomantic",
113
+ "Profanity",
114
+ ]
115
+
116
+ model = WavLMForSequenceClassification.from_pretrained(
117
+ model_path, num_labels=len(labels)
118
+ )
119
+ model.eval()
120
+
121
+ demo = gr.Interface(
122
+ fn=predict,
123
+ inputs=gr.Audio(type="filepath", label="Upload or record audio"),
124
+ outputs=gr.Label(num_top_classes=6, label="Classification Results"),
125
+ title="Voice Safety Classifier",
126
+ description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio.
127
+ Upload or record an audio file to get started. The model classifies audio into categories including Discrimination,
128
+ Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity.
129
+
130
+ The model processes audio in 15-second chunks and returns probability scores for each category.""",
131
+ examples=[],
132
+ flagging_mode="never",
133
+ )
134
+
135
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=5
2
+ librosa>=0.10.0
3
+ numpy>=1.24.0
4
+ torch>=2.0.0
5
+ transformers>=4.30.0