josephliu-roblox commited on
Commit
a6b8100
·
1 Parent(s): 39d9319

Add README and supporting code

Browse files
Files changed (4) hide show
  1. README.md +45 -3
  2. images/human_eval_pr_curve.png +0 -0
  3. inference.py +110 -0
  4. requirements.txt +4 -0
README.md CHANGED
@@ -1,3 +1,45 @@
1
- ---
2
- license: cc-by-sa-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Description
2
+ We present a large classification model trained on a manually curated real-world dataset that can be used as a new
3
+ benchmark for advancing research in toxicity detection and classification.
4
+ Our model is fine-tuned on the [WavLM base plus](https://arxiv.org/abs/2110.13900) with 2,374 hours of audio clips from
5
+ voice chat for multilabel classification. The audio clips are automatically labeled using a synthetic data pipeline
6
+ described in [our blog post](link to blog post here). A single output can have multiple labels.
7
+ The model outputs a n by 6 output tensor where the inferred labels are `Profanity`, `DatingAndSexting`, `Racist`,
8
+ `Bullying`, `Other`, `NoViolation`. `Other` consists of policy violation categories with low prevalence such as drugs
9
+ and alcohol or self-harm that are combined into a single category.
10
+
11
+
12
+ We evaluated this model on a data set with human annotated labels that contained a total of 9,795 samples with the class
13
+ distribution shown below. Note that we did not include the "other" category in this evaluation data set.
14
+
15
+ |Class|Number of examples| Duration (hours)|% of dataset|
16
+ |---|---|---|---|
17
+ |Profanity | 4893| 15.38 | 49.95%|
18
+ |DatingAndSexting | 688 | 2.52 | 7.02% |
19
+ |Racist | 889 | 3.10 | 9.08% |
20
+ |Bullying | 1256 | 4.25 | 12.82% |
21
+ |NoViolation | 4185 | 9.93 | 42.73% |
22
+
23
+
24
+ If we set the same threshold across all classes and treat the model as a binary classifier across all 4 toxicity classes
25
+ (`Profanity`, `DatingAndSexting`, `Racist`, `Bullying`), we get a binarized average precision of 94.48%. The precision
26
+ recall curve is as shown below.
27
+
28
+
29
+ <p align="center">
30
+ <img src="images/human_eval_pr_curve.png" alt="PR Curve" width="500"/>
31
+ </p>
32
+
33
+ ## Usage
34
+ The dependencies for the inference file can be installed as follows:
35
+ ```
36
+ pip install -r requirements.txt
37
+ ```
38
+ The inference file contains useful helper functions to preprocess the audio file for proper inference.
39
+ To run the inference file, please run the following command:
40
+ ```
41
+ python inference.py --audio_file <your audio file path> --model_path <path to Huggingface model>
42
+ ```
43
+ You can get the model weights either by downloading from the model releases page [here](https://github.com/Roblox/voice-safety-classifier/releases/tag/vs-classifier-v1), or from HuggingFace under
44
+ [`roblox/voice-safety-classifier`](https://huggingface.co/Roblox/voice-safety-classifier). If `model_path` isn’t
45
+ specified, the model will be loaded directly from HuggingFace.
images/human_eval_pr_curve.png ADDED
inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2024 Roblox Corporation
2
+
3
+ """
4
+ This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model.
5
+ """
6
+
7
+ import torch
8
+ import librosa
9
+ import numpy as np
10
+ import argparse
11
+ from transformers import WavLMForSequenceClassification
12
+
13
+
14
+ def feature_extract_simple(
15
+ wav,
16
+ sr=16_000,
17
+ win_len=15.0,
18
+ win_stride=15.0,
19
+ do_normalize=False,
20
+ ):
21
+ """simple feature extraction for wavLM
22
+ Parameters
23
+ ----------
24
+ wav : str or array-like
25
+ path to the wav file, or array-like
26
+ sr : int, optional
27
+ sample rate, by default 16_000
28
+ win_len : float, optional
29
+ window length, by default 15.0
30
+ win_stride : float, optional
31
+ window stride, by default 15.0
32
+ do_normalize: bool, optional
33
+ whether to normalize the input, by default False.
34
+ Returns
35
+ -------
36
+ np.ndarray
37
+ batched input to wavLM
38
+ """
39
+ if type(wav) == str:
40
+ signal, _ = librosa.core.load(wav, sr=sr)
41
+ else:
42
+ try:
43
+ signal = np.array(wav).squeeze()
44
+ except Exception as e:
45
+ print(e)
46
+ raise RuntimeError
47
+ batched_input = []
48
+ stride = int(win_stride * sr)
49
+ l = int(win_len * sr)
50
+ if len(signal) / sr > win_len:
51
+ for i in range(0, len(signal), stride):
52
+ if i + int(win_len * sr) > len(signal):
53
+ # padding the last chunk to make it the same length as others
54
+ chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
55
+ else:
56
+ chunked = signal[i : i + l]
57
+ if do_normalize:
58
+ chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
59
+ batched_input.append(chunked)
60
+ if i + int(win_len * sr) > len(signal):
61
+ break
62
+ else:
63
+ if do_normalize:
64
+ signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
65
+ batched_input.append(signal)
66
+ return np.stack(batched_input) # [N, T]
67
+
68
+
69
+ def infer(model, inputs):
70
+ output = model(inputs)
71
+ probs = torch.sigmoid(torch.Tensor(output.logits))
72
+ return probs
73
+
74
+
75
+ if __name__ == "__main__":
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument(
78
+ "--audio_file",
79
+ type=str,
80
+ help="File to run inference",
81
+ )
82
+ parser.add_argument(
83
+ "--model_path",
84
+ type=str,
85
+ default="roblox/voice-safety-classifier",
86
+ help="checkpoint file of model",
87
+ )
88
+ args = parser.parse_args()
89
+ labels_name_list = [
90
+ "Profanity",
91
+ "DatingAndSexting",
92
+ "Racist",
93
+ "Bullying",
94
+ "Other",
95
+ "NoViolation",
96
+ ]
97
+ # Model is trained on only 16kHz audio
98
+ audio, _ = librosa.core.load(args.audio_file, sr=16000)
99
+ input_np = feature_extract_simple(audio, sr=16000)
100
+ input_pt = torch.Tensor(input_np)
101
+ model = WavLMForSequenceClassification.from_pretrained(
102
+ args.model_path, num_labels=len(labels_name_list)
103
+ )
104
+ probs = infer(model, input_pt)
105
+ probs = probs.reshape(-1, 6).detach().tolist()
106
+ print(f"Probabilities for {args.audio_file} is:")
107
+ for chunk_idx in range(len(probs)):
108
+ print(f"\nSegment {chunk_idx}:")
109
+ for label_idx, label in enumerate(labels_name_list):
110
+ print(f"{label} : {probs[chunk_idx][label_idx]}")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ librosa
4
+ numpy