Audio Classification
Transformers
Safetensors
wavlm
jpRBX commited on
Commit
710301d
·
1 Parent(s): c61b971

Voice Safety Classifier v2

Browse files
Files changed (5) hide show
  1. README.md +56 -1
  2. config.json +141 -0
  3. inference.py +111 -0
  4. model.safetensors +3 -0
  5. requirements.txt +4 -0
README.md CHANGED
@@ -10,4 +10,59 @@ language:
10
  - ko
11
  - ja
12
  library_name: transformers
13
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  - ko
11
  - ja
12
  library_name: transformers
13
+ ---
14
+ ## Model Description
15
+ We present a voice-safety classification model that can be used for voice-toxicity detection and classification.
16
+ The model has been distilled into the [WavLM](https://arxiv.org/abs/2110.13900) architecture from a larger teacher model.
17
+ All the model training has been conducted with Roblox internal voice chat datasets,
18
+ using both machine and human-labeled data, with over 100k hours of training data in total.
19
+ We have also published a blog post about this work.
20
+
21
+ The model supports eight languages: English, Spanish, German, French, Portuguese, Italian, Korean, and Japanese.
22
+ It classifies the input audio into six toxicity classes in a multilabel fashion. The class labels are as follows:
23
+ `Discrimination`, `Harassment`, `Sexual`, `IllegalAndRegulated`, `DatingAndRomantic`, and `Profanity`.
24
+ Please refer to [Roblox Community Standards](https://en.help.roblox.com/hc/en-us/articles/203313410-Roblox-Community-Standards)
25
+ for a detailed explanation on the policy, which has been used for labeling the datasets.
26
+ The model outputs have been calibrated for the Roblox voice chat environment,
27
+ so that the class scores after a sigmoid can be interpreted as probabilities.
28
+
29
+ The classifier expects 16kHz audio segments as input. Ideal segment length is 15 seconds,
30
+ but the classifier can operate on shorter segments as well. The prediction accuracy may degrade
31
+ for longer segments.
32
+
33
+ The table below displays evaluation precision and recall for each of the supported languages,
34
+ as calculated over internal language-specific held-out datasets, which resemble the Roblox voice chat traffic.
35
+ The operating thresholds for each of the categories were kept equal per language, and optimized
36
+ to achieve a false positive rate of 1%. The classifier was then evaluated as a binary classifier,
37
+ tagging the audio as positive if any of the heads exceeded the threshold.
38
+
39
+ |Language|Precision|Recall|
40
+ |---|---|---|
41
+ |English |63.9%|58.2%|
42
+ |Spanish |76.1%|63.2%|
43
+ |German |69.9%|74.1%|
44
+ |French |70.3%|69.8%|
45
+ |Portuguese|85.4%|58.0%|
46
+ |Italian |86.6%|52.4%|
47
+ |Korean |78.0%|64.6%|
48
+ |Japanese |56.7%|57.7%|
49
+
50
+ Compared to the v1 voice safety classifier, the v2 model
51
+ expands the support from English to 7 additional languages,
52
+ as well as significantly improving the classification accuracy.
53
+ With the 1% false positive rate as above, the binary recall for English is improved by 92%.
54
+
55
+
56
+ ## Usage
57
+ The dependencies for the inference file can be installed as follows:
58
+ ```
59
+ pip install -r requirements.txt
60
+ ```
61
+ The provided Python file demonstrates how to use the classifier with arbitrary 16kHz audio input.
62
+ To run the inference, please run the following command:
63
+ ```
64
+ python inference.py --audio_file <your audio file path> --model_path <path to Huggingface model>
65
+ ```
66
+ You can download the model weights from the model releases page [here](https://github.com/Roblox/voice-safety-classifier/releases/tag/vs-classifier-v2),
67
+ or from HuggingFace under [`roblox/voice-safety-classifier-v2`](https://huggingface.co/Roblox/voice-safety-classifier-v2).
68
+ If `model_path` isn’t specified, the model will be loaded directly from HuggingFace.
config.json ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "patrickvonplaten/wavlm-libri-clean-100h-base-plus",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "WavLMForSequenceClassification"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512
26
+ ],
27
+ "conv_kernel": [
28
+ 10,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 3,
33
+ 2,
34
+ 2,
35
+ 7
36
+ ],
37
+ "conv_stride": [
38
+ 5,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2,
43
+ 2,
44
+ 2,
45
+ 2
46
+ ],
47
+ "ctc_loss_reduction": "mean",
48
+ "ctc_zero_infinity": false,
49
+ "diversity_loss_weight": 0.1,
50
+ "do_stable_layer_norm": false,
51
+ "eos_token_id": 2,
52
+ "feat_extract_activation": "gelu",
53
+ "feat_extract_norm": "group",
54
+ "feat_proj_dropout": 0.0,
55
+ "feat_quantizer_dropout": 0.0,
56
+ "final_dropout": 0.0,
57
+ "freeze_feat_extract_train": true,
58
+ "hidden_act": "gelu",
59
+ "hidden_dropout": 0.0,
60
+ "hidden_size": 1024,
61
+ "id2label": {
62
+ "0": "LABEL_0",
63
+ "1": "LABEL_1",
64
+ "2": "LABEL_2",
65
+ "3": "LABEL_3",
66
+ "4": "LABEL_4",
67
+ "5": "LABEL_5"
68
+ },
69
+ "initializer_range": 0.02,
70
+ "intermediate_size": 3072,
71
+ "label2id": {
72
+ "LABEL_0": 0,
73
+ "LABEL_1": 1,
74
+ "LABEL_2": 2,
75
+ "LABEL_3": 3,
76
+ "LABEL_4": 4,
77
+ "LABEL_5": 5
78
+ },
79
+ "layer_norm_eps": 1e-05,
80
+ "layerdrop": 0.0,
81
+ "mask_channel_length": 10,
82
+ "mask_channel_min_space": 1,
83
+ "mask_channel_other": 0.0,
84
+ "mask_channel_prob": 0.0,
85
+ "mask_channel_selection": "static",
86
+ "mask_feature_length": 10,
87
+ "mask_feature_min_masks": 0,
88
+ "mask_feature_prob": 0.0,
89
+ "mask_time_length": 10,
90
+ "mask_time_min_masks": 2,
91
+ "mask_time_min_space": 1,
92
+ "mask_time_other": 0.0,
93
+ "mask_time_prob": 0.05,
94
+ "mask_time_selection": "static",
95
+ "max_bucket_distance": 800,
96
+ "model_type": "wavlm",
97
+ "no_mask_channel_overlap": false,
98
+ "no_mask_time_overlap": false,
99
+ "num_adapter_layers": 3,
100
+ "num_attention_heads": 16,
101
+ "num_buckets": 320,
102
+ "num_codevector_groups": 2,
103
+ "num_codevectors_per_group": 320,
104
+ "num_conv_pos_embedding_groups": 16,
105
+ "num_conv_pos_embeddings": 128,
106
+ "num_ctc_classes": 80,
107
+ "num_feat_extract_layers": 7,
108
+ "num_hidden_layers": 10,
109
+ "num_negatives": 100,
110
+ "output_hidden_size": 768,
111
+ "pad_token_id": 28,
112
+ "proj_codevector_dim": 256,
113
+ "replace_prob": 0.5,
114
+ "tdnn_dilation": [
115
+ 1,
116
+ 2,
117
+ 3,
118
+ 1,
119
+ 1
120
+ ],
121
+ "tdnn_dim": [
122
+ 512,
123
+ 512,
124
+ 512,
125
+ 512,
126
+ 1500
127
+ ],
128
+ "tdnn_kernel": [
129
+ 5,
130
+ 3,
131
+ 3,
132
+ 1,
133
+ 1
134
+ ],
135
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
136
+ "torch_dtype": "float32",
137
+ "transformers_version": "4.38.2",
138
+ "use_weighted_layer_sum": false,
139
+ "vocab_size": 31,
140
+ "xvector_output_dim": 512
141
+ }
inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2024 Roblox Corporation
2
+
3
+ """
4
+ This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model.
5
+ """
6
+
7
+ import torch
8
+ import librosa
9
+ import numpy as np
10
+ import argparse
11
+ from transformers import WavLMForSequenceClassification
12
+
13
+
14
+ def feature_extract_simple(
15
+ wav,
16
+ sr=16_000,
17
+ win_len=15.0,
18
+ win_stride=15.0,
19
+ do_normalize=False,
20
+ ):
21
+ """simple feature extraction for wavLM
22
+ Parameters
23
+ ----------
24
+ wav : str or array-like
25
+ path to the wav file, or array-like
26
+ sr : int, optional
27
+ sample rate, by default 16_000
28
+ win_len : float, optional
29
+ window length, by default 15.0
30
+ win_stride : float, optional
31
+ window stride, by default 15.0
32
+ do_normalize: bool, optional
33
+ whether to normalize the input, by default False.
34
+ Returns
35
+ -------
36
+ np.ndarray
37
+ batched input to wavLM
38
+ """
39
+ if type(wav) == str:
40
+ signal, _ = librosa.core.load(wav, sr=sr)
41
+ else:
42
+ try:
43
+ signal = np.array(wav).squeeze()
44
+ except Exception as e:
45
+ print(e)
46
+ raise RuntimeError
47
+ batched_input = []
48
+ stride = int(win_stride * sr)
49
+ l = int(win_len * sr)
50
+ if len(signal) / sr > win_len:
51
+ for i in range(0, len(signal), stride):
52
+ if i + int(win_len * sr) > len(signal):
53
+ # padding the last chunk to make it the same length as others
54
+ chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
55
+ else:
56
+ chunked = signal[i : i + l]
57
+ if do_normalize:
58
+ chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
59
+ batched_input.append(chunked)
60
+ if i + int(win_len * sr) > len(signal):
61
+ break
62
+ else:
63
+ if do_normalize:
64
+ signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
65
+ batched_input.append(signal)
66
+ return np.stack(batched_input) # [N, T]
67
+
68
+
69
+ def infer(model, inputs):
70
+ output = model(inputs)
71
+ probs = torch.sigmoid(torch.Tensor(output.logits))
72
+ return probs
73
+
74
+
75
+ if __name__ == "__main__":
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument(
78
+ "--audio_file",
79
+ type=str,
80
+ help="File to run inference",
81
+ )
82
+ parser.add_argument(
83
+ "--model_path",
84
+ type=str,
85
+ default="roblox/voice-safety-classifier",
86
+ help="checkpoint file of model",
87
+ )
88
+ args = parser.parse_args()
89
+ labels_name_list = [
90
+ "Discrimination",
91
+ "Harassment",
92
+ "Sexual",
93
+ "IllegalAndRegulated",
94
+ "DatingAndRomantic",
95
+ "Profanity",
96
+ ]
97
+
98
+ # Model is trained on only 16kHz audio
99
+ audio, _ = librosa.core.load(args.audio_file, sr=16000)
100
+ input_np = feature_extract_simple(audio, sr=16000)
101
+ input_pt = torch.Tensor(input_np)
102
+ model = WavLMForSequenceClassification.from_pretrained(
103
+ args.model_path, num_labels=len(labels_name_list)
104
+ )
105
+ probs = infer(model, input_pt)
106
+ probs = probs.reshape(-1, 6).detach().tolist()
107
+ print(f"Probabilities for {args.audio_file}:")
108
+ for chunk_idx in range(len(probs)):
109
+ print(f"\nSegment {chunk_idx}:")
110
+ for label_idx, label in enumerate(labels_name_list):
111
+ print(f"{label} : {probs[chunk_idx][label_idx]}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8311e17845ff80ca00973d0aca5e898c20a3e390f1fd2783a4c227d5b9aec559
3
+ size 480863848
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ librosa
4
+ numpy